Skip to content

Commit

Permalink
add det v4 mobile
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Aug 28, 2023
1 parent 8b8b34a commit 96d88ad
Show file tree
Hide file tree
Showing 16 changed files with 556 additions and 94 deletions.
42 changes: 19 additions & 23 deletions configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
Global:
debug: false
use_gpu: true
device: gpu
epoch_num: &epoch_num 500
log_smooth_window: 20
print_batch_step: 100
output_dir: ./output/ch_PP-OCRv4
save_epoch_step: 10
eval_batch_step:
- 0
- 1500
print_batch_step: 1
output_dir: ./output/ch_PP-OCRv4/
eval_epoch_step: [0, 1]
cal_metric_during_train: false
pretrained_model:
checkpoints:
pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/PPLCNetV3_x0_75_ocr_det.pdparams
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt
distributed: true
use_tensorboard: false
infer_img: doc/imgs/1.jpg

Export:
export_dir:
export_shape: [ 1, 3, 640, 640 ]
dynamic_axes: [ 0, 2, 3 ]

Architecture:
model_type: det
Expand Down Expand Up @@ -44,15 +42,13 @@ Loss:

Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001 #(8*8c)
warmup_epoch: 2
regularizer:
name: L2
factor: 5.0e-05
lr: 0.001 #(8*8c)
weight_decay: 5.0e-05

LRScheduler:
name: CosineAnnealingLR
warmup_epoch: 2


PostProcess:
name: DBPostProcess
Expand Down
28 changes: 27 additions & 1 deletion convert_params_compute_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,32 @@ def conver_params(model_config, paddle_params_path, tmp_dir, show_log=False):
print(f"save convert torch params to {torch_params_path}")
return paddle_params_path, torch_params_path

def torch2paddle(torch_model: torch.nn.Module, paddle_model: paddle.nn.Layer):
paddle_state_dict = paddle_model.state_dict()
# paddle_state_dict = paddle.load(paddle_model)
fc_names = ["classifier"]
torch_state_dict = {}
for k in paddle_state_dict:
v = paddle_state_dict[k].detach().cpu().numpy()
flag = [i in k for i in fc_names]
if any(flag) and "weight" in k: # ignore bias
new_shape = [1, 0] + list(range(2, v.ndim))
print(f"name: {k}, ori shape: {v.shape}, new shape: {v.transpose(new_shape).shape}")
v = v.transpose(new_shape)
k = k.replace("_variance", "running_var")
k = k.replace("_mean", "running_mean")
torch_state_dict[k] = torch.from_numpy(v)

for k in torch_state_dict:
if k not in torch_model.state_dict():
print(f'{k} is not in torch model')
for k in torch_model.state_dict():
if 'num_batches_tracked' in k:
continue
if k not in torch_state_dict:
print(f'{k} is not in torch params')
torch_model.load_state_dict(torch_state_dict)

def get_input(w, h, color=True):
img = cv2.imread("doc/imgs/1.jpg", 1 if color else 0)
img = cv2.resize(img, (w, h))
Expand Down Expand Up @@ -175,7 +201,7 @@ def main():

tmp_dir = './tmp'
os.makedirs(tmp_dir, exist_ok=True)
config_path = "configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml"
config_path = "configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml"
paddle_params_path = ''

