Skip to content

Commit

Permalink
Merge pull request meituan#1 from meituan/main
Browse files Browse the repository at this point in the history
Pull newest code
  • Loading branch information
xingyueye authored Jul 4, 2022
2 parents d62fbf6 + 5bd6686 commit 8398760
Show file tree
Hide file tree
Showing 12 changed files with 289 additions and 33 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ python tools/eval.py --data data/coco.yaml --batch 32 --weights yolov6s.pt --tas
- Comparisons of the mAP and speed of different object detectors are tested on [COCO val2017](https://cocodataset.org/#download) dataset.
- Refer to [Test speed](./docs/Test_speed.md) tutorial to reproduce the speed results of YOLOv6.
- Params and Flops of YOLOv6 are estimated on deployed model.
- Speed results of other methods are tested in our environment using official codebase and model if not found from the corresponding official release.
## Third-party resources
- Speed results of other methods are tested in our environment using official codebase and model if not found from the corresponding official release.

## Third-party resources
* YOLOv6 NCNN Android app demo: [ncnn-android-yolov6](https://github.com/FeiGeChuanShu/ncnn-android-yolov6) from [FeiGeChuanShu](https://github.com/FeiGeChuanShu)
* YOLOv6 ONNXRuntime/MNN/TNN C++: [YOLOv6-ORT](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/ort/cv/yolov6.cpp), [YOLOv6-MNN](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/mnn/cv/mnn_yolov6.cpp) and [YOLOv6-TNN](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/tnn/cv/tnn_yolov6.cpp) from [DefTruth](https://github.com/DefTruth)
* YOLOv6 ONNXRuntime/MNN/TNN C++: [YOLOv6-ORT](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/ort/cv/yolov6.cpp), [YOLOv6-MNN](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/mnn/cv/mnn_yolov6.cpp) and [YOLOv6-TNN](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/tnn/cv/tnn_yolov6.cpp) from [DefTruth](https://github.com/DefTruth)
* YOLOv6 TensorRT Python: [yolov6-tensorrt-python](https://github.com/Linaom1214/tensorrt-python/blob/main/yolov6/trt.py) from [Linaom1214](https://github.com/Linaom1214)
* YOLOv6 TensorRT Windows C++: [yolort](https://github.com/zhiqwang/yolov5-rt-stack/tree/main/deployment/tensorrt-yolov6) from [Wei Zeng](https://github.com/Wulingtian)
93 changes: 87 additions & 6 deletions deploy/ONNX/README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,98 @@
## Export ONNX Model
# Export ONNX Model

### Check requirements
## Check requirements
```shell
pip install onnx>=1.10.0
```

### Export script
## Export script
```shell
python deploy/ONNX/export_onnx.py --weights yolov6s.pt --img 640 --batch 1

python ./deploy/ONNX/export_onnx.py \
--weights yolov6s.pt \
--img 640 \
--batch 1
```

### Download


#### Description of all arguments

- `--weights` : The path of yolov6 model weights.
- `--img` : Image size of model inputs.
- `--batch` : Batch size of model inputs.
- `--half` : Whether to export half-precision model.
- `--inplace` : Whether to set Detect() inplace.
- `--simplify` : Whether to simplify onnx. Not support in end to end export.
- `--end2end` : Whether to export end to end onnx model. Only support onnxruntime and TensorRT >= 8.0.0 .
- `--max-wh` : Default is None for TensorRT backend. Set int for onnxruntime backend.
- `--topk-all` : Topk objects for every image.
- `--iou-thres` : IoU threshold for NMS algorithm.
- `--conf-thres` : Confidence threshold for NMS algorithm.
- `--device` : Export device. Cuda device : 0 or 0,1,2,3 ... , CPU : cpu .

## Download

* [YOLOv6-nano](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6n.onnx)
* [YOLOv6-tiny](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6t.onnx)
* [YOLOv6-s](https://github.com/meituan/YOLOv6/releases/download/0.1.0/yolov6s.onnx)

## End2End export

Now YOLOv6 supports end to end detect for onnxruntime and TensorRT !

If you want to deploy in TensorRT, make sure you have installed TensorRT >= 8.0.0 !

### onnxruntime backend
#### Usage

```bash
python ./deploy/ONNX/export_onnx.py \
--weights yolov6s.pt \
--img 640 \
--batch 1 \
--end2end \
--max-wh 7680
```

You will get an onnx with **NonMaxSuppression** operater .

The onnx outputs shape is ```nums x 7```.

```nums``` means the number of all objects which were detected.

```7``` means [`batch_index`,`x0`,`y0`,`x1`,` y1`,`classid`,`score`]

### TensorRT backend (TensorRT version>= 8.0.0)

#### Usage

```bash
python ./deploy/ONNX/export_onnx.py \
--weights yolov6s.pt \
--img 640 \
--batch 1 \
--end2end
```

You will get an onnx with **[EfficientNMS_TRT](https://github.com/NVIDIA/TensorRT/tree/main/plugin/efficientNMSPlugin)** plugin .
The onnx outputs are as shown :

<img src="https://user-images.githubusercontent.com/92794867/176650971-a4fa3d65-10d4-4b65-b8ef-00a2ff13406c.png" height="300px" />

```num_dets``` means the number of object in every image in its batch .

```det_boxes``` means topk(100) object's location about [`x0`,`y0`,`x1`,` y1`] .

```det_scores``` means the confidence score of every topk(100) objects .

```det_classes``` means the category of every topk(100) objects .


You can export TensorRT engine use [trtexec](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#trtexec-ovr) tools.
#### Usage
``` shell
/path/to/trtexec \
--onnx=yolov6s.onnx \
--saveEngine=yolov6s.engine \
--fp16 # if export TensorRT fp16 model
```
28 changes: 24 additions & 4 deletions deploy/ONNX/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
parser.add_argument('--inplace', action='store_true', help='set Detect() inplace=True')
parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
parser.add_argument('--end2end', action='store_true', help='export end2end onnx')
parser.add_argument('--max-wh', type=int, default=None, help='None for trt int for ort')
parser.add_argument('--topk-all', type=int, default=100, help='topk objects for every images')
parser.add_argument('--iou-thres', type=float, default=0.45, help='iou threshold for NMS')
parser.add_argument('--conf-thres', type=float, default=0.25, help='conf threshold for NMS')
parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
args = parser.parse_args()
args.img_size *= 2 if len(args.img_size) == 1 else 1 # expand
Expand Down Expand Up @@ -57,6 +62,10 @@
m.act = SiLU()
elif isinstance(m, Detect):
m.inplace = args.inplace
if args.end2end:
from yolov6.models.end2end import End2End
model = End2End(model, max_obj=args.topk_all, iou_thres=args.iou_thres,
score_thres=args.conf_thres, max_wh=args.max_wh, device=device)

y = model(img) # dry run

Expand All @@ -65,16 +74,23 @@
LOGGER.info('\nStarting to export ONNX...')
export_file = args.weights.replace('.pt', '.onnx') # filename
with BytesIO() as f:
torch.onnx.export(model, img, export_file, verbose=False, opset_version=12,
torch.onnx.export(model, img, f, verbose=False, opset_version=12,
training=torch.onnx.TrainingMode.EVAL,
do_constant_folding=True,
input_names=['image_arrays'],
output_names=['outputs'],
)
output_names=['num_dets', 'det_boxes', 'det_scores', 'det_classes']
if args.end2end and args.max_wh is None else ['outputs'],)
f.seek(0)
# Checks
onnx_model = onnx.load(f) # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model
# Fix output shape
if args.end2end and args.max_wh is None:
shapes = [args.batch_size, 1, args.batch_size, args.topk_all, 4,
args.batch_size, args.topk_all, args.batch_size, args.topk_all]
for i in onnx_model.graph.output:
for j in i.type.tensor_type.shape.dim:
j.dim_param = str(shapes.pop(0))
if args.simplify:
try:
import onnxsim
Expand All @@ -83,10 +99,14 @@
assert check, 'assert check failed'
except Exception as e:
LOGGER.info(f'Simplifier failure: {e}')
onnx.save(onnx_model, f)
onnx.save(onnx_model, export_file)
LOGGER.info(f'ONNX export success, saved as {export_file}')
except Exception as e:
LOGGER.info(f'ONNX export failure: {e}')

# Finish
LOGGER.info('\nExport complete (%.2fs)' % (time.time() - t))
if args.end2end:
if args.max_wh is None:
LOGGER.info('\nYou can export tensorrt engine use trtexec tools.\nCommand is:')
LOGGER.info(f'trtexec --onnx={export_file} --saveEngine={export_file.replace(".onnx",".engine")}')
12 changes: 6 additions & 6 deletions docs/Test_speed.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,31 @@ Download the models you want to test from the latest release.

Refer to README, install packages corresponding to CUDA, CUDNN and TensorRT version.

Here, we use Torch1.8.0 inference on V100 and TensorRT 7.2 on T4.
Here, we use Torch1.8.0 inference on V100 and TensorRT 7.2 Cuda 10.2 Cudnn 8.0.2 on T4.

## 2. Reproduce speed

#### 2.1 Torch Inference on V100

To get inference speed without TensorRT on V100, you can run the following command:
To get inference speed without TensorRT on V100, you can run the following command:

```shell
python tools/eval.py --data data/coco.yaml --batch 32 --weights yolov6n.pt --task speed [--half]
python tools/eval.py --data data/coco.yaml --batch 32 --weights yolov6n.pt --task speed [--half]
```

- Speed results with batchsize = 1 are unstable in multiple runs, thus we do not provide the bs1 speed results.

#### 2.2 TensorRT Inference on T4

To get inference speed with TensorRT in FP16 mode on T4, you can follow the steps below:
To get inference speed with TensorRT in FP16 mode on T4, you can follow the steps below:

First, export pytorch model as onnx format using the following command:
First, export pytorch model as onnx format using the following command:

```shell
python deploy/ONNX/export_onnx.py --weights yolov6n.pt --device 0 --batch [1 or 32]
```

Second, generate an inference trt engine and test speed using `trtexec`:
Second, generate an inference trt engine and test speed using `trtexec`:

```
trtexec --onnx=yolov6n.onnx --workspace=1024 --avgRuns=1000 --inputIOFormats=fp16:chw --outputIOFormats=fp16:chw
Expand Down
14 changes: 8 additions & 6 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@ def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(description='YOLOv6 PyTorch Training', add_help=add_help)
parser.add_argument('--data-path', default='./data/coco.yaml', type=str, help='path of dataset')
parser.add_argument('--conf-file', default='./configs/yolov6s.py', type=str, help='experiments description file')
parser.add_argument('--img-size', type=int, default=640, help='train, val image size (pixels)')
parser.add_argument('--img-size', default=640, type=int, help='train, val image size (pixels)')
parser.add_argument('--batch-size', default=32, type=int, help='total batch size for all GPUs')
parser.add_argument('--epochs', default=400, type=int, help='number of total epochs to run')
parser.add_argument('--workers', default=8, type=int, help='number of data loading workers (default: 8)')
parser.add_argument('--device', default='0', type=str, help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--eval-interval', type=int, default=20, help='evaluate at every interval epochs')
parser.add_argument('--eval-interval', default=20, type=int, help='evaluate at every interval epochs')
parser.add_argument('--eval-final-only', action='store_true', help='only evaluate at the final epoch')
parser.add_argument('--heavy-eval-range', default=50,
parser.add_argument('--heavy-eval-range', default=50, type=int,
help='evaluating every epoch for last such epochs (can be jointly used with --eval-interval)')
parser.add_argument('--check-images', action='store_true', help='check images when initializing datasets')
parser.add_argument('--check-labels', action='store_true', help='check label files when initializing datasets')
parser.add_argument('--output-dir', default='./runs/train', type=str, help='path to save outputs')
parser.add_argument('--name', default='exp', type=str, help='experiment name, saved to output_dir/name')
parser.add_argument('--dist_url', type=str, default="default url: tcp://127.0.0.1:8888")
parser.add_argument('--dist_url', default='env://', type=str, help='url used to set up distributed training')
parser.add_argument('--gpu_count', type=int, default=0)
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter')
parser.add_argument('--resume', type=str, default=None, help='resume the corresponding ckpt')
Expand All @@ -47,8 +47,8 @@ def check_and_init(args):
'''check config files and device, and initialize '''

# check files
master_process = args.rank == 0 if args.world_size > 1 else args.rank == -1
args.save_dir = str(increment_name(osp.join(args.output_dir, args.name)))
os.makedirs(args.save_dir, exist_ok=True)
cfg = Config.fromfile(args.conf_file)

# check device
Expand All @@ -58,7 +58,9 @@ def check_and_init(args):
set_random_seed(1+args.rank, deterministic=(args.rank == -1))

# save args
save_yaml(vars(args), osp.join(args.save_dir, 'args.yaml'))
if master_process:
os.makedirs(args.save_dir)
save_yaml(vars(args), osp.join(args.save_dir, 'args.yaml'))

return cfg, device

Expand Down
2 changes: 1 addition & 1 deletion yolov6/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, args, cfg, device):
assert os.path.isfile(args.resume), 'ERROR: --resume checkpoint does not exists'
self.ckpt = torch.load(args.resume, map_location='cpu')
self.start_epoch = self.ckpt['epoch'] + 1

self.max_epoch = args.epochs
self.max_stepnum = len(self.train_loader)
self.batch_size = args.batch_size
Expand Down
2 changes: 1 addition & 1 deletion yolov6/core/evaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def predict_model(self, model, dataloader, task):
def eval_model(self, pred_results, model, dataloader, task):
'''Evaluate models
For task speed, this function only evaluates the speed of model and outputs inference time.
For task val, this function evaluates the speed and mAP by pycocotools, and returns
For task val, this function evaluates the speed and mAP by pycocotools, and returns
inference time and mAP value.
'''
LOGGER.info(f'\nEvaluating speed.')
Expand Down
2 changes: 1 addition & 1 deletion yolov6/core/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def precess_image(path, img_size, stride, half):
img_src = cv2.imread(path)
assert img_src is not None, f'Invalid image: {path}'
except Exception as e:
LOGGER.Warning(e)
LOGGER.warning(e)
image = letterbox(img_src, img_size, stride=stride)[0]

# Convert
Expand Down
4 changes: 2 additions & 2 deletions yolov6/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def get_imgs_labels(self, img_dir):
ne_per_file,
msg,
) in pbar:
if img_path:
if nc_per_file == 0:
img_info[img_path]["labels"] = labels_per_file
else:
img_info.pop(img_path)
Expand Down Expand Up @@ -484,7 +484,7 @@ def check_label_files(args):
except Exception as e:
nc = 1
msg = f"WARNING: {lb_path}: ignoring invalid labels: {e}"
return None, None, nc, nm, nf, ne, msg
return img_path, None, nc, nm, nf, ne, msg

@staticmethod
def generate_coco_format_labels(img_info, class_names, save_path):
Expand Down
2 changes: 1 addition & 1 deletion yolov6/layers/dbb_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ def transV_avg(channels, kernel_size, groups):
def transVI_multiscale(kernel, target_kernel_size):
H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2
W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2
return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])
return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])
Loading

0 comments on commit 8398760

Please sign in to comment.