Skip to content

Commit

Permalink
Merge branch 'meituan:main' into trtmAP
Browse files Browse the repository at this point in the history
  • Loading branch information
triple-Mu authored Jul 13, 2022
2 parents f741283 + a1f5f97 commit 16b929e
Show file tree
Hide file tree
Showing 11 changed files with 483 additions and 36 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ Your can also specify a checkpoint path to `--resume` parameter by

* [Train custom data](./docs/Train_custom_data.md)
* [Test speed](./docs/Test_speed.md)

* [Tutorial of RepOpt for YOLOv6](./docs/tutorial_repopt.md)

## Benchmark

Expand Down
56 changes: 56 additions & 0 deletions configs/repopt/yolov6s_hs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# YOLOv6s model
model = dict(
type='YOLOv6s',
pretrained=None,
depth_multiple=0.33,
width_multiple=0.50,
backbone=dict(
type='EfficientRep',
num_repeats=[1, 6, 12, 18, 6],
out_channels=[64, 128, 256, 512, 1024],
),
neck=dict(
type='RepPAN',
num_repeats=[12, 12, 12, 12],
out_channels=[256, 128, 128, 256, 256, 512],
),
head=dict(
type='EffiDeHead',
in_channels=[128, 256, 512],
num_layers=3,
begin_indices=24,
anchors=1,
out_indices=[17, 20, 23],
strides=[8, 16, 32],
iou_type='siou'
)
)

solver = dict(
optim='SGD',
lr_scheduler='Cosine',
lr0=0.01,
lrf=0.01,
momentum=0.937,
weight_decay=0.0005,
warmup_epochs=3.0,
warmup_momentum=0.8,
warmup_bias_lr=0.1
)

data_aug = dict(
hsv_h=0.015,
hsv_s=0.7,
hsv_v=0.4,
degrees=0.0,
translate=0.1,
scale=0.5,
shear=0.0,
flipud=0.0,
fliplr=0.5,
mosaic=1.0,
mixup=0.0,
)

# Choose Rep-block by the training Mode, choices=["repvgg", "hyper-search", "repopt"]
training_mode='hyper_search'
56 changes: 56 additions & 0 deletions configs/repopt/yolov6s_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# YOLOv6s model
model = dict(
type='YOLOv6s',
pretrained='./assets/yolov6s_scale.pt',
depth_multiple=0.33,
width_multiple=0.50,
backbone=dict(
type='EfficientRep',
num_repeats=[1, 6, 12, 18, 6],
out_channels=[64, 128, 256, 512, 1024],
),
neck=dict(
type='RepPAN',
num_repeats=[12, 12, 12, 12],
out_channels=[256, 128, 128, 256, 256, 512],
),
head=dict(
type='EffiDeHead',
in_channels=[128, 256, 512],
num_layers=3,
begin_indices=24,
anchors=1,
out_indices=[17, 20, 23],
strides=[8, 16, 32],
iou_type='siou'
)
)

solver = dict(
optim='SGD',
lr_scheduler='Cosine',
lr0=0.01,
lrf=0.01,
momentum=0.937,
weight_decay=0.0005,
warmup_epochs=3.0,
warmup_momentum=0.8,
warmup_bias_lr=0.1
)

data_aug = dict(
hsv_h=0.015,
hsv_s=0.7,
hsv_v=0.4,
degrees=0.0,
translate=0.1,
scale=0.5,
shear=0.0,
flipud=0.0,
fliplr=0.5,
mosaic=1.0,
mixup=0.0,
)

# Choose Rep-block by the training Mode, choices=["repvgg", "hyper-search", "repopt"]
training_mode='repopt'
29 changes: 29 additions & 0 deletions docs/tutorial_repopt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# RepOpt version implementation of YOLOv6
## Introduction
This is a RepOpt-version implementation of YOLOv6 according to RepOptimizer: https://arxiv.org/pdf/2205.15242.pdf @DingXiaoH \
It shows some advantages:
1. With only minor changes. it is compatible with the original repvgg version, and it is easy to reproduce the precision comparable with original version.
2. No more train/deploy transform. The target network is consistent when training and deploying.
3. A slight training acceleration of about 8%.
4. Last and the most important, It is quantization friendly. Compared to the original version, the mAP decrease of PTQ can be greatly improved. Furthermore, the architecture of RepOptimizer is friendly to wrap quant-models for QAT.

