forked from IrisRainbowNeko/genshin_auto_fish
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredictor.py
114 lines (97 loc) · 3.34 KB
/
predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import time
from loguru import logger
import os
import torch
import cv2
from yolox.data.data_augment import ValTransform
from yolox.data.datasets import FISH_CLASSES
from yolox.utils import postprocess, vis
class Predictor(object):
def __init__(
self,
model,
exp,
cls_names=FISH_CLASSES,
trt_file=None,
decoder=None,
device="cpu",
fp16=False,
legacy=False,
):
self.model = model
self.cls_names = cls_names
self.decoder = decoder
self.num_classes = exp.num_classes
self.confthre = exp.test_conf
self.nmsthre = exp.nmsthre
self.test_size = exp.test_size
self.device = device
self.fp16 = fp16
self.preproc = ValTransform(legacy=legacy)
if trt_file is not None:
from torch2trt import TRTModule
model_trt = TRTModule()
model_trt.load_state_dict(torch.load(trt_file))
x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
self.model(x)
self.model = model_trt
def inference(self, img):
img_info = {"id": 0}
if isinstance(img, str):
img_info["file_name"] = os.path.basename(img)
img = cv2.imread(img)
else:
img_info["file_name"] = None
height, width = img.shape[:2]
img_info["height"] = height
img_info["width"] = width
img_info["raw_img"] = img
ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
img_info["ratio"] = ratio
img, _ = self.preproc(img, None, self.test_size)
img = torch.from_numpy(img).unsqueeze(0)
img = img.float()
if self.device == "gpu":
img = img.cuda()
if self.fp16:
img = img.half() # to FP16
with torch.no_grad():
t0 = time.time()
outputs = self.model(img)
if self.decoder is not None:
outputs = self.decoder(outputs, dtype=outputs.type())
outputs = postprocess(
outputs, self.num_classes, self.confthre,
self.nmsthre, class_agnostic=True
)
logger.info("Infer time: {:.4f}s".format(time.time() - t0))
return outputs, img_info
def image_det(self, img, with_info=False):
outputs, img_info = self.inference(img)
ratio = img_info["ratio"]
obj_list = []
if outputs[0] is None:
return None
for item in outputs[0].cpu():
bboxes = item[:4]
# preprocessing: resize
bboxes /= ratio
scores = item[4] * item[5]
obj_list.append([self.cls_names[int(item[6])], scores, [bboxes[0], bboxes[1], bboxes[2], bboxes[3]]])
if with_info:
return obj_list, outputs, img_info
else:
return obj_list
def visual(self, output, img_info, cls_conf=0.35):
ratio = img_info["ratio"]
img = img_info["raw_img"]
if output is None:
return img
output = output.cpu()
bboxes = output[:, 0:4]
# preprocessing: resize
bboxes /= ratio
cls = output[:, 6]
scores = output[:, 4] * output[:, 5]
vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
return vis_res