Skip to content

Commit ef88dfa

Browse files
committed
update custom video testing
1 parent dfd011e commit ef88dfa

File tree

1 file changed

+230
-0
lines changed

1 file changed

+230
-0
lines changed

tracking/run_video.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# ------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft
3+
# Licensed under the MIT License.
4+
5+
# Detail: test on a specific video (provide init bbox [optional] and video file)
6+
# ------------------------------------------------------------------------------
7+
8+
import _init_paths
9+
import os
10+
import cv2
11+
import torch
12+
import random
13+
import argparse
14+
import numpy as np
15+
16+
try:
17+
from torch2trt import TRTModule
18+
except:
19+
print('Warning: TensorRT is not successfully imported')
20+
21+
import models.models as models
22+
23+
from os.path import exists, join, dirname, realpath
24+
from tracker.ocean import Ocean
25+
from tracker.online import ONLINE
26+
from easydict import EasyDict as edict
27+
from utils.utils import load_pretrain, cxy_wh_2_rect, get_axis_aligned_bbox, load_dataset, poly_iou
28+
29+
from eval_toolkit.pysot.datasets import VOTDataset
30+
from eval_toolkit.pysot.evaluation import EAOBenchmark
31+
from tqdm import tqdm
32+
33+
34+
def parse_args():
35+
"""
36+
args for fc testing.
37+
"""
38+
parser = argparse.ArgumentParser(description='PyTorch SiamFC Tracking Test')
39+
parser.add_argument('--arch', default='Ocean', type=str, help='backbone architecture')
40+
parser.add_argument('--resume', default='snapshot/OceanV19on.pth', type=str, help='pretrained model')
41+
parser.add_argument('--video', default='./dataset/soccer1.mp4', type=str, help='video file path')
42+
parser.add_argument('--online', default=True, type=bool, help='use online or offline model')
43+
parser.add_argument('--save', default=True, type=bool, help='save pictures')
44+
parser.add_argument('--init_bbox', default=None, help='bbox in the first frame None or [lx, ly, w, h]')
45+
args = parser.parse_args()
46+
47+
return args
48+
49+
50+
def track_video(siam_tracker, online_tracker, siam_net, video_path, init_box=None, args=None):
51+
52+
assert os.path.isfile(video_path), "please provide a valid video file"
53+
54+
video_name = video_path.split('/')[-1]
55+
video_name = video_name.split('.')[0]
56+
save_path = os.path.join('vis', video_name)
57+
if not os.path.exists(save_path):
58+
os.makedirs(save_path)
59+
60+
cap = cv2.VideoCapture(video_path)
61+
display_name = 'Video: {}'.format(video_path.split('/')[-1])
62+
cv2.namedWindow(display_name, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)
63+
cv2.resizeWindow(display_name, 960, 720)
64+
success, frame = cap.read()
65+
cv2.imshow(display_name, frame)
66+
67+
if success is not True:
68+
print("Read failed.")
69+
exit(-1)
70+
71+
# init
72+
count = 0
73+
74+
if init_box is not None:
75+
lx, ly, w, h = init_box
76+
target_pos = np.array([lx + w/2, ly + h/2])
77+
target_sz = np.array([w, h])
78+
79+
state = siam_tracker.init(frame, target_pos, target_sz, siam_net) # init tracker
80+
rgb_im = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
81+
82+
if args.online:
83+
online_tracker.init(frame, rgb_im, siam_net, target_pos, target_sz, True, dataname='VOT2019', resume=args.resume)
84+
85+
else:
86+
while True:
87+
88+
frame_disp = frame.copy()
89+
90+
cv2.putText(frame_disp, 'Select target ROI and press ENTER', (20, 30), cv2.FONT_HERSHEY_COMPLEX_SMALL,
91+
1, (0, 0, 255), 1)
92+
93+
lx, ly, w, h = cv2.selectROI(display_name, frame_disp, fromCenter=False)
94+
target_pos = np.array([lx + w / 2, ly + h / 2])
95+
target_sz = np.array([w, h])
96+
97+
state = siam_tracker.init(frame_disp, target_pos, target_sz, siam_net) # init tracker
98+
rgb_im = cv2.cvtColor(frame_disp, cv2.COLOR_BGR2RGB)
99+
100+
if args.online:
101+
online_tracker.init(frame_disp, rgb_im, siam_net, target_pos, target_sz, True, dataname='VOT2019', resume=args.resume)
102+
103+
break
104+
105+
while True:
106+
ret, frame = cap.read()
107+
108+
if frame is None:
109+
return
110+
111+
frame_disp = frame.copy()
112+
rgb_im = cv2.cvtColor(frame_disp, cv2.COLOR_BGR2RGB)
113+
114+
# Draw box
115+
if args.online:
116+
state = online_tracker.track(frame_disp, rgb_im, siam_tracker, state)
117+
else:
118+
state = siam_tracker.track(state, frame_disp)
119+
120+
location = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
121+
x1, y1, x2, y2 = int(location[0]), int(location[1]), int(location[0] + location[2]), int(location[1] + location[3])
122+
123+
cv2.rectangle(frame_disp, (x1, y1), (x2, y2), (0, 255, 0), 5)
124+
125+
font_color = (0, 0, 0)
126+
cv2.putText(frame_disp, 'Tracking!', (20, 30), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1,
127+
font_color, 1)
128+
cv2.putText(frame_disp, 'Press r to reset', (20, 55), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1,
129+
font_color, 1)
130+
cv2.putText(frame_disp, 'Press q to quit', (20, 80), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1,
131+
font_color, 1)
132+
133+
# Display the resulting frame
134+
cv2.imshow(display_name, frame_disp)
135+
136+
if args.save:
137+
save_name = os.path.join(save_path, '{:04d}.jpg'.format(count))
138+
cv2.imwrite(save_name, frame_disp)
139+
count += 1
140+
141+
key = cv2.waitKey(1)
142+
# key = None
143+
if key == ord('q'):
144+
break
145+
elif key == ord('r'):
146+
ret, frame = cap.read()
147+
frame_disp = frame.copy()
148+
149+
cv2.putText(frame_disp, 'Select target ROI and press ENTER', (20, 30), cv2.FONT_HERSHEY_COMPLEX_SMALL,
150+
1.5,
151+
(0, 0, 0), 1)
152+
153+
cv2.imshow(display_name, frame_disp)
154+
lx, ly, w, h = cv2.selectROI(display_name, frame_disp, fromCenter=False)
155+
target_pos = np.array([lx + w / 2, ly + h / 2])
156+
target_sz = np.array([w, h])
157+
158+
state = siam_tracker.init(frame_disp, target_pos, target_sz, siam_net) # init tracker
159+
rgb_im = cv2.cvtColor(frame_disp, cv2.COLOR_BGR2RGB)
160+
161+
if args.online:
162+
online_tracker.init(frame_disp, rgb_im, siam_net, target_pos, target_sz, True, dataname='VOT2019', resume=args.resume)
163+
164+
# When everything done, release the capture
165+
cap.release()
166+
cv2.destroyAllWindows()
167+
168+
169+
def main():
170+
args = parse_args()
171+
172+
# prepare model (SiamRPN or SiamFC)
173+
174+
# prepare tracker
175+
info = edict()
176+
info.arch = args.arch
177+
info.dataset = 'VOT2019'
178+
info.TRT = 'TRT' in args.arch
179+
info.epoch_test = False
180+
181+
siam_info = edict()
182+
siam_info.arch = args.arch
183+
siam_info.dataset = 'VOT2019'
184+
siam_info.online = args.online
185+
siam_info.epoch_test = False
186+
siam_info.TRT = 'TRT' in args.arch
187+
188+
siam_info.align = False
189+
190+
if siam_info.TRT:
191+
siam_info.align = False
192+
193+
siam_tracker = Ocean(siam_info)
194+
siam_net = models.__dict__[args.arch](align=siam_info.align, online=args.online)
195+
print('===> init Siamese <====')
196+
197+
if not siam_info.TRT:
198+
siam_net = load_pretrain(siam_net, args.resume)
199+
else:
200+
print("tensorrt toy model: not loading checkpoint")
201+
siam_net.eval()
202+
siam_net = siam_net.cuda()
203+
204+
if siam_info.TRT:
205+
print('===> load model from TRT <===')
206+
print('===> please ignore the warning information of TRT <===')
207+
print('===> We only provide a toy demo for TensorRT. There are some operations are not supported well.<===')
208+
print('===> If you wang to test on benchmark, please us Pytorch version. <===')
209+
print('===> The tensorrt code will be contingously optimized (with the updating of official TensorRT.)<===')
210+
trtNet = reloadTRT()
211+
siam_net.tensorrt_init(trtNet)
212+
213+
if args.online:
214+
online_tracker = ONLINE(info)
215+
else:
216+
online_tracker = None
217+
218+
print('[*] ======= Track video with {} ======='.format(args.arch))
219+
220+
# check init box is list or not
221+
if not isinstance(args.init_bbox, list) and args.init_bbox is not None:
222+
args.init_bbox = list(eval(args.init_bbox))
223+
else:
224+
args.init_bbox = None
225+
print('===> please draw a box with your mouse <====')
226+
227+
track_video(siam_tracker, online_tracker, siam_net, args.video, init_box=args.init_bbox, args=args)
228+
229+
if __name__ == '__main__':
230+
main()

0 commit comments

Comments
 (0)