## Training
The training of V6-RepOpt can be divided into two stages, hyperparameter search and target network training.
1. hyperparameter search. This stage is used to get a suitable 'scale' for RepOptimizer, and the result checkpoint can be passed to stage2. Remember to add `training_mode='hyper_search'` in your config.
```
python tools/train.py --batch 32 --conf configs/repopt/yolov6s_hs.py --data data/coco.yaml --device 0
```
Or you can directly use the [pretrained scale](https://github.com/xingyueye/YOLOv6/releases/download/0.1.0/yolov6s_scale.pt) we provided and omit this stage.

2. Training. Add the flag of `training_mode='repopt'` and pretraind model `pretrained='./assets/yolov6s_scale.pt',` in your config
```
python tools/train.py --batch 32 --conf configs/repopt/yolov6s_opt.py --data data/coco.yaml --device 0
```
## Evaluation
Reproduce mAP on COCO val2017 dataset, you can directly test our [pretrained model](https://github.com/xingyueye/YOLOv6/releases/download/0.1.0/yolov6s_opt.pt).
```
python tools/eval.py --data data/coco.yaml --batch 32 --weights yolov6s_opt.pt --task val
```
## Benchmark
We train a yolov6s-repopt with 300epochs, the fp32 mAP is 42.4, while the mAP of PTQ is 40.5. More results is coming soon...

2 changes: 2 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def check_and_init(args):
os.makedirs(args.save_dir)

cfg = Config.fromfile(args.conf_file)
if not hasattr(cfg, 'training_mode'):
setattr(cfg, 'training_mode', 'repvgg')
# check device
device = select_device(args.device)
# set random seed
Expand Down
26 changes: 22 additions & 4 deletions yolov6/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from yolov6.utils.ema import ModelEMA, de_parallel
from yolov6.utils.checkpoint import load_state_dict, save_checkpoint, strip_optimizer
from yolov6.solver.build import build_optimizer, build_lr_scheduler
from yolov6.utils.RepOptimizer import extract_scales, RepVGGOptimizer


class Trainer:
Expand All @@ -43,7 +44,11 @@ def __init__(self, args, cfg, device):
self.train_loader, self.val_loader = self.get_data_loader(args, cfg, self.data_dict)
# get model and optimizer
model = self.get_model(args, cfg, self.num_classes, device)
self.optimizer = self.get_optimizer(args, cfg, model)
if cfg.training_mode == 'repopt':
scales = self.load_scale_from_pretrained_models(cfg, device)
self.optimizer = RepVGGOptimizer(model, scales, args, cfg)
else:
self.optimizer = self.get_optimizer(args, cfg, model)
self.scheduler, self.lf = self.get_lr_scheduler(args, cfg, self.optimizer)
self.ema = ModelEMA(model) if self.main_process else None
# tensorboard
Expand Down Expand Up @@ -239,12 +244,25 @@ def prepro_data(batch_data, device):
def get_model(self, args, cfg, nc, device):
model = build_model(cfg, nc, device)
weights = cfg.model.pretrained
if weights: # finetune if pretrained model is set
LOGGER.info(f'Loading state_dict from {weights} for fine-tuning...')
model = load_state_dict(weights, model, map_location=device)
if cfg.training_mode == 'repvgg' and weights:
if weights: # finetune if pretrained model is set
LOGGER.info(f'Loading state_dict from {weights} for fine-tuning...')
model = load_state_dict(weights, model, map_location=device)
LOGGER.info('Model: {}'.format(model))
return model

@staticmethod
def load_scale_from_pretrained_models(cfg, device):
weights = cfg.model.pretrained
scales = None
if not weights:
LOGGER.warning("Training RepOpt Architecture without Searched Hyper Scales")
else:
ckpt = torch.load(weights, map_location=device)
scales = extract_scales(ckpt)
return scales


@staticmethod
def parallel_model(args, model, device):
# If DP mode
Expand Down
113 changes: 97 additions & 16 deletions yolov6/layers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.nn.init as init
import torch.nn.functional as F
from yolov6.layers.dbb_transforms import *

Expand Down Expand Up @@ -118,22 +120,6 @@ def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
return result


class RepBlock(nn.Module):
'''
RepBlock is a stage block with rep-style basic block
'''
def __init__(self, in_channels, out_channels, n=1):
super().__init__()
self.conv1 = RepVGGBlock(in_channels, out_channels)
self.block = nn.Sequential(*(RepVGGBlock(out_channels, out_channels) for _ in range(n - 1))) if n > 1 else None

def forward(self, x):
x = self.conv1(x)
if self.block is not None:
x = self.block(x)
return x


class RepVGGBlock(nn.Module):
'''RepVGGBlock is a basic rep-style block, including training and deploy status
This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
Expand Down Expand Up @@ -254,6 +240,74 @@ def switch_to_deploy(self):
self.deploy = True


class RealVGGBlock(nn.Module):

def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
dilation=1, groups=1, padding_mode='zeros', use_se=False,
):
super(RealVGGBlock, self).__init__()
self.relu = nn.ReLU()
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
self.bn = nn.BatchNorm2d(out_channels)

if use_se:
raise NotImplementedError("se block not supported yet")
else:
self.se = nn.Identity()

def forward(self, inputs):
out = self.relu(self.se(self.bn(self.conv(inputs))))
return out

class ScaleLayer(torch.nn.Module):

def __init__(self, num_features, use_bias=True, scale_init=1.0):
super(ScaleLayer, self).__init__()
self.weight = Parameter(torch.Tensor(num_features))
init.constant_(self.weight, scale_init)
self.num_features = num_features
if use_bias:
self.bias = Parameter(torch.Tensor(num_features))
init.zeros_(self.bias)
else:
self.bias = None

def forward(self, inputs):
if self.bias is None:
return inputs * self.weight.view(1, self.num_features, 1, 1)
else:
return inputs * self.weight.view(1, self.num_features, 1, 1) + self.bias.view(1, self.num_features, 1, 1)

# A CSLA block is a LinearAddBlock with is_csla=True
class LinearAddBlock(nn.Module):

def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
dilation=1, groups=1, padding_mode='zeros', use_se=False, is_csla=False, conv_scale_init=1.0):
super(LinearAddBlock, self).__init__()
self.in_channels = in_channels
self.relu = nn.ReLU()
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
self.scale_conv = ScaleLayer(num_features=out_channels, use_bias=False, scale_init=conv_scale_init)
self.conv_1x1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=0, bias=False)
self.scale_1x1 = ScaleLayer(num_features=out_channels, use_bias=False, scale_init=conv_scale_init)
if in_channels == out_channels and stride == 1:
self.scale_identity = ScaleLayer(num_features=out_channels, use_bias=False, scale_init=1.0)
self.bn = nn.BatchNorm2d(out_channels)
if is_csla: # Make them constant
self.scale_1x1.requires_grad_(False)
self.scale_conv.requires_grad_(False)
if use_se:
raise NotImplementedError("se block not supported yet")
else:
self.se = nn.Identity()

def forward(self, inputs):
out = self.scale_conv(self.conv(inputs)) + self.scale_1x1(self.conv_1x1(inputs))
if hasattr(self, 'scale_identity'):
out += self.scale_identity(inputs)
out = self.relu(self.se(self.bn(out)))
return out

def conv_bn_v2(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
padding_mode='zeros'):
conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
Expand Down Expand Up @@ -499,3 +553,30 @@ def forward(self, im, val=False):
if isinstance(y, np.ndarray):
y = torch.tensor(y, device=self.device)
return y


class RepBlock(nn.Module):
'''
RepBlock is a stage block with rep-style basic block
'''
def __init__(self, in_channels, out_channels, n=1, block=RepVGGBlock):
super().__init__()
self.conv1 = block(in_channels, out_channels)
self.block = nn.Sequential(*(block(out_channels, out_channels) for _ in range(n - 1))) if n > 1 else None

def forward(self, x):
x = self.conv1(x)
if self.block is not None:
x = self.block(x)
return x


def get_block(mode):
if mode == 'repvgg':
return RepVGGBlock
elif mode == 'hyper_search':
return LinearAddBlock
elif mode == 'repopt':
return RealVGGBlock
else:
raise NotImplementedError("Undefied Repblock choice for mode {}".format(mode))
Loading

0 comments on commit 16b929e

Please sign in to comment.