forked from Megvii-BaseDetection/YOLOX
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrt.py
77 lines (62 loc) · 2.18 KB
/
trt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import argparse
import os
import shutil
from loguru import logger
import tensorrt as trt
import torch
from torch2trt import torch2trt
from yolox.exp import get_exp
def make_parser():
parser = argparse.ArgumentParser("YOLOX ncnn deploy")
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
parser.add_argument("-n", "--name", type=str, default=None, help="model name")
parser.add_argument(
"-f",
"--exp_file",
default=None,
type=str,
help="pls input your expriment description file",
)
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
return parser
@logger.catch
def main():
args = make_parser().parse_args()
exp = get_exp(args.exp_file, args.name)
if not args.experiment_name:
args.experiment_name = exp.exp_name
model = exp.get_model()
file_name = os.path.join(exp.output_dir, args.experiment_name)
os.makedirs(file_name, exist_ok=True)
if args.ckpt is None:
ckpt_file = os.path.join(file_name, "best_ckpt.pth")
else:
ckpt_file = args.ckpt
ckpt = torch.load(ckpt_file, map_location="cpu")
# load the model state dict
model.load_state_dict(ckpt["model"])
logger.info("loaded checkpoint done.")
model.eval()
model.cuda()
model.head.decode_in_inference = False
x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
model_trt = torch2trt(
model,
[x],
fp16_mode=True,
log_level=trt.Logger.INFO,
max_workspace_size=(1 << 32),
)
torch.save(model_trt.state_dict(), os.path.join(file_name, "model_trt.pth"))
logger.info("Converted TensorRT model done.")
engine_file = os.path.join(file_name, "model_trt.engine")
engine_file_demo = os.path.join("demo", "TensorRT", "cpp", "model_trt.engine")
with open(engine_file, "wb") as f:
f.write(model_trt.engine.serialize())
shutil.copyfile(engine_file, engine_file_demo)
logger.info("Converted TensorRT model engine file is saved for C++ inference.")
if __name__ == "__main__":
main()