Skip to content

Commit

Permalink
Add slim deploy docs and demo
Browse files Browse the repository at this point in the history
  • Loading branch information
nepeplwu committed Nov 19, 2020
1 parent 0c904b4 commit c03e30f
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 0 deletions.
4 changes: 4 additions & 0 deletions slim/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ TRAIN.SYNC_BATCH_NORM False \
SLIM.PREPROCESS True \
```

## 预测部署

请参考[量化模型部署文档](./deploy/)

## 量化结果


Expand Down
67 changes: 67 additions & 0 deletions slim/quantization/deploy/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# PaddleSeg量化模型部署方案

## 1. 说明

本方案旨在提供一个PaddeSeg量化模型使用TensorRT的`Python`预测部署方案作为参考,用户通过一定的配置,加上少量的代码,即可把模型集成到自己的服务中,完成图像分割的任务。

## 2. 环境准备

* 参考[编译安装文档](../../../deploy/python/docs/compile_paddle_with_tensorrt.md),编译支持TensorRT的Paddle安装包并安装。

## 3. 开始预测

### 3.1 准备预测模型

请参考[模型量化](../)训练并导出相应的量化模型

模型导出的目录通常包括三个文件:

```
├── model # 模型文件
├── params # 参数文件
└── deploy.yaml # 配置文件,用于C++或Python预测
```

配置文件的主要字段及其含义如下:
```yaml
DEPLOY:
# 是否使用GPU预测
USE_GPU: 1
# 模型和参数文件所在目录路径
MODEL_PATH: "freeze_model"
# 模型文件名
MODEL_FILENAME: "model"
# 参数文件名
PARAMS_FILENAME: "params"
# 预测图片的的标准输入尺寸,输入尺寸不一致会做resize
EVAL_CROP_SIZE: (2049, 1025)
# 均值
MEAN: [0.5, 0.5, 0.5]
# 方差
STD: [0.5, 0.5, 0.5]
# 分类类型数
NUM_CLASSES: 19
# 图片通道数
CHANNELS : 3
# 预测模式,支持 NATIVE 和 ANALYSIS
PREDICTOR_MODE: "ANALYSIS"
# 每次预测的 batch_size
BATCH_SIZE : 3
```
### 3.2 执行预测程序
```bash
python infer.py --conf=/path/to/deploy.yaml --input_dir=/path/to/images_directory
```
参数说明如下:

| 参数 | 是否必须|含义 |
|-------|-------|----------|
| conf | Yes|模型配置的Yaml文件路径 |
| input_dir |Yes| 需要预测的图片目录 |
| save_dir | No|预测结果的保存路径,默认为output|
| ext | No| 所支持的图片格式,有多种格式时以'\|'分隔,默认为'.jpg\|.jpeg'|
| use_int8 |No| 是否是否Int8预测 |

运行后程序会扫描`input_dir` 目录下所有指定格式图片,并生成`可视化的结果`
194 changes: 194 additions & 0 deletions slim/quantization/deploy/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import ast
import os
import time

import argparse
import cv2
import yaml
import numpy as np
import paddle.fluid as fluid


class DeployConfig:
def __init__(self, conf_file):
if not os.path.exists(conf_file):
raise Exception('Config file path [%s] invalid!' % conf_file)

with open(conf_file) as fp:
configs = yaml.load(fp, Loader=yaml.FullLoader)
deploy_conf = configs["DEPLOY"]
# 1. get eval_crop_size
self.eval_crop_size = ast.literal_eval(
deploy_conf["EVAL_CROP_SIZE"])
# 2. get mean
self.mean = deploy_conf["MEAN"]
# 3. get std
self.std = deploy_conf["STD"]
# 4. get class_num
self.class_num = deploy_conf["NUM_CLASSES"]
# 5. get paddle model and params file path
self.model_file = os.path.join(deploy_conf["MODEL_PATH"],
deploy_conf["MODEL_FILENAME"])
self.param_file = os.path.join(deploy_conf["MODEL_PATH"],
deploy_conf["PARAMS_FILENAME"])
# 6. use_gpu
self.use_gpu = deploy_conf["USE_GPU"]
# 7. predictor_mode
self.predictor_mode = deploy_conf["PREDICTOR_MODE"]
# 8. batch_size
self.batch_size = deploy_conf["BATCH_SIZE"]
# 9. channels
self.channels = deploy_conf["CHANNELS"]


def create_predictor(args):
predictor_config = fluid.core.AnalysisConfig(args.conf.model_file,
args.conf.param_file)
predictor_config.enable_use_gpu(100, 0)
predictor_config.switch_ir_optim(True)
precision_type = fluid.core.AnalysisConfig.Precision.Float32 if not args.use_int8 else fluid.core.AnalysisConfig.Precision.Int8
use_calib = False
predictor_config.enable_tensorrt_engine(
workspace_size=1 << 30,
max_batch_size=1,
min_subgraph_size=40,
precision_mode=precision_type,
use_static=False,
use_calib_mode=use_calib)
predictor_config.switch_specify_input_names(True)
predictor_config.enable_memory_optim()
predictor = fluid.core.create_paddle_predictor(predictor_config)

return predictor


def preprocess(conf, image_path):
flag = cv2.IMREAD_UNCHANGED if conf.channels == 4 else cv2.IMREAD_COLOR
im = cv2.imread(image_path, flag)

channels = im.shape[2]
if channels != 3 and channels != 4:
print('Only support rgb(gray) or rgba image.')
return -1

ori_h = im.shape[0]
ori_w = im.shape[1]
eval_w, eval_h = conf.eval_crop_size
if ori_h != eval_h or ori_w != eval_w:
im = cv2.resize(
im, (eval_w, eval_h), fx=0, fy=0, interpolation=cv2.INTER_LINEAR)

im_mean = np.array(conf.mean).reshape((conf.channels, 1, 1))
im_std = np.array(conf.std).reshape((conf.channels, 1, 1))

im = im.swapaxes(1, 2)
im = im.swapaxes(0, 1)
im = im[:, :, :].astype('float32') / 255.0
im -= im_mean
im /= im_std

im = im[np.newaxis, :, :, :]
info = [image_path, im, (ori_w, ori_h)]
return info


# Generate ColorMap for visualization
def generate_colormap(num_classes):
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
return color_map


def infer(args):
predictor = create_predictor(args)
colormap = generate_colormap(args.conf.class_num)

images = get_images_from_dir(args.input_dir, args.ext)

for image in images:
im_info = preprocess(args.conf, image)

input_tensor = fluid.core.PaddleTensor()
input_tensor.name = 'image'
input_tensor.shape = im_info[1].shape
input_tensor.dtype = fluid.core.PaddleDType.FLOAT32
input_tensor.data = fluid.core.PaddleBuf(
im_info[1].ravel().astype("float32"))
input_tensor = [input_tensor]

output_tensor = predictor.run(input_tensor)[0]
output_data = output_tensor.as_ndarray()

img_name = im_info[0]
ori_shape = im_info[2]

logit = np.argmax(output_data, axis=1).squeeze()
logit = logit.astype('uint8')[:, :, np.newaxis]
logit = np.concatenate([logit] * 3, axis=2)

for i in range(logit.shape[0]):
for j in range(logit.shape[1]):
logit[i, j] = colormap[logit[i, j, 0]]

logit = cv2.resize(
logit, ori_shape, fx=0, fy=0, interpolation=cv2.INTER_CUBIC)

save_path = os.path.join(args.save_dir, img_name)
dirname = os.path.dirname(save_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
cv2.imwrite(save_path, logit, [cv2.CV_8UC1])


def get_images_from_dir(img_dir, support_ext='.jpg|.jpeg'):
if (not os.path.exists(img_dir) or not os.path.isdir(img_dir)):
raise Exception('Image Directory [%s] invalid' % img_dir)
imgs = []
for item in os.listdir(img_dir):
ext = os.path.splitext(item)[1][1:].strip().lower()
if (len(ext) > 0 and ext in support_ext):
item_path = os.path.join(img_dir, item)
imgs.append(item_path)
return imgs


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--conf', type=str, default='', help='Configuration File Path.')
parser.add_argument(
'--ext',
type=str,
default='.jpeg|.jpg',
help='Input Image File Extensions.')
parser.add_argument(
'--use_int8',
dest='use_int8',
action='store_true',
help='Whether to use int8 for prediction.')
parser.add_argument(
'--input_dir',
type=str,
help='Directory that store images to be predicted.')
parser.add_argument(
'--save_dir',
type=str,
default='output',
help='Directory for saving the predict results.')

return parser.parse_args()


if __name__ == '__main__':
args = parse_args()
args.conf = DeployConfig(args.conf)
result = infer(args)

0 comments on commit c03e30f

Please sign in to comment.