diff --git a/deploy/ONNX/export_onnx.py b/deploy/ONNX/export_onnx.py index e178cf38..56c50f7c 100644 --- a/deploy/ONNX/export_onnx.py +++ b/deploy/ONNX/export_onnx.py @@ -23,7 +23,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--weights', type=str, default='./yolov6s.pt', help='weights path') - parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size') # height, width + parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='image size, the order is: height width') # height, width parser.add_argument('--batch-size', type=int, default=1, help='batch size') parser.add_argument('--half', action='store_true', help='FP16 half-precision export') parser.add_argument('--inplace', action='store_true', help='set Detect() inplace=True') diff --git a/deploy/TensorRT/Processor.py b/deploy/TensorRT/Processor.py index 7b8c3587..fb98f7bb 100644 --- a/deploy/TensorRT/Processor.py +++ b/deploy/TensorRT/Processor.py @@ -77,13 +77,15 @@ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleu class Processor(): - def __init__(self, model, num_classes=80, num_layers=3, anchors=1, device=torch.device('cuda:0'), return_int=False, scale_exact=False, force_no_pad=False): + def __init__(self, model, num_classes=80, num_layers=3, anchors=1, device=torch.device('cuda:0'), return_int=False, scale_exact=False, force_no_pad=False, is_end2end=False): # load tensorrt engine) self.return_int = return_int self.scale_exact = scale_exact self.force_no_pad = force_no_pad + self.is_end2end = is_end2end Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr')) self.logger = trt.Logger(trt.Logger.INFO) + trt.init_libnvinfer_plugins(self.logger, namespace="") self.runtime = trt.Runtime(self.logger) with open(model, "rb") as f: self.engine = self.runtime.deserialize_cuda_engine(f.read()) @@ -133,7 +135,7 @@ def pre_process(self, img_src, input_shape=None,): """Preprocess an image before TRT YOLO inferencing. """ input_shape = input_shape if input_shape is not None else self.input_shape - image, ratio, pad = letterbox(img_src, input_shape, auto=False, return_int=self.return_int) + image, ratio, pad = letterbox(img_src, input_shape, auto=False, return_int=self.return_int, scaleup=True) # Convert image = image.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB image = torch.from_numpy(np.ascontiguousarray(image)).to(self.device).float() @@ -144,7 +146,14 @@ def inference(self, inputs): self.binding_addrs[self.input_names[0]] = int(inputs.data_ptr()) #self.binding_addrs['x2paddle_image_arrays'] = int(inputs.data_ptr()) self.context.execute_v2(list(self.binding_addrs.values())) - output = self.bindings[self.output_names[0]].data + if self.is_end2end: + nums = self.bindings['num_dets'].data + boxes = self.bindings['det_boxes'].data + scores = self.bindings['det_scores'].data + classes = self.bindings['det_classes'].data + output = torch.cat((boxes, scores[:,:,None], classes[:,:,None]), axis=-1) + else: + output = self.bindings[self.output_names[0]].data #output = self.bindings['save_infer_model/scale_0.tmp_0'].data return output @@ -174,7 +183,10 @@ def output_reformate(self, outputs): return torch.cat(z, 1) def post_process(self, outputs, img_shape, conf_thres=0.5, iou_thres=0.6): - det_t = self.non_max_suppression(outputs, conf_thres, iou_thres, multi_label=True) + if self.is_end2end: + det_t = outputs + else: + det_t = self.non_max_suppression(outputs, conf_thres, iou_thres, multi_label=True) self.scale_coords(self.input_shape, det_t[0][:, :4], img_shape[0], img_shape[1]) return det_t[0] diff --git a/deploy/TensorRT/README.md b/deploy/TensorRT/README.md index 0d491689..7c0e731d 100644 --- a/deploy/TensorRT/README.md +++ b/deploy/TensorRT/README.md @@ -70,13 +70,31 @@ Then run the demo: ```shell ./yolov6 ../you.engine -i image_path ``` -# Testing on image -You can do testing on images using .trt weights, just give path of image directory & its annotation path +# Evaluate the performace + You can evaluate the performace of the TensorRT model. + ``` + python deploy/TensorRT/eval_yolo_trt.py \ + --imgs_dir /path/to/images/val \ + --labels_dir /path/to/labels/val\ + --annotations /path/to/coco/format/annotation/file \ --batch 1 \ + --img_size 640 \ + --model /path/to/tensorrt/model \ + --do_pr_metric --is_coco + ``` +Tips: +`--is_coco`: if you are evaluating the COCO dataset, add this, if not, do not add this parameter. +`--do_pr_metric`: If you want to get PR metric, add this. + +For example: ``` -python3 deploy/TensorRT/eval_yolo_trt.py -v -m model.trt \ ---imgs-dir /workdir/datasets/coco/images/val2017 \ ---annotations /workdir/datasets/coco/annotations/instances_val2017.json \ ---conf-thres 0.40 --iou-thres 0.45 \ ---is_coco -``` +python deploy/TensorRT/eval_yolo_trt.py \ + --imgs_dir /workdir/datasets/coco/images/val2017/ \ + --labels_dir /workdir/datasets/coco/labels/val2017\ + --annotations /workdir/datasets/coco/annotations/instances_val2017.json \ + --batch 1 \ + --img_size 640 \ + --model weights/yolov6n.trt \ + --do_pr_metric --is_coco + +``` \ No newline at end of file diff --git a/deploy/TensorRT/calibrator.py b/deploy/TensorRT/calibrator.py index 12c0cc84..fec895f1 100644 --- a/deploy/TensorRT/calibrator.py +++ b/deploy/TensorRT/calibrator.py @@ -22,6 +22,9 @@ trt.IInt8MinMaxCalibrator """ +IMG_FORMATS = [".bmp", ".jpg", ".jpeg", ".png", ".tif", ".tiff", ".dng", ".webp", ".mpo"] +IMG_FORMATS.extend([f.upper() for f in IMG_FORMATS]) + class Calibrator(trt.IInt8MinMaxCalibrator): def __init__(self, stream, cache_file=""): trt.IInt8MinMaxCalibrator.__init__(self) @@ -74,7 +77,7 @@ def __init__(self, batch_size, batch_num, calib_img_dir, input_w, input_h): self.input_h = input_h self.input_w = input_w # self.img_list = [i.strip() for i in open('calib.txt').readlines()] - self.img_list = glob.glob(os.path.join(calib_img_dir, "*.jpg")) + self.img_list = [os.path.join(calib_img_dir, x) for x in os.listdir(calib_img_dir) if os.path.splitext(x)[-1] in IMG_FORMATS] assert len(self.img_list) > self.batch_size * self.length, \ '{} must contains more than '.format(calib_img_dir) + str(self.batch_size * self.length) + ' images to calib' print('found all {} images to calib.'.format(len(self.img_list))) @@ -86,9 +89,10 @@ def reset(self): def next_batch(self): if self.index < self.length: for i in range(self.batch_size): - assert os.path.exists(self.img_list[i + self.index * self.batch_size]), 'not found!!' + assert os.path.exists(self.img_list[i + self.index * self.batch_size]), f'{self.img_list[i + self.index * self.batch_size]} not found!!' img = cv2.imread(self.img_list[i + self.index * self.batch_size]) - img = precess_image(img, self.input_h, 32) + img = precess_image(img, [self.input_h, self.input_w], 32) + print(img.shape) self.calibration_data[i] = img self.index += 1 diff --git a/deploy/TensorRT/eval_yolo_trt.py b/deploy/TensorRT/eval_yolo_trt.py index efa7fe2a..ee28a454 100644 --- a/deploy/TensorRT/eval_yolo_trt.py +++ b/deploy/TensorRT/eval_yolo_trt.py @@ -51,11 +51,12 @@ def parse_args(): parser.add_argument('--force_no_pad', type=bool, default=True, help='for no extra pad in letterbox') parser.add_argument('--visualize', '-v', action="store_true", default=False, help='visualize demo') parser.add_argument('--num_imgs_to_visualize', type=int, default=10, help='number of images to visualize') - parser.add_argument('--do_pr_metric', type=bool, default=False, help='use pr_metric to evaluate models') + parser.add_argument('--do_pr_metric', action='store_true', help='use pr_metric to evaluate models') parser.add_argument('--plot_curve', type=bool, default=True, help='plot curve for pr_metric') - parser.add_argument('--plot_confusion_matrix', type=bool, default=False, help='plot confusion matrix ') - parser.add_argument('--verbose', type=bool, default=False, help='report mAP by class') + parser.add_argument('--plot_confusion_matrix', action='store_true', help='plot confusion matrix ') + parser.add_argument('--verbose', action='store_true', help='report mAP by class') parser.add_argument('--save_dir',default='', help='whether use pr_metric') + parser.add_argument('--is_end2end', action='store_true', help='whether the model is end2end (build with NMS)') args = parser.parse_args() return args @@ -99,15 +100,15 @@ def check_args(args): sys.exit('%s is not a valid file' % args.annotations) -def generate_results(data_class, model_names, do_pr_metric, plot_confusion_matrix, processor, imgs_dir, labels_dir, valid_images, results_file, conf_thres, iou_thres, is_coco, batch_size=1, test_load_size=640, visualize=False, num_imgs_to_visualize=0): +def generate_results(data_class, model_names, do_pr_metric, plot_confusion_matrix, processor, imgs_dir, labels_dir, valid_images, results_file, conf_thres, iou_thres, is_coco, batch_size=1, test_load_size=640, visualize=False, num_imgs_to_visualize=0, imgname2id={}): """Run detection on each jpg and write results to file.""" results = [] pbar = tqdm(range(math.ceil(len(valid_images)/batch_size)), desc="TRT-Model test in val datasets.") idx = 0 num_visualized = 0 + stats= [] + seen = 0 if do_pr_metric: - stats= [] - seen = 0 iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95 niou = iouv.numel() if plot_confusion_matrix: @@ -148,13 +149,10 @@ def generate_results(data_class, model_names, do_pr_metric, plot_confusion_matri source_imgs.append(img_src) shape = (h0, w0), ((h / h0, w / w0), pad) shapes.append(shape) - if is_coco: - image_ids.append(int(valid_images[idx].split('.')[0].split('_')[-1])) - else: - image_ids.append(valid_images[idx].split('.')[0].split('_')[-1]) + assert valid_images[idx] in imgname2id.keys(), f'valid_images[idx] not in annotations you provided.' + image_ids.append(imgname2id[valid_images[idx]]) idx += 1 output = processor.inference(torch.stack(preprocessed_imgs, axis=0)) - for j in range(len(shapes)): pred = processor.post_process(output[j].unsqueeze(0), shapes[j], conf_thres = conf_thres, iou_thres = iou_thres) @@ -167,6 +165,8 @@ def generate_results(data_class, model_names, do_pr_metric, plot_confusion_matri w = float(p[2] - p[0]) h = float(p[3] - p[1]) s = float(p[4]) + # Warning, some dataset, the category id is start from 1, so that the category id must add 1. + # For example, change the line bellow to: 'category_id': data_class[int(p[5])] if is_coco else int(p[5]) + 1, results.append({'image_id': image_ids[j], 'category_id': data_class[int(p[5])] if is_coco else int(p[5]), 'bbox': [round(x, 3) for x in [x, y, w, h]], @@ -215,6 +215,7 @@ def generate_results(data_class, model_names, do_pr_metric, plot_confusion_matri num_visualized += 1 with open(results_file, 'w') as f: + LOGGER.info(f'saving coco format detection resuslt to {results_file}') f.write(json.dumps(results, indent=4)) return stats, seen @@ -256,13 +257,17 @@ def main(): model_names = list(range(0, args.class_num)) # setup processor - processor = Processor(model=args.model, scale_exact=args.scale_exact, return_int=args.letterbox_return_int, force_no_pad=args.force_no_pad) + processor = Processor(model=args.model, scale_exact=args.scale_exact, return_int=args.letterbox_return_int, force_no_pad=args.force_no_pad, is_end2end=args.is_end2end) image_names = [p for p in os.listdir(args.imgs_dir) if p.split(".")[-1].lower() in IMG_FORMATS] # Eliminate data with missing labels. with open(args.annotations) as f: coco_format_annotation = json.load(f) # Get image names from coco format annotations. coco_format_imgs = [x['file_name'] for x in coco_format_annotation['images']] + # make a projection of image names and ids. + imgname2id = {} + for item in coco_format_annotation['images']: + imgname2id[item['file_name']] = item['id'] valid_images = [] for img_name in image_names: img_name_wo_ext = os.path.splitext(img_name)[0] @@ -274,7 +279,7 @@ def main(): assert len(valid_images) > 0, 'No valid images are found. Please check you image format or whether annotation file is match.' #targets=[j for j in os.listdir(args.labels_dir) if j.endswith('.txt')] stats, seen = generate_results(data_class, model_names, args.do_pr_metric, args.plot_confusion_matrix, processor, args.imgs_dir, args.labels_dir, valid_images, results_file, args.conf_thres, args.iou_thres, args.is_coco, batch_size=args.batch_size, test_load_size=args.test_load_size, - visualize=args.visualize, num_imgs_to_visualize=args.num_imgs_to_visualize) + visualize=args.visualize, num_imgs_to_visualize=args.num_imgs_to_visualize, imgname2id=imgname2id) # Run COCO mAP evaluation # Reference: https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb diff --git a/deploy/TensorRT/onnx_to_trt.py b/deploy/TensorRT/onnx_to_trt.py index 93a8e568..5d220783 100644 --- a/deploy/TensorRT/onnx_to_trt.py +++ b/deploy/TensorRT/onnx_to_trt.py @@ -143,8 +143,8 @@ def main(): '--qat', action='store_true', help='whether the onnx model is qat; if it is, the int8 calibrator is not needed') # If enable int8(not post-QAT model), then set the following - parser.add_argument('--img-size', type=int, - default=640, help='image size of model input') + parser.add_argument('--img-size', nargs='+', type=int, + default=[640, 640], help='image size of model input, the order is: height width') parser.add_argument('--batch-size', type=int, default=128, help='batch size for training: default 64') parser.add_argument('--num-calib-batch', default=6, type=int, @@ -159,8 +159,10 @@ def main(): if args.dtype == "int8" and not args.qat: from calibrator import DataLoader, Calibrator + if len(args.img_size) == 1: + args.img_size = [args.img_size[0], args.img_size[0]] calib_loader = DataLoader(args.batch_size, args.num_calib_batch, args.calib_img_dir, - args.img_size, args.img_size) + args.img_size[1], args.img_size[0]) engine = build_engine_from_onnx(args.model, args.dtype, args.verbose, int8_calib=True, calib_loader=calib_loader, calib_cache=args.calib_cache) else: diff --git a/yolov6/core/inferer.py b/yolov6/core/inferer.py index aa35ed64..7b186675 100644 --- a/yolov6/core/inferer.py +++ b/yolov6/core/inferer.py @@ -87,7 +87,7 @@ def infer(self, conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir, # Create output files in nested dirs that mirrors the structure of the images' dirs rel_path = osp.relpath(osp.dirname(img_path), osp.dirname(self.source)) save_path = osp.join(save_dir, rel_path, osp.basename(img_path)) # im.jpg - txt_path = osp.join(save_dir, rel_path, osp.splitext(osp.basename(img_path))[0]) + txt_path = osp.join(save_dir, rel_path, 'labels', osp.splitext(osp.basename(img_path))[0]) os.makedirs(osp.join(save_dir, rel_path), exist_ok=True) gn = torch.tensor(img_src.shape)[[1, 0, 1, 0]] # normalization gain whwh