TensorRT INT8量化代码
1 简介
在之前的文章中7-TensorRT中的INT8
介绍了TensorRT的量化理论基础,这里就根据理论实现相关的代码
2 PTQ
2.1 trtexec
int8量化 使用trtexec
参数--int8
来生成对应的--int8
的engine
,但是精度损失会比较大。也可使用int8和fp16混合精度同时使用--fp16 --int8
1 | trtexec --onnx=XX.onnx --saveEngine=model.plan --int8 |
上面的方式没有使用量化文件进行校正,因此精度损失非常多。可以使用trtexec
添加--calib=
参数指定calibration file
文件进行校准。如下面的指令
这个是我目前采用的方法
1 | trtexec --onnx=test.onnx --verbose --dumpLayerInfo --dumpProfile --int8 --calib=calibratorfile_test.txt --saveEngine=test.onnx.INT8.trtmodel --exportProfile=build.log |
上面指定的--calib=calibratorfile_test.txt
是需要我们自己生成的,使用我们自己的数据集,在build engine阶段,tensorrt会根据数据分布和我们指定的校准器来生成的。如在7-TensorRT中的INT8
中介绍的熵校准Entropy calibration
使用的校准器就是IInt8EntropyCalibrator2
。
将ONNX转换为INT8的TensorRT引擎,需要:
- 准备一个校准集,用于在转换过程中寻找使得转换后的激活值分布与原来的FP32类型的激活值分布差异最小的阈值;
- 并写一个校准器类,该类需继承trt.IInt8EntropyCalibrator2父类,并重写get_batch_size, get_batch, read_calibration_cache, write_calibration_cache这几个方法。可直接使用
myCalibrator.py
,需传入图片文件夹地址
下面是具体的实现代码
2.1.1 python版本代码生成calibratorfile和int8engine
参考:https://github.com/aiLiwensong/Pytorch2TensorRT
官方代码在/samples/python/int8_caffe_mnist
下面回答来自GPT
在 TensorRT 中,trt.Builder
是用于构建 TensorRT 引擎的主要类。当你使用 trt.Builder
来构建一个 TensorRT 引擎时,通过调用 builder.build_engine
函数来生成引擎。在执行这个函数时,如果你使用了 int8 量化,TensorRT 将会在量化过程中使用校准数据,并在量化过程中生成校准数据文件。
具体而言,当你调用 builder.build_engine
函数时,TensorRT 会根据设置执行 int8 量化过程,包括使用指定的校准方法和数据集来生成校准数据。在量化过程中,TensorRT 将会不断地调用 IInt8Calibrator::getBatch
函数来获取数据样本,并根据这些样本来执行量化。在执行量化过程时,TensorRT 会在合适的时机将校准数据写入到校准数据文件中。
main.py
1 | # main.py |
对应的myCalibrator.py
1 | #myCalibrator.py |
trt_convertor.py
如下
1 | # trt_convertor.py |
2.1.2 c++代码版本生成calibratorfile和int8engine
参考代码
- Miacis
- tensorrt_starter
- https://www.cnblogs.com/TaipKang/p/15542329.html 这篇博客中也有全套的代码供参考,但是对应的tensorrt版本比较老。我现在是8.4.有些函数对应不上,仅供参考流程。
2.1.2.1 简单版本 只有IInt8EntropyCalibrator2
下面代码来自 tensorrt_starter
强烈推荐这个 tensorrt_starter 里面有很多教程。这个例子包含了整个engine的生成和推理过程。
1 |
|
src/model/calibrator.hpp代码如下
1 |
|
对应的头文件如下
1 | //trt_calibrator.hpp |
2.1.2.2 完相对完善的量化器方案
上面的代码之实现了一个量化的方案。
下面的实现了三种,也很有参考价值,参考网址
1 | //calibrator.hpp |
对应的调用代码如下
其中calibration_image_list_file
是一个txt文件,里面包含了量化需要的文件路径文件内容类似
1 | /mnt/data/dataset/ImageNet/ILSVRC2012/test/ILSVRC2012_test_00000001.JPEG |
1 | int max_batch_size = 1; |
2.2 python onnx转trt
- 操作流程:按照常规方案导出onnx,onnx序列化为tensorrt engine之前打开int8量化模式并采用校正数据集进行校正;
- 优点:
- 导出onnx之前的所有操作都为常规操作;
- 相比在pytorch中进行PTQ int8量化,所需显存小;
- 缺点:
- 量化过程为黑盒子,无法看到中间过程;
- 校正过程需在实际运行的tensorrt版本中进行并保存tensorrt engine;
- 量化过程中发现,即使模型为动态输入,校正数据集使用时也必须与推理时的输入shape[N, C, H, W]完全一致,否则,效果非常非常差,动态模型慎用。
- 操作示例参看onnx2trt_ptq.py
下面的代码其实就是上面的2.1.1
1 | import tensorrt as trt |
2.3 polygraphy工具
- 操作流程:按照常规方案导出onnx,onnx序列化为tensorrt engine之前打开int8量化模式并采用校正数据集进行校正;
- 优点:1. 相较于1.1,代码量更少,只需完成校正数据的处理代码;
- 缺点:1
- 同上所有;
- 动态尺寸时,校正数据需与–trt-opt-shapes相同
- 内部默认最多校正20个epoch;
- 安装polygraphy
1 | pip install colored polygraphy --extra-index-url https://pypi.ngc.nvidia.com |
- 量化
1 | polygraphy convert XX.onnx --int8 --data-loader-script loader_data.py --calibration-cache XX.cache -o XX.plan --trt-min-shapes images:[1,3,384,1280] --trt-opt-shapes images:[26,3,384,1280] --trt-max-shapes images:[26,3,384,1280] #量化 |
- loader_data.py为较正数据集加载过程,自动调用脚本中的load_data()函数:
2.4 pytorch中执行(推荐)
实际上是使用pytorch-quantization
PyTorch 中进行量化(Quantization)的库,支持PTQ和QAT的量化。这里给出就是PTQ的例子
注:在pytorch中执行导出的onnx将产生一个明确量化的模型,属于显示量化
操作流程:安装pytorch_quantization库->加载校正数据->加载模型(在加载模型之前,启用quant_modules.initialize() 以保证原始模型层替换为量化层)->校正->导出onnx;
优点:
- 通过导出的onnx能够看到每层量化的过程;
- onnx导出为tensort engine时可以采用trtexec(注:命令行需加–int8,需要fp16和int8混合精度时,再添加–fp16,这里有些疑问,GPT说导出 ONNX 模型时进行了量化,那么在使用
trtexec
转换为 TensorRT Engine 时,你不需要添加任何特别的参数。因为 ONNX 模型中已经包含了量化后的信息,TensorRT 在转换过程中会自动识别并保留这些信息。因此不知道是不是需要--int8
,我感觉不需要了。),比较简单; - pytorch校正过程可在任意设备中进行;
- 相较上述方法,校正数据集使用shape无需与推理shape一致,也能获得较好的结果,动态输入时,推荐采用此种方式。
缺点:导出onnx时,显存占用非常大;
操作示例参看:pytorch模型进行量化导出yolov5_pytorch_ptq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision import models, datasets
import pytorch_quantization
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules
from pytorch_quantization import calib
from tqdm import tqdm
print(pytorch_quantization.__version__)
import os
import tensorrt as trt
import numpy as np
import time
import wget
import tarfile
import shutil
import cv2
import random
from models.yolo import Model
from models.experimental import End2End
def compute_amax(model, **kwargs):
# Load calib result
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
if isinstance(module._calibrator, calib.MaxCalibrator):
module.load_calib_amax()
else:
module.load_calib_amax(**kwargs)
model.cuda()
def collect_stats(model, data_loader):
"""Feed data to the network and collect statistics"""
# Enable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.disable_quant()
module.enable_calib()
else:
module.disable()
# Feed data to the network for collecting stats
for i, image in tqdm(enumerate(data_loader)):
model(image.cuda())
# Disable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.enable_quant()
module.disable_calib()
else:
module.enable()
def get_crop_bbox(img, crop_size):
"""Randomly get a crop bounding box."""
margin_h = max(img.shape[0] - crop_size[0], 0)
margin_w = max(img.shape[1] - crop_size[1], 0)
offset_h = np.random.randint(0, margin_h + 1)
offset_w = np.random.randint(0, margin_w + 1)
crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
return crop_x1, crop_y1, crop_x2, crop_y2
def crop(img, crop_bbox):
"""Crop from ``img``"""
crop_x1, crop_y1, crop_x2, crop_y2 = crop_bbox
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
return img
class CaliData(data.Dataset):
def __init__(self, path, num, inputsize=[384, 1280]):
self.img_files = [os.path.join(path, p) for p in os.listdir(path) if p.endswith('jpg')]
random.shuffle(self.img_files)
self.img_files = self.img_files[:num]
self.height = inputsize[0]
self.width = inputsize[1]
def __getitem__(self, index):
f = self.img_files[index]
img = cv2.imread(f) # BGR
crop_size = [self.height, self.width]
crop_bbox = get_crop_bbox(img, crop_size)
# crop the image
img = crop(img, crop_bbox)
img = img.transpose((2, 0, 1))[::-1, :, :] # BHWC to BCHW ,BGR to RGB
img = np.ascontiguousarray(img)
img = img.astype(np.float32) / 255.
return img
def __len__(self):
return len(self.img_files)
if __name__ == '__main__':
pt_file = 'runs/train/exp/weights/best.pt'
calib_path = 'XX/train'
num = 2000 # 用来校正的数目
batchsize = 4
# 准备数据
dataset = CaliData(calib_path, num)
dataloader = data.DataLoader(dataset, batch_size=batchsize)
# 模型加载
quant_modules.initialize() #保证原始模型层替换为量化层
device = torch.device('cuda:0')
ckpt = torch.load(pt_file, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
# QAT
q_model = ckpt['model']
yaml = ckpt['model'].yaml
q_model = Model(yaml, ch=yaml['ch'], nc=yaml['nc']).to(device) # creat
q_model.eval()
q_model = End2End(q_model).cuda()
ckpt = ckpt['model']
modified_state_dict = {}
for key, val in ckpt.state_dict().items():
# Remove 'module.' from the key names
if key.startswith('module'):
modified_state_dict[key[7:]] = val
else:
modified_state_dict[key] = val
q_model.model.load_state_dict(modified_state_dict)
# Calibrate the model using calibration technique.
with torch.no_grad():
collect_stats(q_model, dataloader)
compute_amax(q_model, method="entropy")
# Set static member of TensorQuantizer to use Pytorch’s own fake quantization functions
quant_nn.TensorQuantizer.use_fb_fake_quant = True
# Exporting to ONNX
dummy_input = torch.randn(26, 3, 384, 1280, device='cuda')
input_names = ["images"]
output_names = ["num_dets", 'det_boxes']
# output_names = ['outputs']
save_path = '/'.join(pt_file.split('/')[:-1])
onnx_file = os.path.join(save_path, 'best_ptq.onnx')
dynamic = dict()
dynamic['images'] = {0: 'batch'}
dynamic['num_dets'] = {0: 'batch'}
dynamic['det_boxes'] = {0: 'batch'}
torch.onnx.export(
q_model,
dummy_input,
onnx_file,
verbose=False,
opset_version=13,
do_constant_folding=False,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic)上面的代码,生成的 ONNX 模型是已经量化过的。以下是代码中的量化过程:
- 导入 PyTorch Quantization 库:
- 通过
import pytorch_quantization
以及其他相关模块的导入,使用了 PyTorch Quantization 库中的功能。
- 通过
- 量化模型:
- 在加载模型后,执行了量化操作。在
__main__
中,通过collect_stats
和compute_amax
函数执行了量化统计和计算最大值的操作。这是典型的 QAT(Quantization Aware Training)过程,其中使用校准数据集来估计量化参数。 - 在执行
compute_amax
函数时,传递了method="entropy"
参数,这表明使用的是熵方法来计算量化参数。 - 最后,通过
torch.onnx.export
函数将量化后的模型导出为 ONNX 格式。
- 在加载模型后,执行了量化操作。在
- 导入 PyTorch Quantization 库:
3 QAT
实际上是使用pytorch-quantization
PyTorch 中进行量化(Quantization)的库,支持PTQ和QAT的量化。这里给出就是QAT的例子
注:在pytorch中执行导出的onnx将产生一个明确量化的模型,属于显式量化
- 操作流程:安装pytorch_quantization库->加载训练数据->加载模型(在加载模型之前,启用quant_modules.initialize() 以保证原始模型层替换为量化层)->训练->导出onnx;
- 优点:1. 模型量化参数重新训练,训练较好时,精度下降较少; 2. 通过导出的onnx能够看到每层量化的过程;2. onnx导出为tensort engine时可以采用trtexec(注:命令行需加–int8,需要fp16和int8混合精度时,再添加–fp16),比较简单;3.训练过程可在任意设备中进行;
- 缺点:1.导出onnx时,显存占用非常大;2.最终精度取决于训练好坏;3. QAT训练shape需与推理shape一致才能获得好的推理结果;4. 导出onnx时需采用真实的图片输入作为输入设置
- 操作示例参看yolov5_pytorch_qat.py感知训练,参看export_onnx_qat.py
1 | import torch |
附录:
- tensorrt官方int8量化方法汇总
- 参考代码:python版本 onnx转int8 trtengine https://github.com/aiLiwensong/Pytorch2TensorRT/tree/master
- 参考代码 Miacis
- 参考代码vtensorrt_starter
- 强烈推荐参考代码 tensorrt_starter
- c++实现所有量化方法代码
- 强烈推荐知乎的一个文章,包含PTQ和QAT的流程和python代码实现yolov5_tensorrt_int8
- 参考代码 https://www.cnblogs.com/TaipKang/p/15542329.html
- pytorch-quantization’s documentation