Skip to content

Commit

Permalink
[Codes] Update the function of calculating Params and Flops.
Browse files Browse the repository at this point in the history
  • Loading branch information
Fafa-DL committed Sep 9, 2022
1 parent 4e61e72 commit ae67de8
Show file tree
Hide file tree
Showing 4 changed files with 690 additions and 259 deletions.
30 changes: 30 additions & 0 deletions datas/docs/Calculate_Flops.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
获取模型Flops&Param
===========================

[![BILIBILI](https://raw.githubusercontent.com/Fafa-DL/readme-data/main/Bilibili.png)](https://space.bilibili.com/46880349)
- 提供 `tools/vis_lr.py` 工具来可视化学习率。

**命令行**

```bash
python tools/vis_lr.py \
${CONFIG_FILE} \
[--shape ${Shape}] \
```

**所有参数的说明**

- `config` : 模型配置文件的路径。
- `--shape` : 输入图片尺寸,默认224


**示例Step**

```bash
python tools/get_flops.py models/mobilenet/mobilenet_v3_small.py
```

**注意**

- 官方给出的参数量与浮点运算量是基于ImageNet,也就是说默认分类数为`1000`,所以当你评估自己模型时请在配置文件中将`num_classes`修改为对应数量,因为将很大程度上影响结果
- 如果你有新增任何`基类`卷积/池化/采样功能,请在`utils/flops_counter.py/get_modules_mapping()`进行增加注册
259 changes: 0 additions & 259 deletions datas/docs/Configs_description.md

This file was deleted.

56 changes: 56 additions & 0 deletions tools/get_flops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import sys
sys.path.insert(0,os.getcwd())

from utils.train_utils import file2dict
from utils.flops_counter import get_model_complexity_info
from models.build import BuildNet

def parse_args():
parser = argparse.ArgumentParser(description='Get model flops and params')
parser.add_argument('config', help='config file path')
parser.add_argument(
'--shape',
type=int,
nargs='+',
default=[224, 224],
help='input image size')
args = parser.parse_args()
return args


def main():

args = parse_args()

if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape)
else:
raise ValueError('invalid input shape')

model_cfg,train_pipeline,val_pipeline,data_cfg,lr_config,optimizer_cfg = file2dict(args.config)
model = BuildNet(model_cfg)
model.eval()

if hasattr(model, 'extract_feat'):
model.forward = model.extract_feat
else:
raise NotImplementedError(
'FLOPs counter is currently not currently supported with {}'.
format(model.__class__.__name__))

flops, params = get_model_complexity_info(model, input_shape)
split_line = '=' * 30
print(f'{split_line}\nInput shape: {input_shape}\n'
f'Flops: {flops}\nParams: {params}\n{split_line}')
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')


if __name__ == '__main__':
main()
Loading

0 comments on commit ae67de8

Please sign in to comment.