config = load_config(config_path)
Expand Down
16 changes: 12 additions & 4 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- [模型对齐信息](#模型对齐信息)
- [环境](#环境)
- [对齐列表](#对齐列表)
- [TODO](#todo)
- [使用方式](#使用方式)
- [数据准备](#数据准备)
- [train](#train)
Expand Down Expand Up @@ -34,6 +35,7 @@

| 模型 | 是否对齐 | 对齐误差| 配置文件 |
|---|---|---|---|
| ch_PP-OCRv4_det_student | Y | 0 | [config](configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml) |
| ch_PP-OCRv3_rec | Y | 4.615016e-11 | [config](configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml) |
| ch_PP-OCRv3_rec_distillation.yml | Y | Teacher_head_out_res 7.470646e-10 <br> Student_head_out_res 4.615016e-11 | [config](configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml) |
| ch_PP-OCRv3_det_student | Y | 1.766314e-07 | [config](cconfigs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student.yml) |
Expand All @@ -56,8 +58,10 @@

模型:

- [ ] PP-OCRv4 det
- [ ] PP-OCRv4 rec
- [x] PP-OCRv4 det mobile
- [ ] PP-OCRv4 det server
- [ ] PP-OCRv4 rec mobile
- [ ] PP-OCRv4 rec server
- [ ] DB
- [ ] DB ++
- [ ] CRNN
Expand All @@ -71,14 +75,18 @@
### train

```sh
python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
# 单卡
CUDA_VISIBLE_DEVICES=0 python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml

# 多卡
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes=1 --nproc_per_node=4 tools/train.py --c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
```


### eval

```sh
python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
```


Expand Down
3 changes: 2 additions & 1 deletion tools/infer_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def build_det_process(cfg):
op[op_name]['keep_keys'] = ['image', 'shape']
transforms.append(op)
return transforms

def main(cfg):
logger = get_logger()
global_config = cfg['Global']
Expand All @@ -50,7 +51,7 @@ def main(cfg):

save_res_path = global_config.get('output_dir', 'output')
os.makedirs(save_res_path, exist_ok=True)

with open(os.path.join(save_res_path, 'predict_det.txt'), "w") as fout:
for file in get_image_file_list(cfg['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
Expand Down
2 changes: 0 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def parse_args():
action='store_true',
default=True,
help="Whether to perform evaluation in train")
parser.add_argument('--local_rank', dest='local_rank', default=0, type=int, help='Use distributed training')
args = parser.parse_args()
return args

Expand All @@ -32,7 +31,6 @@ def main():
opt = FLAGS.pop('opt')
cfg.merge_dict(FLAGS)
cfg.merge_dict(opt)
cfg.cfg['Global']['local_rank'] = FLAGS['local_rank']
trainer = Trainer(cfg, mode='train_eval' if FLAGS['eval'] else 'train')
trainer.train()

Expand Down
35 changes: 21 additions & 14 deletions torchocr/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,18 @@
class Trainer(object):
def __init__(self, cfg, mode='train'):
self.cfg = cfg.cfg
self.local_rank = self.cfg['Global'].get('local_rank', 0)

self.local_rank = int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0
self.set_device(self.cfg['Global']['device'])

if torch.cuda.device_count() > 1:
torch.distributed.init_process_group(backend="nccl")
torch.cuda.set_device(self.device)
self.cfg['Global']['distributed'] = True
else:
self.cfg['Global']['distributed'] = False
self.local_rank = 0

self.cfg['Global']['output_dir'] = self.cfg['Global'].get('output_dir', 'output')
os.makedirs(self.cfg['Global']['output_dir'], exist_ok=True)

Expand All @@ -41,15 +52,6 @@ def __init__(self, cfg, mode='train'):
self.logger = get_logger('torchocr', os.path.join(self.cfg['Global']['output_dir'],
' train.log') if 'train' in mode else None)

self.set_device(self.cfg['Global']['device'])
if torch.cuda.device_count() > 1 and self.device.type == 'cuda':
torch.cuda.set_device(self.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://",
world_size=torch.cuda.device_count(), rank=self.local_rank)
self.cfg['Global']['distributed'] = True
else:
self.cfg['Global']['distributed'] = False

self.set_random_seed(self.cfg.get('seed', 48))

mode = mode.lower()
Expand All @@ -72,7 +74,7 @@ def __init__(self, cfg, mode='train'):

# build model
self.model = build_model(self.cfg['Architecture'])

self.model = self.model.to(self.device)
use_sync_bn = self.cfg["Global"].get("use_sync_bn", False)
if use_sync_bn:
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
Expand All @@ -96,7 +98,7 @@ def __init__(self, cfg, mode='train'):
self.status = load_ckpt(self.model, self.cfg, self.optimizer, self.lr_scheduler)

if self.cfg['Global']['distributed']:
self.model = torch.nn.parallel.DistributedDataParallel(self.model)
self.model = torch.nn.parallel.DistributedDataParallel(self.model, [self.local_rank], find_unused_parameters=True)

# amp
self.scaler = torch.cuda.amp.GradScaler() if self.cfg['Global'].get('use_amp', False) else None
Expand All @@ -114,7 +116,7 @@ def set_random_seed(self, seed):

def set_device(self, device):
if device == 'gpu' and torch.cuda.is_available():
device = torch.device("cuda:0")
device = torch.device(f"cuda:{self.local_rank}")
else:
if device == 'gpu':
self.logger.info('cuda is not available, auto switch to cpu')
Expand Down Expand Up @@ -160,6 +162,7 @@ def train(self):
self.train_dataloader = build_dataloader(self.cfg, 'Train', self.logger)

for idx, batch in enumerate(self.train_dataloader):
batch = [t.to(self.device) for t in batch]
self.optimizer.zero_grad()
train_reader_cost += time.time() - reader_start
# use amp
Expand Down Expand Up @@ -193,7 +196,7 @@ def train(self):

# logger
stats = {
k: float(v) if v.shape == [] else v.detach().numpy().mean()
k: float(v) if v.shape == [] else v.detach().cpu().numpy().mean()
for k, v in loss.items()
}
stats['lr'] = self.lr_scheduler.get_last_lr()[0]
Expand Down Expand Up @@ -253,6 +256,8 @@ def train(self):
self.logger.info(best_str)
if self.writer is not None:
self.writer.close()
if torch.cuda.device_count() > 1:
torch.distributed.destroy_process_group()

def eval(self):
self.model.eval()
Expand All @@ -266,7 +271,9 @@ def eval(self):
leave=True)
sum_images = 0
for idx, batch in enumerate(self.valid_dataloader):
batch = [t.to(self.device) for t in batch]
start = time.time()
images = batch[0].to(self.device)
if self.scaler:
with torch.cuda.amp.autocast():
preds = self.model(batch[0], data=batch[1:])
Expand Down
4 changes: 2 additions & 2 deletions torchocr/losses/det_db_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def forward(self, predicts, labels):
cbn_loss = self.bce_loss(cbn_maps[:, 0, :, :], label_shrink_map,
label_shrink_mask)
else:
dis_loss = torch.tensor([0.])
cbn_loss = torch.tensor([0.])
dis_loss = torch.tensor([0.], device=shrink_maps.device)
cbn_loss = torch.tensor([0.], device=shrink_maps.device)

loss_all = loss_shrink_maps + loss_threshold_maps \
+ loss_binary_maps
Expand Down
4 changes: 2 additions & 2 deletions torchocr/metrics/det_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def __call__(self, preds, batch, **kwargs):
preds: a list of dict produced by post process
points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
'''
gt_polyons_batch = batch[2]
ignore_tags_batch = batch[3]
gt_polyons_batch = batch[2].cpu().numpy()
ignore_tags_batch = batch[3].cpu().numpy()
for pred, gt_polyons, ignore_tags in zip(preds, gt_polyons_batch,
ignore_tags_batch):
# prepare gt
Expand Down
6 changes: 4 additions & 2 deletions torchocr/modeling/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@ def build_backbone(config, model_type):
if model_type == "det" or model_type == "table":
from .det_mobilenet_v3 import MobileNetV3
from .det_resnet_vd import ResNet_vd
from .rec_lcnetv3 import PPLCNetV3
support_dict = [
'MobileNetV3', 'ResNet_vd'
'MobileNetV3', 'ResNet_vd', 'PPLCNetV3'
]
elif model_type == "rec" or model_type == "cls":
from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_31 import ResNet31
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_lcnetv3 import PPLCNetV3
support_dict = [
'MobileNetV1Enhance','ResNet31', 'MobileNetV3'
'MobileNetV1Enhance','ResNet31', 'MobileNetV3', 'PPLCNetV3'
]
else:
raise NotImplementedError
Expand Down
Loading

0 comments on commit 96d88ad

Please sign in to comment.