-
Notifications
You must be signed in to change notification settings - Fork 168
/
grid_search.m
71 lines (58 loc) · 1.99 KB
/
grid_search.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
% --------------------------------------------------------
% MDP Tracking
% Copyright (c) 2015 CVGL Stanford
% Licensed under The MIT License [see LICENSE for details]
% Written by Yu Xiang
% --------------------------------------------------------
%
% find parameters by grid search
function opt = grid_search(seq_idx)
% basic parameters
opt = globals();
% 1. search for rescale box (very important for optical flow tracking)
fname = 'rescale_box';
values = {[1, 1], [0.8, 1], [0.6, 1], [0.6, 0.8]};
opt = search_parameter(seq_idx, opt, fname, values);
% 2. search for enlarge box
fname = 'enlarge_box';
values = {2, 3, 4, 5};
opt = search_parameter(seq_idx, opt, fname, values);
% 3. search for max ratio changed allowed in LK
fname = 'max_ratio';
values = {0.6, 0.7, 0.8, 0.9};
opt = search_parameter(seq_idx, opt, fname, values);
% 4. search for weight tracking
fname = 'weight_tracking';
values = {1, 3, 5, 7};
opt = search_parameter(seq_idx, opt, fname, values);
% 5. search for threshold ratio
fname = 'threshold_ratio';
values = {0.6, 0.7, 0.8, 0.9};
opt = search_parameter(seq_idx, opt, fname, values);
% 6. search for threshold dis
fname = 'threshold_dis';
values = {1, 3, 5, 7};
opt = search_parameter(seq_idx, opt, fname, values);
% save parameters
seq_name = opt.mot2d_train_seqs{seq_idx};
filename = sprintf('%s/%s_opt.mat', opt.results, seq_name);
fprintf('save parameters to %s\n', filename);
save(filename, 'opt');
% search a specific parameter
function opt = search_parameter(seq_idx, opt, fname, values)
num = numel(values);
mota = zeros(num, 1);
for i = 1:num
fprintf('Test parameter %s: ', fname);
disp(values{i});
opt.(fname) = values{i};
tracker = MDP_train(seq_idx, opt);
metrics = MDP_test(seq_idx, 'train', tracker);
mota(i) = metrics.mets2d.m(12);
fprintf('Test parameter %s, mota %.2f, ', fname, mota(i));
disp(values{i});
end
[m, ind] = max(mota);
opt.(fname) = values{ind};
fprintf('Final parameter %s with max mota %f, ', fname, m);
disp(opt.(fname));