forked from STVIR/pysot
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
libo
committed
Jun 14, 2019
1 parent
b899d80
commit 835481c
Showing
4 changed files
with
220 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
|
||
% error('Tracker not configured! Please edit the tracker_test.m file.'); % Remove this line after proper configuration | ||
|
||
% The human readable label for the tracker, used to identify the tracker in reports | ||
% If not set, it will be set to the same value as the identifier. | ||
% It does not have to be unique, but it is best that it is. | ||
tracker_label = ['SiamRPNpp']; | ||
|
||
% For Python implementations we have created a handy function that generates the appropritate | ||
% command that will run the python executable and execute the given script that includes your | ||
% tracker implementation. | ||
% | ||
% Please customize the line below by substituting the first argument with the name of the | ||
% script of your tracker (not the .py file but just the name of the script) and also provide the | ||
% path (or multiple paths) where the tracker sources % are found as the elements of the cell | ||
% array (second argument). | ||
setenv('MKL_NUM_THREADS','1'); | ||
pysot_root = 'path/to/pysot'; | ||
track_build_path = 'path/to/track/build'; | ||
tracker_command = generate_python_command('vot_iter.vot_iter', {pysot_root; [track_build_path '/python/lib']}) | ||
|
||
tracker_interpreter = 'python'; | ||
|
||
tracker_linkpath = {track_build_path}; | ||
|
||
% tracker_linkpath = {}; % A cell array of custom library directories used by the tracker executable (optional) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
""" | ||
\file vot.py | ||
@brief Python utility functions for VOT integration | ||
@author Luka Cehovin, Alessio Dore | ||
@date 2016 | ||
""" | ||
|
||
import sys | ||
import copy | ||
import collections | ||
|
||
try: | ||
import trax | ||
except ImportError: | ||
raise Exception('TraX support not found. Please add trax module to Python path.') | ||
|
||
Rectangle = collections.namedtuple('Rectangle', ['x', 'y', 'width', 'height']) | ||
Point = collections.namedtuple('Point', ['x', 'y']) | ||
Polygon = collections.namedtuple('Polygon', ['points']) | ||
|
||
class VOT(object): | ||
""" Base class for Python VOT integration """ | ||
def __init__(self, region_format, channels=None): | ||
""" Constructor | ||
Args: | ||
region_format: Region format options | ||
""" | ||
assert(region_format in [trax.Region.RECTANGLE, trax.Region.POLYGON]) | ||
|
||
if channels is None: | ||
channels = ['color'] | ||
elif channels == 'rgbd': | ||
channels = ['color', 'depth'] | ||
elif channels == 'rgbt': | ||
channels = ['color', 'ir'] | ||
elif channels == 'ir': | ||
channels = ['ir'] | ||
else: | ||
raise Exception('Illegal configuration {}.'.format(channels)) | ||
|
||
self._trax = trax.Server([region_format], [trax.Image.PATH], channels) | ||
|
||
request = self._trax.wait() | ||
assert(request.type == 'initialize') | ||
if isinstance(request.region, trax.Polygon): | ||
self._region = Polygon([Point(x[0], x[1]) for x in request.region]) | ||
else: | ||
self._region = Rectangle(*request.region.bounds()) | ||
self._image = [x.path() for k, x in request.image.items()] | ||
if len(self._image) == 1: | ||
self._image = self._image[0] | ||
|
||
self._trax.status(request.region) | ||
|
||
def region(self): | ||
""" | ||
Send configuration message to the client and receive the initialization | ||
region and the path of the first image | ||
Returns: | ||
initialization region | ||
""" | ||
|
||
return self._region | ||
|
||
def report(self, region, confidence = None): | ||
""" | ||
Report the tracking results to the client | ||
Arguments: | ||
region: region for the frame | ||
""" | ||
assert(isinstance(region, Rectangle) or isinstance(region, Polygon)) | ||
if isinstance(region, Polygon): | ||
tregion = trax.Polygon.create([(x.x, x.y) for x in region.points]) | ||
else: | ||
tregion = trax.Rectangle.create(region.x, region.y, region.width, region.height) | ||
properties = {} | ||
if not confidence is None: | ||
properties['confidence'] = confidence | ||
self._trax.status(tregion, properties) | ||
|
||
def frame(self): | ||
""" | ||
Get a frame (image path) from client | ||
Returns: | ||
absolute path of the image | ||
""" | ||
if hasattr(self, "_image"): | ||
image = self._image | ||
del self._image | ||
return image | ||
|
||
request = self._trax.wait() | ||
|
||
if request.type == 'frame': | ||
image = [x.path() for k, x in request.image.items()] | ||
if len(image) == 1: | ||
return image[0] | ||
return image | ||
else: | ||
return None | ||
|
||
|
||
def quit(self): | ||
if hasattr(self, '_trax'): | ||
self._trax.quit() | ||
|
||
def __del__(self): | ||
self.quit() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import sys | ||
import cv2 | ||
import torch | ||
import numpy as np | ||
import os | ||
from os.path import join | ||
|
||
from pysot.core.config import cfg | ||
from pysot.models.model_builder import ModelBuilder | ||
from pysot.tracker.tracker_builder import build_tracker | ||
from pysot.utils.bbox import get_axis_aligned_bbox | ||
from pysot.utils.model_load import load_pretrain | ||
from toolkit.datasets import DatasetFactory | ||
from toolkit.utils.region import vot_overlap, vot_float2str | ||
|
||
from . import vot | ||
from .vot import Rectangle, Polygon, Point | ||
|
||
|
||
# modify root | ||
|
||
cfg_root = "path/to/expr" | ||
model_file = join(cfg_root, 'model.pth') | ||
cfg_file = join(cfg_root, 'config.yaml') | ||
|
||
def warmup(model): | ||
for i in range(10): | ||
model.template(torch.FloatTensor(1,3,127,127).cuda()) | ||
|
||
def setup_tracker(): | ||
cfg.merge_from_file(cfg_file) | ||
|
||
model = ModelBuilder() | ||
model = load_pretrain(model, model_file).cuda().eval() | ||
|
||
tracker = build_tracker(model) | ||
warmup(model) | ||
return tracker | ||
|
||
|
||
tracker = setup_tracker() | ||
|
||
handle = vot.VOT("polygon") | ||
region = handle.region() | ||
try: | ||
region = np.array([region[0][0][0], region[0][0][1], region[0][1][0], region[0][1][1], | ||
region[0][2][0], region[0][2][1], region[0][3][0], region[0][3][1]]) | ||
except: | ||
region = np.array(region) | ||
|
||
cx, cy, w, h = get_axis_aligned_bbox(region) | ||
|
||
image_file = handle.frame() | ||
if not image_file: | ||
sys.exit(0) | ||
|
||
im = cv2.imread(image_file) # HxWxC | ||
# init | ||
target_pos, target_sz = np.array([cx, cy]), np.array([w, h]) | ||
gt_bbox_ = [cx-(w-1)/2, cy-(h-1)/2, w, h] | ||
tracker.init(im, gt_bbox_) | ||
|
||
while True: | ||
img_file = handle.frame() | ||
if not img_file: | ||
break | ||
im = cv2.imread(img_file) | ||
outputs = tracker.track(im) | ||
pred_bbox = outputs['bbox'] | ||
result = Rectangle(*pred_bbox) | ||
score = outputs['best_score'] | ||
if cfg.MASK.MASK: | ||
pred_bbox = outputs['polygon'] | ||
result = Polygon(Point(x[0], x[1]) for x in pred_bbox) | ||
|
||
handle.report(result, score) |