Skip to content

Commit

Permalink
Merge pull request meituan#493 from lippman1125/main
Browse files Browse the repository at this point in the history
update tiny & nano qat config & add zero scale fix
  • Loading branch information
Chilicyy authored Sep 22, 2022
2 parents b1d543c + ef15ff2 commit 2dce6fc
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 15 deletions.
4 changes: 2 additions & 2 deletions configs/repopt/yolov6_tiny_opt_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
num_bits = 8,
calib_batches = 4,
# 'max', 'histogram'
calib_method = 'histogram',
calib_method = 'max',
# 'entropy', 'percentile', 'mse'
histogram_amax_method='entropy',
histogram_amax_percentile=99.99,
Expand All @@ -73,7 +73,7 @@
)

qat = dict(
calib_pt = './assets/v6s_t_calib_histogram.pt',
calib_pt = './assets/v6s_t_calib_max.pt',
sensitive_layers_skip = False,
sensitive_layers_list=[],
)
Expand Down
4 changes: 2 additions & 2 deletions configs/repopt/yolov6n_opt_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
num_bits = 8,
calib_batches = 4,
# 'max', 'histogram'
calib_method = 'histogram',
calib_method = 'max',
# 'entropy', 'percentile', 'mse'
histogram_amax_method='entropy',
histogram_amax_percentile=99.99,
Expand All @@ -73,7 +73,7 @@
)

qat = dict(
calib_pt = './assets/v6s_n_calib_histogram.pt',
calib_pt = './assets/v6s_n_calib_max.pt',
sensitive_layers_skip = False,
sensitive_layers_list=[],
)
Expand Down
2 changes: 1 addition & 1 deletion tools/partial_quantization/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, eval_cfg):
os.makedirs(save_dir)

# reload thres/device/half/data according task
conf_thres = 0.001
conf_thres = 0.03
iou_thres = 0.65
device = Evaler.reload_device(device, None, task)
data = Evaler.reload_dataset(data) if isinstance(data, str) else data
Expand Down
2 changes: 1 addition & 1 deletion tools/partial_quantization/eval.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
task: 'val'
save_dir: 'runs/val/exp'
half: True
half: False
data: '../../data/coco.yaml'
batch_size: 32
img_size: 640
Expand Down
34 changes: 28 additions & 6 deletions tools/qat/qat_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,39 @@
('detect.reg_convs.2.conv', 'detect.cls_convs.2.conv'),
]

# python3 qat_export.py --weights yolov6s_v2_reopt.pt --quant-weights yolov6s_v2_reopt_qat_43.0.pt --export-batch-size 1
def zero_scale_fix(model, device):

for k, m in model.named_modules():
# print(k, m)
if isinstance(m, quant_nn.QuantConv2d) or \
isinstance(m, quant_nn.QuantConvTranspose2d):
# print(m)
# print(m._weight_quantizer._amax)
weight_amax = m._weight_quantizer._amax.detach().cpu().numpy()
# print(weight_amax)
print(k)
ones = np.ones_like(weight_amax)
print("zero scale number = {}".format(np.sum(weight_amax == 0.0)))
weight_amax = np.where(weight_amax == 0.0, ones, weight_amax)
m._weight_quantizer._amax.copy_(torch.from_numpy(weight_amax).to(device))
else:
# module can not be quantized, continue
continue

# python3 qat_export.py --weights yolov6s_v2_reopt.pt --quant-weights yolov6s_v2_reopt_qat_43.0.pt --export-batch-size 1 --conf ../../configs/repopt/yolov6s_opt_qat.py
# python3 qat_export.py --weights v6s_t.pt --quant-weights yolov6t_v2_reopt_qat_40.1.pt --export-batch-size 1 --conf ../../configs/repopt/yolov6_tiny_opt_qat.py
# python3 qat_export.py --weights v6s_n.pt --quant-weights yolov6n_v2_reopt_qat_34.9.pt --export-batch-size 1 --conf ../../configs/repopt/yolov6n_opt_qat.py
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='./yolov6s_v2_reopt.pt', help='weights path')
parser.add_argument('--quant-weights', type=str, default='./yolov6s_v2_reopt_qat_43.0.pt', help='calib weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width
parser.add_argument('--conf', type=str, default='../../configs/repopt/yolov6s_opt_qat.py', help='model config')
parser.add_argument('--export-batch-size', type=int, default=None, help='export batch size')
parser.add_argument('--calib', action='store_true', help='indicate calibrated model')
parser.add_argument('--calib', action='store_true', default=False, help='calibrated model')
parser.add_argument('--scale-fix', action='store_true', help='enable scale fix')
parser.add_argument('--fuse-bn', action='store_true', help='fuse bn')
parser.add_argument('--graph-opt', action='store_true', help='enable graph optimizer')
parser.add_argument('--skip-qat-sensitive', action='store_true', help='skip qat sensitive layers')
parser.add_argument('--skip-ptq-sensitive', action='store_true', help='skip ptq sensitive layers')
parser.add_argument('--inplace', action='store_true', help='set Detect() inplace=True')
parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0, 1, 2, 3 or cpu')
parser.add_argument('--eval-yaml', type=str, default='../partial_quantization/eval.yaml', help='evaluation config')
Expand Down Expand Up @@ -68,10 +88,12 @@
cfg = Config.fromfile(args.conf)
# init qat model
qat_init_model_manu(model, cfg, args)
model.neck.upsample_enable_quant()
model.neck.upsample_enable_quant(cfg.ptq.num_bits, cfg.ptq.calib_method)
ckpt = torch.load(args.quant_weights)
model.load_state_dict(ckpt['model'].float().state_dict())
model.to(device)
if args.scale_fix:
zero_scale_fix(model, device)
if args.graph_opt:
# concat amax fusion
for sub_fusion_list in op_concat_fusion_list:
Expand Down Expand Up @@ -112,4 +134,4 @@
do_constant_folding=True,
input_names=['image_arrays'],
output_names=['outputs']
)
)
2 changes: 1 addition & 1 deletion yolov6/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def quant_setup(self, model, cfg, device):
from tools.qat.qat_utils import qat_init_model_manu, skip_sensitive_layers
qat_init_model_manu(model, cfg, self.args)
# workaround
model.neck.upsample_enable_quant()
model.neck.upsample_enable_quant(cfg.ptq.num_bits, cfg.ptq.calib_method)
# if self.main_process:
# print(model)
# QAT
Expand Down
4 changes: 2 additions & 2 deletions yolov6/models/reppan.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ def __init__(
stride=2
)

def upsample_enable_quant(self):
def upsample_enable_quant(self, num_bits, calib_method):
print("Insert fakequant after upsample")
# Insert fakequant after upsample op to build TensorRT engine
from pytorch_quantization import nn as quant_nn
from pytorch_quantization.tensor_quant import QuantDescriptor
conv2d_input_default_desc = QuantDescriptor(num_bits=8, calib_method='histogram')
conv2d_input_default_desc = QuantDescriptor(num_bits=num_bits, calib_method=calib_method)
self.upsample_feat0_quant = quant_nn.TensorQuantizer(conv2d_input_default_desc)
self.upsample_feat1_quant = quant_nn.TensorQuantizer(conv2d_input_default_desc)
# global _QUANT
Expand Down

0 comments on commit 2dce6fc

Please sign in to comment.