forked from rbgirshick/fast-rcnn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
basic matlab detection functionality
- Loading branch information
1 parent
fb33dba
commit 2bf5297
Showing
7 changed files
with
362 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
A basic demo in MATLAB. | ||
|
||
Detection is also implemented in MATLAB (though missing some bells and whistles | ||
compared to the Python version) via the fast_rcnn_im_detect() function. | ||
|
||
See fast_rcnn_demo.m for example usage. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
% -------------------------------------------------------- | ||
% Fast R-CNN | ||
% Copyright (c) 2015 Microsoft | ||
% Licensed under The MIT License [see LICENSE for details] | ||
% Written by Ross Girshick | ||
% -------------------------------------------------------- | ||
|
||
function fast_rcnn_demo() | ||
% Fast R-CNN demo (in matlab). | ||
|
||
[folder, name, ext] = fileparts(mfilename('fullpath')); | ||
|
||
caffe_path = fullfile(folder, '..', 'caffe-fast-rcnn', 'matlab', 'caffe'); | ||
addpath(caffe_path); | ||
|
||
use_gpu = true; | ||
% You can try other models here: | ||
def = fullfile(folder, '..', 'models', 'VGG16', 'test.prototxt');; | ||
net = fullfile(folder, '..', 'data', 'fast_rcnn_models', ... | ||
'vgg16_fast_rcnn_iter_40000.caffemodel'); | ||
model = fast_rcnn_load_net(def, net, use_gpu); | ||
|
||
car_ind = 7; | ||
sofa_ind = 18; | ||
tv_ind = 20; | ||
|
||
demo(model, '000004', [car_ind], {'car'}); | ||
demo(model, '001551', [sofa_ind, tv_ind], {'sofa', 'tvmonitor'}); | ||
fprintf('\n'); | ||
|
||
% ------------------------------------------------------------------------ | ||
function demo(model, im_id, cls_inds, cls_names) | ||
% ------------------------------------------------------------------------ | ||
[folder, name, ext] = fileparts(mfilename('fullpath')); | ||
box_file = fullfile(folder, '..', 'data', 'demo', [im_id '_boxes.mat']); | ||
% Boxes were saved with 0-based indexing | ||
ld = load(box_file); boxes = single(ld.boxes) + 1; clear ld; | ||
im_file = fullfile(folder, '..', 'data', 'demo', [im_id '.jpg']); | ||
im = imread(im_file); | ||
dets = fast_rcnn_im_detect(model, im, boxes); | ||
|
||
THRESH = 0.8; | ||
for j = 1:length(cls_inds) | ||
cls_ind = cls_inds(j); | ||
cls_name = cls_names{j}; | ||
I = find(dets{cls_ind}(:, end) >= THRESH); | ||
showboxes(im, dets{cls_ind}(I, :)); | ||
title(sprintf('%s detections with p(%s | box) >= %.3f', ... | ||
cls_name, cls_name, THRESH)) | ||
fprintf('\n> Press any key to continue'); | ||
pause; | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
% -------------------------------------------------------- | ||
% Fast R-CNN | ||
% Copyright (c) 2015 Microsoft | ||
% Licensed under The MIT License [see LICENSE for details] | ||
% Written by Ross Girshick | ||
% -------------------------------------------------------- | ||
|
||
function dets = fast_rcnn_im_detect(model, im, boxes) | ||
% Perform detection a Fast R-CNN network given an image and | ||
% object proposals. | ||
|
||
if model.init_key ~= caffe('get_init_key') | ||
error('You probably need call fast_rcnn_load_net() first.'); | ||
end | ||
|
||
[im_batch, scales] = image_pyramid(im, model.pixel_means, false); | ||
|
||
[feat_pyra_boxes, feat_pyra_levels] = project_im_rois(boxes, scales); | ||
rois = cat(2, feat_pyra_levels, feat_pyra_boxes); | ||
% Adjust to 0-based indexing and make roi info the fastest dimension | ||
rois = rois - 1; | ||
rois = permute(rois, [2 1]); | ||
|
||
input_blobs = cell(2, 1); | ||
input_blobs{1} = im_batch; | ||
input_blobs{2} = rois; | ||
th = tic(); | ||
blobs_out = caffe('forward', input_blobs); | ||
fprintf('fwd: %.3fs\n', toc(th)); | ||
|
||
bbox_deltas = squeeze(blobs_out{1})'; | ||
probs = squeeze(blobs_out{2})'; | ||
|
||
num_classes = size(probs, 2); | ||
dets = cell(num_classes - 1, 1); | ||
NMS_THRESH = 0.3; | ||
% class index 1 is __background__, so we don't return it | ||
for j = 2:num_classes | ||
cls_probs = probs(:, j); | ||
cls_deltas = bbox_deltas(:, (1 + (j - 1) * 4):(j * 4)); | ||
pred_boxes = bbox_pred(boxes, cls_deltas); | ||
cls_dets = [pred_boxes cls_probs]; | ||
keep = nms(cls_dets, NMS_THRESH); | ||
cls_dets = cls_dets(keep, :); | ||
dets{j - 1} = cls_dets; | ||
end | ||
|
||
% ------------------------------------------------------------------------ | ||
function [batch, scales] = image_pyramid(im, pixel_means, multiscale) | ||
% ------------------------------------------------------------------------ | ||
% Construct an image pyramid that's ready for feeding directly into caffe | ||
if ~multiscale | ||
SCALES = [600]; | ||
MAX_SIZE = 1000; | ||
else | ||
SCALES = [1200 864 688 576 480]; | ||
MAX_SIZE = 2000; | ||
end | ||
num_levels = length(SCALES); | ||
|
||
im = single(im); | ||
% Convert to BGR | ||
im = im(:, :, [3 2 1]); | ||
% Subtract mean (mean of the image mean--one mean per channel) | ||
im = bsxfun(@minus, im, pixel_means); | ||
|
||
im_orig = im; | ||
im_size = min([size(im_orig, 1) size(im_orig, 2)]); | ||
im_size_big = max([size(im_orig, 1) size(im_orig, 2)]); | ||
scale_factors = SCALES ./ im_size; | ||
|
||
max_size = [0 0 0]; | ||
for i = 1:num_levels | ||
if round(im_size_big * scale_factors(i)) > MAX_SIZE | ||
scale_factors(i) = MAX_SIZE / im_size_big; | ||
end | ||
ims{i} = imresize(im_orig, scale_factors(i), 'bilinear', ... | ||
'antialiasing', false); | ||
max_size = max(cat(1, max_size, size(ims{i})), [], 1); | ||
end | ||
|
||
batch = zeros(max_size(2), max_size(1), 3, num_levels, 'single'); | ||
for i = 1:num_levels | ||
im = ims{i}; | ||
im_sz = size(im); | ||
im_sz = im_sz(1:2); | ||
% Make width the fastest dimension (for caffe) | ||
im = permute(im, [2 1 3]); | ||
batch(1:im_sz(2), 1:im_sz(1), :, i) = im; | ||
end | ||
scales = scale_factors'; | ||
|
||
% ------------------------------------------------------------------------ | ||
function [boxes, levels] = project_im_rois(boxes, scales) | ||
% ------------------------------------------------------------------------ | ||
widths = boxes(:,3) - boxes(:,1) + 1; | ||
heights = boxes(:,4) - boxes(:,2) + 1; | ||
|
||
areas = widths .* heights; | ||
scaled_areas = bsxfun(@times, areas, (scales.^2)'); | ||
diff_areas = abs(scaled_areas - (224 * 224)); | ||
[~, levels] = min(diff_areas, [], 2); | ||
|
||
boxes = boxes - 1; | ||
boxes = bsxfun(@times, boxes, scales(levels)); | ||
boxes = boxes + 1; | ||
|
||
% ------------------------------------------------------------------------ | ||
function pred_boxes = bbox_pred(boxes, bbox_deltas) | ||
% ------------------------------------------------------------------------ | ||
if isempty(boxes) | ||
pred_boxes = []; | ||
return; | ||
end | ||
|
||
Y = bbox_deltas; | ||
|
||
% Read out predictions | ||
dst_ctr_x = Y(:, 1); | ||
dst_ctr_y = Y(:, 2); | ||
dst_scl_x = Y(:, 3); | ||
dst_scl_y = Y(:, 4); | ||
|
||
src_w = boxes(:, 3) - boxes(:, 1) + eps; | ||
src_h = boxes(:, 4) - boxes(:, 2) + eps; | ||
src_ctr_x = boxes(:, 1) + 0.5 * src_w; | ||
src_ctr_y = boxes(:, 2) + 0.5 * src_h; | ||
|
||
pred_ctr_x = (dst_ctr_x .* src_w) + src_ctr_x; | ||
pred_ctr_y = (dst_ctr_y .* src_h) + src_ctr_y; | ||
pred_w = exp(dst_scl_x) .* src_w; | ||
pred_h = exp(dst_scl_y) .* src_h; | ||
pred_boxes = [pred_ctr_x - 0.5 * pred_w, pred_ctr_y - 0.5 * pred_h, ... | ||
pred_ctr_x + 0.5 * pred_w, pred_ctr_y + 0.5 * pred_h]; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
% -------------------------------------------------------- | ||
% Fast R-CNN | ||
% Copyright (c) 2015 Microsoft | ||
% Licensed under The MIT License [see LICENSE for details] | ||
% Written by Ross Girshick | ||
% -------------------------------------------------------- | ||
|
||
function model = fast_rcnn_load_net(def, net, use_gpu) | ||
% Load a Fast R-CNN network. | ||
|
||
init_key = caffe('init', def, net, 'test'); | ||
if exist('use_gpu', 'var') && ~use_gpu | ||
caffe('set_mode_cpu'); | ||
else | ||
caffe('set_mode_gpu'); | ||
end | ||
|
||
model.init_key = init_key; | ||
% model.stride is correct for the included models, but may not be correct | ||
% for other models! | ||
model.stride = 16; | ||
model.pixel_means = reshape([102.9801, 115.9465, 122.7717], [1 1 3]); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
function cnn = init_cnn_model(varargin) | ||
% cnn = init_cnn_model | ||
% Initialize a CNN with caffe | ||
% | ||
% Optional arguments | ||
% net_file network binary file | ||
% def_file network prototxt file | ||
% use_gpu set to false to use CPU (default: true) | ||
% use_caffe set to false to avoid using caffe (default: true) | ||
% useful for running on the cluster (must use cached pyramids!) | ||
|
||
% ------------------------------------------------------------------------ | ||
% Options | ||
ip = inputParser; | ||
|
||
% network binary file | ||
ip.addParamValue('net_file', ... | ||
'./data/caffe_nets/ilsvrc_2012_train_iter_310k', ... | ||
@isstr); | ||
|
||
% network prototxt file | ||
ip.addParamValue('def_file', ... | ||
'./model-defs/pyramid_cnn_output_conv5_scales_7_plane_1713.prototxt', ... | ||
@isstr); | ||
|
||
% Set use_gpu to false to use the CPU | ||
ip.addParamValue('use_gpu', true, @islogical); | ||
|
||
% Set use_caffe to false to avoid using caffe | ||
% (must be used in conjunction with cached features!) | ||
ip.addParamValue('use_caffe', true, @islogical); | ||
|
||
ip.parse(varargin{:}); | ||
opts = ip.Results; | ||
% ------------------------------------------------------------------------ | ||
|
||
cnn.binary_file = opts.net_file; | ||
cnn.definition_file = opts.def_file; | ||
cnn.init_key = -1; | ||
|
||
% load the ilsvrc image mean | ||
data_mean_file = 'ilsvrc_2012_mean.mat'; | ||
assert(exist(data_mean_file, 'file') ~= 0); | ||
% input size business isn't likley necessary, but we're doing it | ||
% to be consistent with previous experiments | ||
ld = load(data_mean_file); | ||
mu = ld.image_mean; clear ld; | ||
input_size = 227; | ||
off = floor((size(mu,1) - input_size)/2)+1; | ||
%mu = mu(off:off+input_size-1, off:off+input_size-1, :); | ||
%mu = sum(sum(mu, 1), 2) / size(mu, 1) / size(mu, 2); | ||
cnn.mu = reshape([102.9801, 115.9465, 122.7717], [1 1 3]); | ||
|
||
if opts.use_caffe | ||
cnn.init_key = ... | ||
caffe('init', cnn.definition_file, cnn.binary_file); | ||
caffe('set_phase_test'); | ||
if opts.use_gpu | ||
caffe('set_mode_gpu'); | ||
else | ||
caffe('set_mode_cpu'); | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
function pick = nms(boxes, overlap) | ||
% top = nms(boxes, overlap) | ||
% Non-maximum suppression. (FAST VERSION) | ||
% Greedily select high-scoring detections and skip detections | ||
% that are significantly covered by a previously selected | ||
% detection. | ||
% | ||
% NOTE: This is adapted from Pedro Felzenszwalb's version (nms.m), | ||
% but an inner loop has been eliminated to significantly speed it | ||
% up in the case of a large number of boxes | ||
|
||
% Copyright (C) 2011-12 by Tomasz Malisiewicz | ||
% All rights reserved. | ||
% | ||
% This file is part of the Exemplar-SVM library and is made | ||
% available under the terms of the MIT license (see COPYING file). | ||
% Project homepage: https://github.com/quantombone/exemplarsvm | ||
|
||
|
||
if isempty(boxes) | ||
pick = []; | ||
return; | ||
end | ||
|
||
x1 = boxes(:,1); | ||
y1 = boxes(:,2); | ||
x2 = boxes(:,3); | ||
y2 = boxes(:,4); | ||
s = boxes(:,end); | ||
|
||
area = (x2-x1+1) .* (y2-y1+1); | ||
[vals, I] = sort(s); | ||
|
||
pick = s*0; | ||
counter = 1; | ||
while ~isempty(I) | ||
last = length(I); | ||
i = I(last); | ||
pick(counter) = i; | ||
counter = counter + 1; | ||
|
||
xx1 = max(x1(i), x1(I(1:last-1))); | ||
yy1 = max(y1(i), y1(I(1:last-1))); | ||
xx2 = min(x2(i), x2(I(1:last-1))); | ||
yy2 = min(y2(i), y2(I(1:last-1))); | ||
|
||
w = max(0.0, xx2-xx1+1); | ||
h = max(0.0, yy2-yy1+1); | ||
|
||
inter = w.*h; | ||
o = inter ./ (area(i) + area(I(1:last-1)) - inter); | ||
|
||
I = I(find(o<=overlap)); | ||
end | ||
|
||
pick = pick(1:(counter-1)); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
% -------------------------------------------------------- | ||
% Fast R-CNN | ||
% Copyright (c) 2015 Microsoft | ||
% Licensed under The MIT License [see LICENSE for details] | ||
% Written by Ross Girshick | ||
% -------------------------------------------------------- | ||
|
||
function showboxes(im, boxes) | ||
|
||
image(im); | ||
axis image; | ||
axis off; | ||
set(gcf, 'Color', 'white'); | ||
|
||
if ~isempty(boxes) | ||
x1 = boxes(:, 1); | ||
y1 = boxes(:, 2); | ||
x2 = boxes(:, 3); | ||
y2 = boxes(:, 4); | ||
c = 'r'; | ||
s = '-'; | ||
line([x1 x1 x2 x2 x1]', [y1 y2 y2 y1 y1]', ... | ||
'color', c, 'linewidth', 2, 'linestyle', s); | ||
for i = 1:size(boxes, 1) | ||
text(double(x1(i)), double(y1(i)) - 2, ... | ||
sprintf('%.3f', boxes(i, end)), ... | ||
'backgroundcolor', 'r', 'color', 'w'); | ||
end | ||
end |