forked from meituan/YOLOv6
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
liqingyuan02
committed
Jul 14, 2022
1 parent
ccf8b43
commit 6aa2b25
Showing
7 changed files
with
605 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Partial Quantization | ||
The performance of YOLOv6s heavily degrades from 42.4% to 35.6% after traditional PTQ, which is unacceptable. To resolve this issue, we propose **partial quantization**. First we analyze the quantization sensitivity of all layers, and then we let the most sensitive layers to have full precision as a compromise. | ||
|
||
With partial quantization, we finally reach 42.1%, only 0.3% loss in accuracy, while the throughput of the partially quantized model is about 1.56 times that of the FP16 model at a batch size of 32. This method achieves a nice tradeoff between accuracy and throughput. | ||
|
||
## Prerequirements | ||
```python | ||
pip install --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com nvidia-pyindex | ||
pip install --extra-index-url=https://pypi.ngc.nvidia.com --trusted-host pypi.ngc.nvidia.com pytorch_quantization | ||
``` | ||
## Sensitivity analysis | ||
|
||
Please use the following command to perform sensitivity analysis. Since we randomly sample 128 images from train dataset each time, the sensitivity files will be slightly different. | ||
|
||
```python | ||
python3 sensitivity_analyse.py --weights yolov6s_reopt.pt \ | ||
--batch-size 32 \ | ||
--batch-number 4 \ | ||
--data-root train_data_path | ||
``` | ||
|
||
## Partial quantization | ||
|
||
With the sensitivity file at hand, we then proceed with partial quantization as follows. | ||
|
||
```python | ||
python3 partial_quant.py --weights yolov6s_reopt.pt \ | ||
--calib-weights yolov6s_repot_calib.pt \ | ||
--sensitivity-file yolov6s_reopt_sensivitiy_128_calib.txt \ | ||
--quant-boundary 55 \ | ||
--export-batch-size 1 | ||
``` | ||
|
||
## Deployment | ||
|
||
Build a TRT engine | ||
|
||
```python | ||
trtexec --workspace=1024 --percentile=99 --streams=1 --int8 --fp16 --avgRuns=10 --onnx=yolov6s_reopt_partial_bs1.sim.onnx --saveEngine=yolov6s_reopt_partial_bs1.sim.trt | ||
``` | ||
|
||
## Performance | ||
| Model | Size | Precision |mAP<sup>val<br/>0.5:0.95 | Speed<sup>T4<br/>trt b1 <br/>(fps) | Speed<sup>T4<br/>trt b32 <br/>(fps) | | ||
| :-------------- | ----------- | ----------- |:----------------------- | ---------------------------------------- | -----------------------------------| | ||
| [**YOLOv6-s-partial**] </br>[bs1](https://github.com/lippman1125/YOLOv6/releases/download/0.1.0/yolov6s_reopt_partial_bs1.sim.onnx) <br/>[bs32](https://github.com/lippman1125/YOLOv6/releases/download/0.1.0/yolov6s_reopt_partial_bs32.sim.onnx) <br/>| 640 | INT8 |42.1 | 503 | 811 | | ||
| [**YOLOv6-s**] | 640 | FP16 |42.4 | 373 | 520 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import os | ||
import torch | ||
from yolov6.core.evaler import Evaler | ||
|
||
class EvalerWrapper(object): | ||
def __init__(self, eval_cfg): | ||
task = eval_cfg['task'] | ||
save_dir = eval_cfg['save_dir'] | ||
half = eval_cfg['half'] | ||
data = eval_cfg['data'] | ||
batch_size = eval_cfg['batch_size'] | ||
img_size = eval_cfg['img_size'] | ||
device = eval_cfg['device'] | ||
dataloader = None | ||
|
||
Evaler.check_task(task) | ||
if not os.path.exists(save_dir): | ||
os.makedirs(save_dir) | ||
|
||
# reload thres/device/half/data according task | ||
conf_thres, iou_thres = Evaler.reload_thres(conf_thres=0.001, iou_thres=0.65, task=task) | ||
device = Evaler.reload_device(device, None, task) | ||
data = Evaler.reload_dataset(data) if isinstance(data, str) else data | ||
|
||
# init | ||
val = Evaler(data, batch_size, img_size, conf_thres, \ | ||
iou_thres, device, half, save_dir) | ||
val.stride = eval_cfg['stride'] | ||
dataloader = val.init_data(dataloader, task) | ||
|
||
self.eval_cfg = eval_cfg | ||
self.half = half | ||
self.device = device | ||
self.task = task | ||
self.val = val | ||
self.val_loader = dataloader | ||
|
||
def eval(self, model): | ||
model.eval() | ||
model.to(self.device) | ||
if self.half is True: | ||
model.half() | ||
|
||
with torch.no_grad(): | ||
pred_result = self.val.predict_model(model, self.val_loader, self.task) | ||
eval_result = self.val.eval_model(pred_result, model, self.val_loader, self.task) | ||
|
||
return eval_result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
task: 'val' | ||
save_dir: 'runs/val/exp' | ||
half: True | ||
data: '../../data/coco.yaml' | ||
batch_size: 32 | ||
img_size: 640 | ||
device: '0' | ||
stride: 32 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import argparse | ||
import time | ||
import sys | ||
import os | ||
|
||
ROOT = os.getcwd() | ||
if str(ROOT) not in sys.path: | ||
sys.path.append(str(ROOT)) | ||
|
||
sys.path.append('../../') | ||
|
||
from yolov6.models.effidehead import Detect | ||
from yolov6.layers.common import * | ||
from yolov6.utils.events import LOGGER, load_yaml | ||
from yolov6.utils.checkpoint import load_checkpoint | ||
|
||
from tools.partial_quantization.eval import EvalerWrapper | ||
from tools.partial_quantization.utils import get_module, concat_quant_amax_fuse, quant_sensitivity_load | ||
from tools.partial_quantization.ptq import load_ptq, partial_quant | ||
|
||
from pytorch_quantization import nn as quant_nn | ||
|
||
# concat_fusion_list = [ | ||
# ('backbone.ERBlock_5.2.m', 'backbone.ERBlock_5.2.cv2.conv'), | ||
# ('backbone.ERBlock_5.0.rbr_reparam', 'neck.Rep_p4.conv1.rbr_reparam'), | ||
# ('backbone.ERBlock_4.0.rbr_reparam', 'neck.Rep_p3.conv1.rbr_reparam'), | ||
# ('neck.upsample1.upsample_transpose', 'neck.Rep_n3.conv1.rbr_reparam'), | ||
# ('neck.upsample0.upsample_transpose', 'neck.Rep_n4.conv1.rbr_reparam') | ||
# ] | ||
|
||
opt_concat_fusion_list = [ | ||
('backbone.ERBlock_5.2.m', 'backbone.ERBlock_5.2.cv2.conv'), | ||
('backbone.ERBlock_5.0.conv', 'neck.Rep_p4.conv1.conv'), | ||
('backbone.ERBlock_4.0.conv', 'neck.Rep_p3.conv1.conv'), | ||
('neck.upsample1.upsample_transpose', 'neck.Rep_n3.conv1.conv'), | ||
('neck.upsample0.upsample_transpose', 'neck.Rep_n4.conv1.conv') | ||
] | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--weights', type=str, default='./yolov6s_reopt.pt', help='weights path') | ||
parser.add_argument('--calib-weights', type=str, default='./yolov6s_reopt_calib.pt', help='calib weights path') | ||
parser.add_argument('--data-root', type=str, default=None, help='train data path') | ||
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width | ||
parser.add_argument('--export-batch-size', type=int, default=None, help='export batch size') | ||
parser.add_argument('--inplace', action='store_true', help='set Detect() inplace=True') | ||
parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0, 1, 2, 3 or cpu') | ||
parser.add_argument('--sensitivity-file', type=str, default=None, help='quantization sensitivity file') | ||
parser.add_argument('--quant-boundary', type=int, default=None, help='quantization boundary') | ||
parser.add_argument('--eval-yaml', type=str, default='./eval.yaml', help='evaluation config') | ||
args = parser.parse_args() | ||
args.img_size *= 2 if len(args.img_size) == 1 else 1 # expand | ||
print(args) | ||
t = time.time() | ||
|
||
# Check device | ||
cuda = args.device != 'cpu' and torch.cuda.is_available() | ||
device = torch.device('cuda:0' if cuda else 'cpu') | ||
assert not (device.type == 'cpu' and args.half), '--half only compatible with GPU export, i.e. use --device 0' | ||
# Load PyTorch model | ||
model = load_checkpoint(args.weights, map_location=device, inplace=True, fuse=True) # load FP32 model | ||
model.eval() | ||
yolov6_evaler = EvalerWrapper(eval_cfg=load_yaml(args.eval_yaml)) | ||
orig_mAP = yolov6_evaler.eval(model) | ||
|
||
for layer in model.modules(): | ||
if isinstance(layer, RepVGGBlock): | ||
layer.switch_to_deploy() | ||
|
||
for k, m in model.named_modules(): | ||
if isinstance(m, Conv): # assign export-friendly activations | ||
if isinstance(m.act, nn.SiLU): | ||
m.act = SiLU() | ||
elif isinstance(m, Detect): | ||
m.inplace = args.inplace | ||
|
||
model_ptq = load_ptq(model, args.calib_weights, device) | ||
|
||
quant_sensitivity = quant_sensitivity_load(args.sensitivity_file) | ||
quant_sensitivity.sort(key=lambda tup: tup[2], reverse=True) | ||
boundary = args.quant_boundary | ||
quantable_ops = [qops[0] for qops in quant_sensitivity[:boundary+1]] | ||
# only quantize ops in quantable_ops list | ||
partial_quant(model_ptq, quantable_ops=quantable_ops) | ||
# concat amax fusion | ||
for sub_fusion_list in opt_concat_fusion_list: | ||
ops = [get_module(model_ptq, op_name) for op_name in sub_fusion_list] | ||
concat_quant_amax_fuse(ops) | ||
|
||
part_mAP = yolov6_evaler.eval(model_ptq) | ||
print(part_mAP) | ||
# ONNX export | ||
quant_nn.TensorQuantizer.use_fb_fake_quant = True | ||
if args.export_batch_size is None: | ||
img = torch.zeros(1, 3, *args.img_size).to(device) | ||
export_file = args.weights.replace('.pt', '_partial_dynamic.onnx') # filename | ||
dynamic_axes = {"image_arrays": {0: "batch"}, "outputs": {0: "batch"}} | ||
torch.onnx.export(model_ptq, | ||
img, | ||
export_file, | ||
verbose=False, | ||
opset_version=13, | ||
training=torch.onnx.TrainingMode.EVAL, | ||
do_constant_folding=True, | ||
input_names=['image_arrays'], | ||
output_names=['outputs'], | ||
dynamic_axes=dynamic_axes | ||
) | ||
else: | ||
img = torch.zeros(args.export_batch_size, 3, *args.img_size).to(device) | ||
export_file = args.weights.replace('.pt', '_partial_bs{}.onnx'.format(args.export_batch_size)) # filename | ||
torch.onnx.export(model_ptq, | ||
img, | ||
export_file, | ||
verbose=False, | ||
opset_version=13, | ||
training=torch.onnx.TrainingMode.EVAL, | ||
do_constant_folding=True, | ||
input_names=['image_arrays'], | ||
output_names=['outputs'] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import torch | ||
import torch.nn as nn | ||
import copy | ||
|
||
from pytorch_quantization import nn as quant_nn | ||
from pytorch_quantization import tensor_quant | ||
from pytorch_quantization import calib | ||
from pytorch_quantization.tensor_quant import QuantDescriptor | ||
|
||
from tools.partial_quantization.utils import set_module, module_quant_disable | ||
|
||
def collect_stats(model, data_loader, batch_number, device='cuda'): | ||
"""Feed data to the network and collect statistic""" | ||
|
||
# Enable calibrators | ||
for name, module in model.named_modules(): | ||
if isinstance(module, quant_nn.TensorQuantizer): | ||
if module._calibrator is not None: | ||
module.disable_quant() | ||
module.enable_calib() | ||
else: | ||
module.disable() | ||
|
||
for i, data_tuple in enumerate(data_loader): | ||
image = data_tuple[0] | ||
image = image.float()/255.0 | ||
model(image.to(device)) | ||
if i + 1 >= batch_number: | ||
break | ||
|
||
# Disable calibrators | ||
for name, module in model.named_modules(): | ||
if isinstance(module, quant_nn.TensorQuantizer): | ||
if module._calibrator is not None: | ||
module.enable_quant() | ||
module.disable_calib() | ||
else: | ||
module.enable() | ||
|
||
def compute_amax(model, **kwargs): | ||
# Load calib result | ||
for name, module in model.named_modules(): | ||
if isinstance(module, quant_nn.TensorQuantizer): | ||
if module._calibrator is not None: | ||
if isinstance(module._calibrator, calib.MaxCalibrator): | ||
module.load_calib_amax() | ||
else: | ||
module.load_calib_amax(**kwargs) | ||
print(F"{name:40}: {module}") | ||
|
||
def quantable_op_check(k, quantable_ops): | ||
if quantable_ops is None: | ||
return True | ||
|
||
if k in quantable_ops: | ||
return True | ||
else: | ||
return False | ||
|
||
def quant_model_init(model, device): | ||
|
||
model_ptq = copy.deepcopy(model) | ||
model_ptq.eval() | ||
model_ptq.to(device) | ||
# print(model) | ||
conv2d_weight_default_desc = tensor_quant.QUANT_DESC_8BIT_CONV2D_WEIGHT_PER_CHANNEL | ||
conv2d_input_default_desc = QuantDescriptor(num_bits=8, calib_method='histogram') | ||
|
||
convtrans2d_weight_default_desc = tensor_quant.QUANT_DESC_8BIT_CONVTRANSPOSE2D_WEIGHT_PER_CHANNEL | ||
convtrans2d_input_default_desc = QuantDescriptor(num_bits=8, calib_method='histogram') | ||
|
||
for k, m in model_ptq.named_modules(): | ||
# print(k, m) | ||
if isinstance(m, nn.Conv2d): | ||
# print("in_channel = {}".format(m.in_channels)) | ||
# print("out_channel = {}".format(m.out_channels)) | ||
# print("kernel size = {}".format(m.kernel_size)) | ||
# print("stride size = {}".format(m.stride)) | ||
# print("pad size = {}".format(m.padding)) | ||
in_channels = m.in_channels | ||
out_channels = m.out_channels | ||
kernel_size = m.kernel_size | ||
stride = m.stride | ||
padding = m.padding | ||
quant_conv = quant_nn.QuantConv2d(in_channels, | ||
out_channels, | ||
kernel_size, | ||
stride, | ||
padding, | ||
quant_desc_input = conv2d_input_default_desc, | ||
quant_desc_weight = conv2d_weight_default_desc) | ||
quant_conv.weight.data.copy_(m.weight.detach()) | ||
if m.bias is not None: | ||
quant_conv.bias.data.copy_(m.bias.detach()) | ||
else: | ||
quant_conv.bias = None | ||
set_module(model_ptq, k, quant_conv) | ||
elif isinstance(m, nn.ConvTranspose2d): | ||
# print("in_channel = {}".format(m.in_channels)) | ||
# print("out_channel = {}".format(m.out_channels)) | ||
# print("kernel size = {}".format(m.kernel_size)) | ||
# print("stride size = {}".format(m.stride)) | ||
# print("pad size = {}".format(m.padding)) | ||
in_channels = m.in_channels | ||
out_channels = m.out_channels | ||
kernel_size = m.kernel_size | ||
stride = m.stride | ||
padding = m.padding | ||
quant_convtrans = quant_nn.QuantConvTranspose2d(in_channels, | ||
out_channels, | ||
kernel_size, | ||
stride, | ||
padding, | ||
quant_desc_input = convtrans2d_input_default_desc, | ||
quant_desc_weight = convtrans2d_weight_default_desc) | ||
quant_convtrans.weight.data.copy_(m.weight.detach()) | ||
if m.bias is not None: | ||
quant_convtrans.bias.data.copy_(m.bias.detach()) | ||
else: | ||
quant_convtrans.bias = None | ||
set_module(model_ptq, k, quant_convtrans) | ||
elif isinstance(m, nn.MaxPool2d): | ||
# print("kernel size = {}".format(m.kernel_size)) | ||
# print("stride size = {}".format(m.stride)) | ||
# print("pad size = {}".format(m.padding)) | ||
# print("dilation = {}".format(m.dilation)) | ||
# print("ceil mode = {}".format(m.ceil_mode)) | ||
kernel_size = m.kernel_size | ||
stride = m.stride | ||
padding = m.padding | ||
dilation = m.dilation | ||
ceil_mode = m.ceil_mode | ||
quant_maxpool2d = quant_nn.QuantMaxPool2d(kernel_size, | ||
stride, | ||
padding, | ||
dilation, | ||
ceil_mode, | ||
quant_desc_input = conv2d_input_default_desc) | ||
set_module(model_ptq, k, quant_maxpool2d) | ||
else: | ||
# module can not be quantized, continue | ||
continue | ||
|
||
return model_ptq.to(device) | ||
|
||
def do_ptq(model, train_loader, batch_number, device): | ||
model_ptq = quant_model_init(model, device) | ||
# It is a bit slow since we collect histograms on CPU | ||
with torch.no_grad(): | ||
collect_stats(model_ptq, train_loader, batch_number) | ||
compute_amax(model_ptq, method='entropy') | ||
return model_ptq | ||
|
||
def load_ptq(model, calib_path, device): | ||
model_ptq = quant_model_init(model, device) | ||
model_ptq.load_state_dict(torch.load(calib_path)['model'].state_dict()) | ||
return model_ptq | ||
|
||
def partial_quant(model_ptq, quantable_ops=None): | ||
# ops not in quantable_ops will reserve full-precision. | ||
for k, m in model_ptq.named_modules(): | ||
if quantable_op_check(k, quantable_ops): | ||
continue | ||
# enable full-precision | ||
if isinstance(m, quant_nn.QuantConv2d) or \ | ||
isinstance(m, quant_nn.QuantConvTranspose2d) or \ | ||
isinstance(m, quant_nn.QuantMaxPool2d): | ||
module_quant_disable(model_ptq, k) |
Oops, something went wrong.