Skip to content

Commit

Permalink
Merge pull request meituan#609 from meituan/fix_trt_inference_visuali…
Browse files Browse the repository at this point in the history
…ze_bug

Fix the bug of visualization in tensorrt inference
  • Loading branch information
mtjhl authored Nov 11, 2022
2 parents d6d89b6 + 4f15ab9 commit 5bdb263
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
7 changes: 3 additions & 4 deletions deploy/TensorRT/Processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(self, model, num_classes=80, num_layers=3, anchors=1, device=torch.

def detect(self, img):
"""Detect objects in the input image."""
resized, _, _ = self.pre_process(img, self.input_shape)
resized, _ = self.pre_process(img, self.input_shape)
outputs = self.inference(resized)
return outputs

Expand All @@ -133,10 +133,10 @@ def pre_process(self, img_src, input_shape=None,):
image = image.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
image = torch.from_numpy(np.ascontiguousarray(image)).to(self.device).float()
image = image / 255. # 0 - 255 to 0.0 - 1.0
return image, pad, img_src
return image, pad

def inference(self, inputs):
self.binding_addrs['image_arrays'] = int(inputs.data_ptr())
self.binding_addrs['images'] = int(inputs.data_ptr())
#self.binding_addrs['x2paddle_image_arrays'] = int(inputs.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
output = self.bindings['outputs'].data
Expand Down Expand Up @@ -198,7 +198,6 @@ def non_max_suppression(self, prediction, conf_thres=0.25, iou_thres=0.45, class
Returns:
list of detections, echo item is one tensor with shape (num_boxes, 6), 6 is for [xyxy, conf, cls].
"""

num_classes = prediction.shape[2] - 5 # number of classes
pred_candidates = prediction[..., 4] > conf_thres # candidates

Expand Down
7 changes: 4 additions & 3 deletions deploy/TensorRT/eval_yolo_trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def generate_results(processor, imgs_dir, jpgs, results_file, conf_thres, iou_th
image_ids = []
shapes = []
for i in range(batch_size):
idx += 1
if (idx == len(jpgs)): break
img = cv2.imread(os.path.join(imgs_dir, jpgs[idx]))
img_src = img.copy()
# shapes.append(img.shape)
h0, w0 = img.shape[:2]
r = test_load_size / max(h0, w0)
Expand All @@ -95,13 +95,14 @@ def generate_results(processor, imgs_dir, jpgs, results_file, conf_thres, iou_th
if r < 1 else cv2.INTER_LINEAR,
)
h, w = img.shape[:2]
imgs[i], pad, img_src = processor.pre_process(img)
imgs[i], pad = processor.pre_process(img)
source_imgs.append(img_src)
shape = (h0, w0), ((h / h0, w / w0), pad)
shapes.append(shape)
image_ids.append(int(jpgs[idx].split('.')[0].split('_')[-1]))
idx += 1
output = processor.inference(imgs)

for j in range(len(shapes)):
pred = processor.post_process(output[j].unsqueeze(0), shapes[j], conf_thres=conf_thres, iou_thres=iou_thres)
if visualize and num_visualized < num_imgs_to_visualize:
Expand Down
4 changes: 2 additions & 2 deletions tools/qat/qat_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def zero_scale_fix(model, device):
opset_version=13,
training=torch.onnx.TrainingMode.EVAL,
do_constant_folding=True,
input_names=['image_arrays'],
input_names=['images'],
output_names=['num_dets', 'det_boxes', 'det_scores', 'det_classes']
if args.end2end else ['outputs'],
dynamic_axes=dynamic_axes
Expand All @@ -157,7 +157,7 @@ def zero_scale_fix(model, device):
opset_version=13,
training=torch.onnx.TrainingMode.EVAL,
do_constant_folding=True,
input_names=['image_arrays'],
input_names=['images'],
output_names=['num_dets', 'det_boxes', 'det_scores', 'det_classes']
if args.end2end else ['outputs'],
)

0 comments on commit 5bdb263

Please sign in to comment.