Skip to content

Commit

Permalink
Add post-processing operations for model export and inference.
Browse files Browse the repository at this point in the history
  • Loading branch information
LutaoChu authored Apr 19, 2021
1 parent 7b9650e commit 1e27e45
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 8 deletions.
1 change: 1 addition & 0 deletions deploy/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ python deploy/python/infer.py --config /path/to/deploy.yaml --image_path
|use_int8|启动TensorRT预测时,是否以int8模式运行|||
|batch_size|单卡batch size||配置文件中指定值|
|save_dir|保存预测结果的目录||output|
|with_argmax|对预测结果进行argmax操作|||

*测试样例和预测结果如下*
![cityscape_predict_demo.png](../../docs/images/cityscapes_predict_demo.png)
Expand Down
11 changes: 9 additions & 2 deletions deploy/python/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ def postprocess(self, results, imgs):

results = np.concatenate(results, axis=0)
for i in range(results.shape[0]):
result = np.argmax(results[i], axis=0)
if self.args.with_argmax:
result = np.argmax(results[i], axis=0)
else:
result = results[i]
result = get_pseudo_color_map(result)
basename = os.path.basename(imgs[i])
basename, _ = os.path.splitext(basename)
Expand Down Expand Up @@ -158,6 +161,11 @@ def parse_args():
dest='use_int8',
help='Whether to use Int8 prediction when using TensorRT prediction.',
action='store_true')
parser.add_argument(
'--with_argmax',
dest='with_argmax',
help='Perform argmax operation on the predict result.',
action='store_true')

return parser.parse_args()

Expand All @@ -182,7 +190,6 @@ def main(args):
env_info = get_sys_env()
args.use_gpu = True if env_info['Paddle compiled with cuda'] and env_info[
'GPUs used'] else False

predictor = Predictor(args)
predictor.run(get_images(args.image_path))

Expand Down
2 changes: 2 additions & 0 deletions docs/model_export.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ python export.py \
|config|配置文件||-|
|save_dir|模型和visualdl日志文件的保存根路径||output|
|model_path|预训练模型参数的路径||配置文件中指定值|
|with_softmax|在网络末端添加softmax算子。由于PaddleSeg组网默认返回logits,如果想要部署模型获取概率值,可以置为True||False|
|without_argmax|是否不在网络末端添加argmax算子。由于PaddleSeg组网默认返回logits,为部署模型可以直接获取预测结果,我们默认在网络末端添加argmax算子||False|

## 结果文件

Expand Down
57 changes: 51 additions & 6 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,58 @@ def parse_args():
parser.add_argument(
'--save_dir',
dest='save_dir',
help='The directory for saving the model snapshot',
help='The directory for saving the exported model',
type=str,
default='./output')
parser.add_argument(
'--model_path',
dest='model_path',
help='The path of model for evaluation',
help='The path of model for export',
type=str,
default=None)
parser.add_argument(
'--without_argmax',
dest='without_argmax',
help='Do not add the argmax operation at the end of the network',
action='store_true')
parser.add_argument(
'--with_softmax',
dest='with_softmax',
help='Add the softmax operation at the end of the network',
action='store_true')

return parser.parse_args()


class SavedSegmentationNet(paddle.nn.Layer):
def __init__(self, net, without_argmax=False, with_softmax=False):
super().__init__()
self.net = net
self.post_processer = PostPorcesser(without_argmax, with_softmax)

def forward(self, x):
outs = self.net(x)
outs = self.post_processer(outs)
return outs


class PostPorcesser(paddle.nn.Layer):
def __init__(self, without_argmax, with_softmax):
super().__init__()
self.without_argmax = without_argmax
self.with_softmax = with_softmax

def forward(self, outs):
new_outs = []
for out in outs:
if self.with_softmax:
out = paddle.nn.functional.softmax(out, axis=1)
if not self.without_argmax:
out = paddle.argmax(out, axis=1)
new_outs.append(out)
return new_outs


def main(args):
os.environ['PADDLESEG_EXPORT_STAGE'] = 'True'
cfg = Config(args.cfg)
Expand All @@ -58,15 +97,21 @@ def main(args):
net.set_dict(para_state_dict)
logger.info('Loaded trained params of model successfully.')

net.eval()
net = paddle.jit.to_static(
net,
if not args.without_argmax or args.with_softmax:
new_net = SavedSegmentationNet(net, args.without_argmax,
args.with_softmax)
else:
new_net = net

new_net.eval()
new_net = paddle.jit.to_static(
new_net,
input_spec=[
paddle.static.InputSpec(
shape=[None, 3, None, None], dtype='float32')
])
save_path = os.path.join(args.save_dir, 'model')
paddle.jit.save(net, save_path)
paddle.jit.save(new_net, save_path)

yml_file = os.path.join(args.save_dir, 'deploy.yaml')
with open(yml_file, 'w') as file:
Expand Down

0 comments on commit 1e27e45

Please sign in to comment.