diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000000..359bb5307e
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,3 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000000..28c863b6ff
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,46 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000000..105ce2da2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000000..d1e22ecb89
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000000..ecffae6dc6
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000000..94a25f7f4c
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/yolov7.iml b/.idea/yolov7.iml
new file mode 100644
index 0000000000..8b8c395472
--- /dev/null
+++ b/.idea/yolov7.iml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/README.md b/README.md
index 4d1f48c9b6..1892b1b067 100644
--- a/README.md
+++ b/README.md
@@ -151,6 +151,13 @@ python detect.py --weights yolov7.pt --conf 0.25 --img-size 640 --source inferen
+
+## Export
+Use the args `--include-nms` can to export end to end onnx model which include the `EfficientNMS`.
+```shell
+python models/export.py --weights yolov7.pt --grid --include-nms
+```
+
## Citation
```
diff --git a/export.py b/export.py
index 06dfc942c3..fd4975f6f9 100644
--- a/export.py
+++ b/export.py
@@ -12,6 +12,7 @@
from utils.activations import Hardswish, SiLU
from utils.general import set_logging, check_img_size
from utils.torch_utils import select_device
+from utils.add_nms import RegisterNMS
if __name__ == '__main__':
parser = argparse.ArgumentParser()
@@ -22,6 +23,7 @@
parser.add_argument('--grid', action='store_true', help='export Detect() layer grid')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
+ parser.add_argument('--include-nms', action='store_true', help='export end2end onnx')
opt = parser.parse_args()
opt.img_size *= 2 if len(opt.img_size) == 1 else 1 # expand
print(opt)
@@ -52,7 +54,9 @@
# m.forward = m.forward_export # assign forward (optional)
model.model[-1].export = not opt.grid # set Detect() layer grid export
y = model(img) # dry run
-
+ if opt.include_nms:
+ model.model[-1].include_nms = True
+ y = None
# TorchScript export
try:
print('\nStarting TorchScript export with torch %s...' % torch.__version__)
@@ -75,16 +79,23 @@
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # size(1,3,640,640)
'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)
- # Checks
- onnx_model = onnx.load(f) # load onnx model
- onnx.checker.check_model(onnx_model) # check onnx model
-
- # # Metadata
- # d = {'stride': int(max(model.stride))}
- # for k, v in d.items():
- # meta = onnx_model.metadata_props.add()
- # meta.key, meta.value = k, str(v)
- # onnx.save(onnx_model, f)
+ if opt.include_nms:
+ print('Registering NMS plugin...')
+ mo = RegisterNMS(f)
+ mo.register_nms()
+ mo.save(f)
+ else:
+ # Checks
+ onnx_model = onnx.load(f) # load onnx model
+ onnx.checker.check_model(onnx_model) # check onnx model
+ # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
+
+ # # Metadata
+ # d = {'stride': int(max(model.stride))}
+ # for k, v in d.items():
+ # meta = onnx_model.metadata_props.add()
+ # meta.key, meta.value = k, str(v)
+ # onnx.save(onnx_model, f)
if opt.simplify:
try:
@@ -95,11 +106,9 @@
assert check, 'assert check failed'
except Exception as e:
print(f'Simplifier failure: {e}')
- # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
print('ONNX export success, saved as %s' % f)
except Exception as e:
print('ONNX export failure: %s' % e)
-
# CoreML export
try:
import coremltools as ct
diff --git a/models/common.py b/models/common.py
index 53e3f87193..111af708de 100644
--- a/models/common.py
+++ b/models/common.py
@@ -236,7 +236,7 @@ def forward(self, x):
class ResX(Res):
# ResNet bottleneck
def __init__(self, c1, c2, shortcut=True, g=32, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
- super().__init__(c1, c2, shortcu, g, e)
+ super().__init__(c1, c2, shortcut, g, e)
c_ = int(c2 * e) # hidden channels
diff --git a/models/yolo.py b/models/yolo.py
index 5f45aad886..5d2845f1f0 100644
--- a/models/yolo.py
+++ b/models/yolo.py
@@ -5,7 +5,7 @@
sys.path.append('./') # to run '$ python *.py' files in subdirectories
logger = logging.getLogger(__name__)
-
+import torch
from models.common import *
from models.experimental import *
from utils.autoanchor import check_anchor_order
@@ -23,7 +23,7 @@
class Detect(nn.Module):
stride = None # strides computed during build
export = False # onnx export
-
+ include_nms = False
def __init__(self, nc=80, anchors=(), ch=()): # detection layer
super(Detect, self).__init__()
self.nc = nc # number of classes
@@ -48,7 +48,6 @@ def forward(self, x):
if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
-
y = x[i].sigmoid()
if not torch.onnx.is_in_onnx_export():
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
@@ -59,13 +58,28 @@ def forward(self, x):
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no))
- return x if self.training else (torch.cat(z, 1), x)
+ if self.include_nms:
+ z = self.convert(z)
+
+ return x if self.training else (z, ) if self.include_nms else (torch.cat(z, 1), x)
@staticmethod
def _make_grid(nx=20, ny=20):
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
+ def convert(self, z):
+ z = torch.cat(z, 1)
+ box = z[:, :, :4]
+ conf = z[:, :, 4:5]
+ score = z[:, :, 5:]
+ score *= conf
+ convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
+ dtype=torch.float32,
+ device=z.device)
+ box @= convert_matrix
+ return (box, score)
+
class IDetect(nn.Module):
stride = None # strides computed during build
diff --git a/utils/add_nms.py b/utils/add_nms.py
new file mode 100644
index 0000000000..8cfa23919e
--- /dev/null
+++ b/utils/add_nms.py
@@ -0,0 +1,151 @@
+import numpy as np
+import onnx
+from onnx import shape_inference
+import onnx_graphsurgeon as gs
+import logging
+
+LOGGER = logging.getLogger(__name__)
+
+class RegisterNMS(object):
+ def __init__(
+ self,
+ onnx_model_path: str,
+ precision: str = "fp32",
+ ):
+
+ self.graph = gs.import_onnx(onnx.load(onnx_model_path))
+ assert self.graph
+ LOGGER.info("ONNX graph created successfully")
+ # Fold constants via ONNX-GS that PyTorch2ONNX may have missed
+ self.graph.fold_constants()
+ self.precision = precision
+ self.batch_size = 1
+ def infer(self):
+ """
+ Sanitize the graph by cleaning any unconnected nodes, do a topological resort,
+ and fold constant inputs values. When possible, run shape inference on the
+ ONNX graph to determine tensor shapes.
+ """
+ for _ in range(3):
+ count_before = len(self.graph.nodes)
+
+ self.graph.cleanup().toposort()
+ try:
+ for node in self.graph.nodes:
+ for o in node.outputs:
+ o.shape = None
+ model = gs.export_onnx(self.graph)
+ model = shape_inference.infer_shapes(model)
+ self.graph = gs.import_onnx(model)
+ except Exception as e:
+ LOGGER.info(f"Shape inference could not be performed at this time:\n{e}")
+ try:
+ self.graph.fold_constants(fold_shapes=True)
+ except TypeError as e:
+ LOGGER.error(
+ "This version of ONNX GraphSurgeon does not support folding shapes, "
+ f"please upgrade your onnx_graphsurgeon module. Error:\n{e}"
+ )
+ raise
+
+ count_after = len(self.graph.nodes)
+ if count_before == count_after:
+ # No new folding occurred in this iteration, so we can stop for now.
+ break
+
+ def save(self, output_path):
+ """
+ Save the ONNX model to the given location.
+ Args:
+ output_path: Path pointing to the location where to write
+ out the updated ONNX model.
+ """
+ self.graph.cleanup().toposort()
+ model = gs.export_onnx(self.graph)
+ onnx.save(model, output_path)
+ LOGGER.info(f"Saved ONNX model to {output_path}")
+
+ def register_nms(
+ self,
+ *,
+ score_thresh: float = 0.25,
+ nms_thresh: float = 0.45,
+ detections_per_img: int = 100,
+ ):
+ """
+ Register the ``EfficientNMS_TRT`` plugin node.
+ NMS expects these shapes for its input tensors:
+ - box_net: [batch_size, number_boxes, 4]
+ - class_net: [batch_size, number_boxes, number_labels]
+ Args:
+ score_thresh (float): The scalar threshold for score (low scoring boxes are removed).
+ nms_thresh (float): The scalar threshold for IOU (new boxes that have high IOU
+ overlap with previously selected boxes are removed).
+ detections_per_img (int): Number of best detections to keep after NMS.
+ """
+
+ self.infer()
+ # Find the concat node at the end of the network
+ op_inputs = self.graph.outputs
+ op = "EfficientNMS_TRT"
+ attrs = {
+ "plugin_version": "1",
+ "background_class": -1, # no background class
+ "max_output_boxes": detections_per_img,
+ "score_threshold": score_thresh,
+ "iou_threshold": nms_thresh,
+ "score_activation": False,
+ "box_coding": 0,
+ }
+
+ if self.precision == "fp32":
+ dtype_output = np.float32
+ elif self.precision == "fp16":
+ dtype_output = np.float16
+ else:
+ raise NotImplementedError(f"Currently not supports precision: {self.precision}")
+
+ # NMS Outputs
+ output_num_detections = gs.Variable(
+ name="num_detections",
+ dtype=np.int32,
+ shape=[self.batch_size, 1],
+ ) # A scalar indicating the number of valid detections per batch image.
+ output_boxes = gs.Variable(
+ name="detection_boxes",
+ dtype=dtype_output,
+ shape=[self.batch_size, detections_per_img, 4],
+ )
+ output_scores = gs.Variable(
+ name="detection_scores",
+ dtype=dtype_output,
+ shape=[self.batch_size, detections_per_img],
+ )
+ output_labels = gs.Variable(
+ name="detection_classes",
+ dtype=np.int32,
+ shape=[self.batch_size, detections_per_img],
+ )
+
+ op_outputs = [output_num_detections, output_boxes, output_scores, output_labels]
+
+ # Create the NMS Plugin node with the selected inputs. The outputs of the node will also
+ # become the final outputs of the graph.
+ self.graph.layer(op=op, name="batched_nms", inputs=op_inputs, outputs=op_outputs, attrs=attrs)
+ LOGGER.info(f"Created NMS plugin '{op}' with attributes: {attrs}")
+
+ self.graph.outputs = op_outputs
+
+ self.infer()
+
+ def save(self, output_path):
+ """
+ Save the ONNX model to the given location.
+ Args:
+ output_path: Path pointing to the location where to write
+ out the updated ONNX model.
+ """
+ self.graph.cleanup().toposort()
+ model = gs.export_onnx(self.graph)
+ onnx.save(model, output_path)
+ LOGGER.info(f"Saved ONNX model to {output_path}")
\ No newline at end of file