forked from meituan/YOLOv6
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request meituan#740 from meituan/feature/make_onnx2trt_sup…
…port_rectangle_input Fix some bugs about evaluation codes of TensorRT model.
- Loading branch information
Showing
7 changed files
with
74 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,11 +51,12 @@ def parse_args(): | |
parser.add_argument('--force_no_pad', type=bool, default=True, help='for no extra pad in letterbox') | ||
parser.add_argument('--visualize', '-v', action="store_true", default=False, help='visualize demo') | ||
parser.add_argument('--num_imgs_to_visualize', type=int, default=10, help='number of images to visualize') | ||
parser.add_argument('--do_pr_metric', type=bool, default=False, help='use pr_metric to evaluate models') | ||
parser.add_argument('--do_pr_metric', action='store_true', help='use pr_metric to evaluate models') | ||
parser.add_argument('--plot_curve', type=bool, default=True, help='plot curve for pr_metric') | ||
parser.add_argument('--plot_confusion_matrix', type=bool, default=False, help='plot confusion matrix ') | ||
parser.add_argument('--verbose', type=bool, default=False, help='report mAP by class') | ||
parser.add_argument('--plot_confusion_matrix', action='store_true', help='plot confusion matrix ') | ||
parser.add_argument('--verbose', action='store_true', help='report mAP by class') | ||
parser.add_argument('--save_dir',default='', help='whether use pr_metric') | ||
parser.add_argument('--is_end2end', action='store_true', help='whether the model is end2end (build with NMS)') | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
@@ -99,15 +100,15 @@ def check_args(args): | |
sys.exit('%s is not a valid file' % args.annotations) | ||
|
||
|
||
def generate_results(data_class, model_names, do_pr_metric, plot_confusion_matrix, processor, imgs_dir, labels_dir, valid_images, results_file, conf_thres, iou_thres, is_coco, batch_size=1, test_load_size=640, visualize=False, num_imgs_to_visualize=0): | ||
def generate_results(data_class, model_names, do_pr_metric, plot_confusion_matrix, processor, imgs_dir, labels_dir, valid_images, results_file, conf_thres, iou_thres, is_coco, batch_size=1, test_load_size=640, visualize=False, num_imgs_to_visualize=0, imgname2id={}): | ||
"""Run detection on each jpg and write results to file.""" | ||
results = [] | ||
pbar = tqdm(range(math.ceil(len(valid_images)/batch_size)), desc="TRT-Model test in val datasets.") | ||
idx = 0 | ||
num_visualized = 0 | ||
stats= [] | ||
seen = 0 | ||
if do_pr_metric: | ||
stats= [] | ||
seen = 0 | ||
iouv = torch.linspace(0.5, 0.95, 10) # iou vector for [email protected]:0.95 | ||
niou = iouv.numel() | ||
if plot_confusion_matrix: | ||
|
@@ -148,13 +149,10 @@ def generate_results(data_class, model_names, do_pr_metric, plot_confusion_matri | |
source_imgs.append(img_src) | ||
shape = (h0, w0), ((h / h0, w / w0), pad) | ||
shapes.append(shape) | ||
if is_coco: | ||
image_ids.append(int(valid_images[idx].split('.')[0].split('_')[-1])) | ||
else: | ||
image_ids.append(valid_images[idx].split('.')[0].split('_')[-1]) | ||
assert valid_images[idx] in imgname2id.keys(), f'valid_images[idx] not in annotations you provided.' | ||
image_ids.append(imgname2id[valid_images[idx]]) | ||
idx += 1 | ||
output = processor.inference(torch.stack(preprocessed_imgs, axis=0)) | ||
|
||
for j in range(len(shapes)): | ||
pred = processor.post_process(output[j].unsqueeze(0), shapes[j], conf_thres = conf_thres, iou_thres = iou_thres) | ||
|
||
|
@@ -167,6 +165,8 @@ def generate_results(data_class, model_names, do_pr_metric, plot_confusion_matri | |
w = float(p[2] - p[0]) | ||
h = float(p[3] - p[1]) | ||
s = float(p[4]) | ||
# Warning, some dataset, the category id is start from 1, so that the category id must add 1. | ||
# For example, change the line bellow to: 'category_id': data_class[int(p[5])] if is_coco else int(p[5]) + 1, | ||
results.append({'image_id': image_ids[j], | ||
'category_id': data_class[int(p[5])] if is_coco else int(p[5]), | ||
'bbox': [round(x, 3) for x in [x, y, w, h]], | ||
|
@@ -215,6 +215,7 @@ def generate_results(data_class, model_names, do_pr_metric, plot_confusion_matri | |
num_visualized += 1 | ||
|
||
with open(results_file, 'w') as f: | ||
LOGGER.info(f'saving coco format detection resuslt to {results_file}') | ||
f.write(json.dumps(results, indent=4)) | ||
return stats, seen | ||
|
||
|
@@ -256,13 +257,17 @@ def main(): | |
model_names = list(range(0, args.class_num)) | ||
|
||
# setup processor | ||
processor = Processor(model=args.model, scale_exact=args.scale_exact, return_int=args.letterbox_return_int, force_no_pad=args.force_no_pad) | ||
processor = Processor(model=args.model, scale_exact=args.scale_exact, return_int=args.letterbox_return_int, force_no_pad=args.force_no_pad, is_end2end=args.is_end2end) | ||
image_names = [p for p in os.listdir(args.imgs_dir) if p.split(".")[-1].lower() in IMG_FORMATS] | ||
# Eliminate data with missing labels. | ||
with open(args.annotations) as f: | ||
coco_format_annotation = json.load(f) | ||
# Get image names from coco format annotations. | ||
coco_format_imgs = [x['file_name'] for x in coco_format_annotation['images']] | ||
# make a projection of image names and ids. | ||
imgname2id = {} | ||
for item in coco_format_annotation['images']: | ||
imgname2id[item['file_name']] = item['id'] | ||
valid_images = [] | ||
for img_name in image_names: | ||
img_name_wo_ext = os.path.splitext(img_name)[0] | ||
|
@@ -274,7 +279,7 @@ def main(): | |
assert len(valid_images) > 0, 'No valid images are found. Please check you image format or whether annotation file is match.' | ||
#targets=[j for j in os.listdir(args.labels_dir) if j.endswith('.txt')] | ||
stats, seen = generate_results(data_class, model_names, args.do_pr_metric, args.plot_confusion_matrix, processor, args.imgs_dir, args.labels_dir, valid_images, results_file, args.conf_thres, args.iou_thres, args.is_coco, batch_size=args.batch_size, test_load_size=args.test_load_size, | ||
visualize=args.visualize, num_imgs_to_visualize=args.num_imgs_to_visualize) | ||
visualize=args.visualize, num_imgs_to_visualize=args.num_imgs_to_visualize, imgname2id=imgname2id) | ||
|
||
# Run COCO mAP evaluation | ||
# Reference: https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters