Skip to content

Commit

Permalink
Fix bugs in the feature extraction example
Browse files Browse the repository at this point in the history
  • Loading branch information
kloudkl committed Mar 19, 2014
1 parent 4de8280 commit dfe6380
Showing 1 changed file with 53 additions and 64 deletions.
117 changes: 53 additions & 64 deletions examples/demo_extract_features.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ int feature_extraction_pipeline(int argc, char** argv) {
if (argc < num_required_args) {
LOG(ERROR)<<
"This program takes in a trained network and an input data layer, and then"
" extract features of the input data produced by the net."
" extract features of the input data produced by the net.\n"
"Usage: demo_extract_features pretrained_net_param"
" extract_feature_blob_name data_prototxt data_layer_name"
" save_feature_leveldb_name [CPU/GPU] [DEVICE_ID=0]";
" feature_extraction_proto_file extract_feature_blob_name"
" save_feature_leveldb_name num_mini_batches [CPU/GPU] [DEVICE_ID=0]";
return 1;
}
int arg_pos = num_required_args;
Expand All @@ -58,86 +58,78 @@ int feature_extraction_pipeline(int argc, char** argv) {
NetParameter pretrained_net_param;

arg_pos = 0; // the name of the executable
// We directly load the net param from trained file
string pretrained_binary_proto(argv[++arg_pos]);
ReadProtoFromBinaryFile(pretrained_binary_proto.c_str(),
&pretrained_net_param);

// Expected prototxt contains at least one data layer such as
// the layer data_layer_name and one feature blob such as the
// fc7 top blob to extract features.
/*
layers {
layer {
name: "data_layer_name"
type: "data"
source: "/path/to/your/images/to/extract/feature/images_leveldb"
meanfile: "/path/to/your/image_mean.binaryproto"
batchsize: 128
cropsize: 227
mirror: false
}
top: "data_blob_name"
top: "label_blob_name"
}
layers {
layer {
name: "drop7"
type: "dropout"
dropout_ratio: 0.5
}
bottom: "fc7"
top: "fc7"
}
*/
NetParameter feature_extraction_net_param;;
string feature_extraction_proto(argv[++arg_pos]);
ReadProtoFromTextFile(feature_extraction_proto,
&feature_extraction_net_param);
shared_ptr<Net<Dtype> > feature_extraction_net(
new Net<Dtype>(pretrained_net_param));
new Net<Dtype>(feature_extraction_net_param));
feature_extraction_net->CopyTrainedLayersFrom(pretrained_net_param);

string extract_feature_blob_name(argv[++arg_pos]);
if (!feature_extraction_net->HasBlob(extract_feature_blob_name)) {
LOG(ERROR)<< "Unknown feature blob name " << extract_feature_blob_name <<
" in trained network " << pretrained_binary_proto;
" in the network " << feature_extraction_proto;
return 1;
}

// Expected prototxt contains at least one data layer to extract features.
/*
layers {
layer {
name: "data_layer_name"
type: "data"
source: "/path/to/your/images/to/extract/feature/images_leveldb"
meanfile: "/path/to/your/image_mean.binaryproto"
batchsize: 128
cropsize: 227
mirror: false
}
top: "data_blob_name"
top: "label_blob_name"
}
*/
string data_prototxt(argv[++arg_pos]);
string data_layer_name(argv[++arg_pos]);
NetParameter data_net_param;
ReadProtoFromTextFile(data_prototxt.c_str(), &data_net_param);
LayerParameter data_layer_param;
int num_layer;
for (num_layer = 0; num_layer < data_net_param.layers_size(); ++num_layer) {
if (data_layer_name == data_net_param.layers(num_layer).layer().name()) {
data_layer_param = data_net_param.layers(num_layer).layer();
break;
}
}
if (num_layer = data_net_param.layers_size()) {
LOG(ERROR) << "Unknown data layer name " << data_layer_name <<
" in prototxt " << data_prototxt;
}

string save_feature_leveldb_name(argv[++arg_pos]);
leveldb::DB* db;
leveldb::Options options;
options.error_if_exists = true;
options.create_if_missing = true;
options.write_buffer_size = 268435456;
LOG(INFO) << "Opening leveldb " << argv[3];
LOG(INFO) << "Opening leveldb " << save_feature_leveldb_name;
leveldb::Status status = leveldb::DB::Open(
options, save_feature_leveldb_name.c_str(), &db);
CHECK(status.ok()) << "Failed to open leveldb " << save_feature_leveldb_name;

int num_mini_batches = atoi(argv[++arg_pos]);

LOG(ERROR)<< "Extacting Features";
DataLayer<Dtype> data_layer(data_layer_param);
vector<Blob<Dtype>*> bottom_vec_that_data_layer_does_not_need_;
vector<Blob<Dtype>*> top_vec;
data_layer.Forward(bottom_vec_that_data_layer_does_not_need_, &top_vec);
int batch_index = 0;
int image_index = 0;

Datum datum;
leveldb::WriteBatch* batch = new leveldb::WriteBatch();
const int max_key_str_length = 100;
char key_str[max_key_str_length];
int num_bytes_of_binary_code = sizeof(Dtype);
// TODO: DataLayer seem to rotate from the last record to the first
// how to judge that all the data record have been enumerated?
while (top_vec.size()) { // data_layer still outputs data
LOG(ERROR)<< "Batch " << batch_index << " feature extraction";
feature_extraction_net->Forward(top_vec);
vector<Blob<float>* > input_vec;
int image_index = 0;
for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
feature_extraction_net->Forward(input_vec);
const shared_ptr<Blob<Dtype> > feature_blob =
feature_extraction_net->GetBlob(extract_feature_blob_name);

LOG(ERROR) << "Batch " << batch_index << " save extracted features";
int num_features = feature_blob->num();
int dim_features = feature_blob->count() / num_features;
for (int n = 0; n < num_features; ++n) {
Expand Down Expand Up @@ -165,17 +157,14 @@ int feature_extraction_pipeline(int argc, char** argv) {
batch = new leveldb::WriteBatch();
}
}
// write the last batch
if (image_index % 1000 != 0) {
db->Write(leveldb::WriteOptions(), batch);
LOG(ERROR) << "Extracted features of " << image_index << " query images.";
delete batch;
batch = new leveldb::WriteBatch();
}

data_layer.Forward(bottom_vec_that_data_layer_does_not_need_, &top_vec);
++batch_index;
} // while (top_vec.size()) {
} // for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index)
// write the last batch
if (image_index % 1000 != 0) {
db->Write(leveldb::WriteOptions(), batch);
LOG(ERROR) << "Extracted features of " << image_index << " query images.";
delete batch;
batch = new leveldb::WriteBatch();
}

delete batch;
delete db;
Expand Down

0 comments on commit dfe6380

Please sign in to comment.