Skip to content

Commit

Permalink
fix: update minor
Browse files Browse the repository at this point in the history
  • Loading branch information
khwengXU committed Jul 7, 2022
1 parent f3065f4 commit 03be9ff
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 14 deletions.
2 changes: 1 addition & 1 deletion yolov6/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from yolov6.utils.checkpoint import load_state_dict, save_checkpoint, strip_optimizer
from yolov6.solver.build import build_optimizer, build_lr_scheduler


class Trainer:
def __init__(self, args, cfg, device):
self.args = args
Expand Down Expand Up @@ -65,7 +66,6 @@ def __init__(self, args, cfg, device):
self.batch_size = args.batch_size
self.img_size = args.img_size


# Training Process

def train(self):
Expand Down
2 changes: 1 addition & 1 deletion yolov6/core/evaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def convert_to_coco_format(self, outputs, imgs, paths, shapes, ids):

@staticmethod
def check_task(task):
if task not in ['train','val','speed']:
if task not in ['train', 'val', 'speed']:
raise Exception("task argument error: only support 'train' / 'val' / 'speed' task.")

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion yolov6/layers/dbb_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def transIII_1x1_kxk(k1, b1, k2, b2, groups):
k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :]
k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :]
k_slices.append(F.conv2d(k2_slice, k1_T_slice))
b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3)))
b_slices.append((k2_slice * b1[g * k1_group_width:(g+1) * k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3)))
k, b_hat = transIV_depthconcat(k_slices, b_slices)
return k, b_hat + b2

Expand Down
24 changes: 13 additions & 11 deletions yolov6/models/end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
import random


class ORT_NMS(torch.autograd.Function):

@staticmethod
Expand All @@ -25,6 +26,7 @@ def forward(ctx,
def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)


class TRT_NMS(torch.autograd.Function):
@staticmethod
def forward(
Expand Down Expand Up @@ -59,18 +61,18 @@ def symbolic(g,
score_activation=0,
score_threshold=0.25):
out = g.op("TRT::EfficientNMS_TRT",
boxes,
scores,
background_class_i=background_class,
box_coding_i=box_coding,
iou_threshold_f=iou_threshold,
max_output_boxes_i=max_output_boxes,
plugin_version_s=plugin_version,
score_activation_i=score_activation,
score_threshold_f=score_threshold,
outputs=4)
boxes,
scores,
background_class_i=background_class,
box_coding_i=box_coding,
iou_threshold_f=iou_threshold,
max_output_boxes_i=max_output_boxes,
plugin_version_s=plugin_version,
score_activation_i=score_activation,
score_threshold_f=score_threshold,
outputs=4)
nums, boxes, scores, classes = out
return nums,boxes,scores,classes
return nums, boxes, scores, classes



Expand Down

0 comments on commit 03be9ff

Please sign in to comment.