forked from ultralytics/ultralytics
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integrate ByteTracker and BoT-SORT trackers (ultralytics#788)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <[email protected]> Co-authored-by: Ayush Chaurasia <[email protected]>
- Loading branch information
1 parent
d99e04d
commit ed6c54d
Showing
24 changed files
with
1,635 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
## Tracker | ||
|
||
### Trackers | ||
|
||
- [x] ByteTracker | ||
- [x] BoT-SORT | ||
|
||
### Usage | ||
|
||
python interface: | ||
|
||
```python | ||
from ultralytics import YOLO | ||
|
||
model = YOLO("yolov8n.pt") # or a segmentation model .i.e yolov8n-seg.pt | ||
model.track( | ||
source="video/streams", | ||
stream=True, | ||
tracker="botsort.yaml/bytetrack.yaml", | ||
..., | ||
) | ||
``` | ||
|
||
cli: | ||
|
||
```bash | ||
yolo detect track source=... tracker=... | ||
yolo segment track source=... tracker=... | ||
``` | ||
|
||
By default, trackers will use the configuration in `ultralytics/tracker/cfg`. | ||
We also support using a modified tracker config file. Please refer to the tracker config files in `ultralytics/tracker/cfg`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .trackers import BYTETracker, BOTSORT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
tracker_type: botsort # tracker type, ['botsort', 'bytetrack'] | ||
track_high_thresh: 0.5 # threshold for the first association | ||
track_low_thresh: 0.1 # threshold for the second association | ||
new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks | ||
track_buffer: 30 # buffer to calculate the time when to remove tracks | ||
match_thresh: 0.8 # threshold for matching tracks | ||
# min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now) | ||
# mot20: False # for tracker evaluation(not used for now) | ||
|
||
# Botsort settings | ||
cmc_method: sparseOptFlow # method of global motion compensation | ||
# ReID model related thresh (not supported yet) | ||
proximity_thresh: 0.5 | ||
appearance_thresh: 0.25 | ||
with_reid: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
tracker_type: bytetrack # tracker type, ['botsort', 'bytetrack'] | ||
track_high_thresh: 0.5 # threshold for the first association | ||
track_low_thresh: 0.1 # threshold for the second association | ||
new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks | ||
track_buffer: 30 # buffer to calculate the time when to remove tracks | ||
match_thresh: 0.8 # threshold for matching tracks | ||
# min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now) | ||
# mot20: False # for tracker evaluation(not used for now) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from ultralytics.tracker import BYTETracker, BOTSORT | ||
from ultralytics.yolo.utils.checks import check_requirements, check_yaml | ||
from ultralytics.yolo.utils import IterableSimpleNamespace, yaml_load | ||
import torch | ||
|
||
TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT} | ||
check_requirements('lap') # for linear_assignment | ||
|
||
|
||
def on_predict_start(predictor): | ||
tracker = check_yaml(predictor.args.tracker) | ||
cfg = IterableSimpleNamespace(**yaml_load(tracker)) | ||
assert cfg.tracker_type in ["bytetrack", "botsort"], \ | ||
f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'" | ||
trackers = [] | ||
for _ in range(predictor.dataset.bs): | ||
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30) | ||
trackers.append(tracker) | ||
predictor.trackers = trackers | ||
|
||
|
||
def on_predict_postprocess_end(predictor): | ||
bs = predictor.dataset.bs | ||
im0s = predictor.batch[2] | ||
im0s = im0s if isinstance(im0s, list) else [im0s] | ||
for i in range(bs): | ||
det = predictor.results[i].boxes.cpu().numpy() | ||
if len(det) == 0: | ||
continue | ||
tracks = predictor.trackers[i].update(det, im0s[i]) | ||
if len(tracks) == 0: | ||
continue | ||
predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1])) | ||
if predictor.results[i].masks is not None: | ||
idx = tracks[:, -1].tolist() | ||
predictor.results[i].masks = predictor.results[i].masks[idx] | ||
|
||
|
||
def register_tracker(model): | ||
model.add_callback("on_predict_start", on_predict_start) | ||
model.add_callback("on_predict_postprocess_end", on_predict_postprocess_end) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .byte_tracker import BYTETracker | ||
from .bot_sort import BOTSORT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import numpy as np | ||
from collections import OrderedDict | ||
|
||
|
||
class TrackState: | ||
New = 0 | ||
Tracked = 1 | ||
Lost = 2 | ||
Removed = 3 | ||
|
||
|
||
class BaseTrack: | ||
_count = 0 | ||
|
||
track_id = 0 | ||
is_activated = False | ||
state = TrackState.New | ||
|
||
history = OrderedDict() | ||
features = [] | ||
curr_feature = None | ||
score = 0 | ||
start_frame = 0 | ||
frame_id = 0 | ||
time_since_update = 0 | ||
|
||
# multi-camera | ||
location = (np.inf, np.inf) | ||
|
||
@property | ||
def end_frame(self): | ||
return self.frame_id | ||
|
||
@staticmethod | ||
def next_id(): | ||
BaseTrack._count += 1 | ||
return BaseTrack._count | ||
|
||
def activate(self, *args): | ||
raise NotImplementedError | ||
|
||
def predict(self): | ||
raise NotImplementedError | ||
|
||
def update(self, *args, **kwargs): | ||
raise NotImplementedError | ||
|
||
def mark_lost(self): | ||
self.state = TrackState.Lost | ||
|
||
def mark_removed(self): | ||
self.state = TrackState.Removed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from collections import deque | ||
import numpy as np | ||
from ..utils import matching | ||
from ..utils.gmc import GMC | ||
from ..utils.kalman_filter import KalmanFilterXYWH | ||
from .byte_tracker import STrack, BYTETracker | ||
from .basetrack import TrackState | ||
|
||
|
||
class BOTrack(STrack): | ||
shared_kalman = KalmanFilterXYWH() | ||
|
||
def __init__(self, tlwh, score, cls, feat=None, feat_history=50): | ||
super().__init__(tlwh, score, cls) | ||
|
||
self.smooth_feat = None | ||
self.curr_feat = None | ||
if feat is not None: | ||
self.update_features(feat) | ||
self.features = deque([], maxlen=feat_history) | ||
self.alpha = 0.9 | ||
|
||
def update_features(self, feat): | ||
feat /= np.linalg.norm(feat) | ||
self.curr_feat = feat | ||
if self.smooth_feat is None: | ||
self.smooth_feat = feat | ||
else: | ||
self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat | ||
self.features.append(feat) | ||
self.smooth_feat /= np.linalg.norm(self.smooth_feat) | ||
|
||
def predict(self): | ||
mean_state = self.mean.copy() | ||
if self.state != TrackState.Tracked: | ||
mean_state[6] = 0 | ||
mean_state[7] = 0 | ||
|
||
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) | ||
|
||
def re_activate(self, new_track, frame_id, new_id=False): | ||
if new_track.curr_feat is not None: | ||
self.update_features(new_track.curr_feat) | ||
super().re_activate(new_track, frame_id, new_id) | ||
|
||
def update(self, new_track, frame_id): | ||
if new_track.curr_feat is not None: | ||
self.update_features(new_track.curr_feat) | ||
super().update(new_track, frame_id) | ||
|
||
@property | ||
def tlwh(self): | ||
"""Get current position in bounding box format `(top left x, top left y, | ||
width, height)`. | ||
""" | ||
if self.mean is None: | ||
return self._tlwh.copy() | ||
ret = self.mean[:4].copy() | ||
ret[:2] -= ret[2:] / 2 | ||
return ret | ||
|
||
@staticmethod | ||
def multi_predict(stracks): | ||
if len(stracks) > 0: | ||
multi_mean = np.asarray([st.mean.copy() for st in stracks]) | ||
multi_covariance = np.asarray([st.covariance for st in stracks]) | ||
for i, st in enumerate(stracks): | ||
if st.state != TrackState.Tracked: | ||
multi_mean[i][6] = 0 | ||
multi_mean[i][7] = 0 | ||
multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance) | ||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): | ||
stracks[i].mean = mean | ||
stracks[i].covariance = cov | ||
|
||
def convert_coords(self, tlwh): | ||
return self.tlwh_to_xywh(tlwh) | ||
|
||
@staticmethod | ||
def tlwh_to_xywh(tlwh): | ||
"""Convert bounding box to format `(center x, center y, width, | ||
height)`. | ||
""" | ||
ret = np.asarray(tlwh).copy() | ||
ret[:2] += ret[2:] / 2 | ||
return ret | ||
|
||
|
||
class BOTSORT(BYTETracker): | ||
|
||
def __init__(self, args, frame_rate=30): | ||
super().__init__(args, frame_rate) | ||
# ReID module | ||
self.proximity_thresh = args.proximity_thresh | ||
self.appearance_thresh = args.appearance_thresh | ||
|
||
if args.with_reid: | ||
# haven't supported bot-sort(reid) yet | ||
self.encoder = None | ||
# self.gmc = GMC(method=args.cmc_method, verbose=[args.name, args.ablation]) | ||
self.gmc = GMC(method=args.cmc_method) | ||
|
||
def get_kalmanfilter(self): | ||
return KalmanFilterXYWH() | ||
|
||
def init_track(self, dets, scores, cls, img=None): | ||
if len(dets) == 0: | ||
return [] | ||
if self.args.with_reid and self.encoder is not None: | ||
features_keep = self.encoder.inference(img, dets) | ||
detections = [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)] | ||
else: | ||
detections = [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] | ||
return detections | ||
|
||
def get_dists(self, tracks, detections): | ||
dists = matching.iou_distance(tracks, detections) | ||
dists_mask = (dists > self.proximity_thresh) | ||
|
||
# TODO: mot20 | ||
# if not self.args.mot20: | ||
dists = matching.fuse_score(dists, detections) | ||
|
||
if self.args.with_reid and self.encoder is not None: | ||
emb_dists = matching.embedding_distance(tracks, detections) / 2.0 | ||
emb_dists[emb_dists > self.appearance_thresh] = 1.0 | ||
emb_dists[dists_mask] = 1.0 | ||
dists = np.minimum(dists, emb_dists) | ||
return dists | ||
|
||
def multi_predict(self, tracks): | ||
BOTrack.multi_predict(tracks) |
Oops, something went wrong.