Skip to content

Commit

Permalink
Fixed input_layer to pass tests, added cat image to data to perform t…
Browse files Browse the repository at this point in the history
…he tests
  • Loading branch information
sguada committed Feb 17, 2014
1 parent 617c016 commit e8e3a1b
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 40 deletions.
Binary file added data/cat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
15 changes: 13 additions & 2 deletions include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,15 @@ class DataLayer : public Layer<Dtype> {
Blob<Dtype> data_mean_;
};

// This function is used to create a pthread that prefetches the data.
template <typename Dtype>
void* InputLayerPrefetch(void* layer_pointer);

template <typename Dtype>
class InputLayer : public Layer<Dtype> {
// The function used to perform prefetching.
friend void* InputLayerPrefetch<Dtype>(void* layer_pointer);

public:
explicit InputLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
Expand All @@ -358,12 +365,16 @@ class InputLayer : public Layer<Dtype> {
virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);

vector<std::pair<std::string, int> > lines_;
int lines_id_;
int datum_channels_;
int datum_height_;
int datum_width_;
int datum_size_;
bool biasterm_;
bool has_data_mean_;
pthread_t thread_;
shared_ptr<Blob<Dtype> > prefetch_data_;
shared_ptr<Blob<Dtype> > prefetch_label_;
Blob<Dtype> data_mean_;
};


Expand Down
114 changes: 76 additions & 38 deletions src/caffe/layers/input_layer.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
// Copyright 2014 Sergio Guadarrama
// Copyright 2013 Yangqing Jia

#include <stdint.h>
#include <leveldb/db.h>
#include <pthread.h>

#include <string>
#include <vector>
#include <iostream>
#include <fstream>

#include "caffe/layer.hpp"
#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp"
#include "caffe/filler.hpp"

using std::string;
using std::pair;

namespace caffe {

Expand All @@ -35,16 +37,19 @@ void* InputLayerPrefetch(void* layer_pointer) {
<< "set at the same time.";
}
// datum scales
const int channels = layer->bottom_channels_;
const int height = layer->bottom_height_;
const int width = layer->bottom_width_;
const int size = layer->bottom_size_;
const int channels = layer->datum_channels_;
const int height = layer->datum_height_;
const int width = layer->datum_width_;
const int size = layer->datum_size_;
const int lines_size = layer->lines_.size();
const Dtype* mean = layer->data_mean_.cpu_data();
for (int itemid = 0; itemid < batchsize; ++itemid) {
// get a blob
CHECK(layer->iter_);
CHECK(layer->iter_->Valid());
datum.ParseFromString(layer->iter_->value().ToString());
CHECK_GT(lines_size,layer->lines_id_);
if (!ReadImageToDatum(layer->lines_[layer->lines_id_].first,
layer->lines_[layer->lines_id_].second, &datum)) {
continue;
};
const string& data = datum.data();
if (cropsize) {
CHECK(data.size()) << "Image cropping only support uint8 data";
Expand Down Expand Up @@ -88,7 +93,7 @@ void* InputLayerPrefetch(void* layer_pointer) {
}
}
} else {
// we will prefer to use data() first, and then try float_data()
// Just copy the whole data
if (data.size()) {
for (int j = 0; j < size; ++j) {
top_data[itemid * size + j] =
Expand All @@ -104,11 +109,11 @@ void* InputLayerPrefetch(void* layer_pointer) {

top_label[itemid] = datum.label();
// go to the next iter
layer->iter_->Next();
if (!layer->iter_->Valid()) {
layer->lines_id_++;
if (layer->lines_id_ >= lines_size) {
// We have reached the end. Restart from the first.
DLOG(INFO) << "Restarting data prefetching from start.";
layer->iter_->SeekToFirst();
layer->lines_id_=0;
}
}

Expand All @@ -124,45 +129,78 @@ InputLayer<Dtype>::~InputLayer<Dtype>() {
template <typename Dtype>
void InputLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK_EQ(bottom.size(), 1) << "Input Layer takes a single blob as input.";
CHECK_EQ(top->size(), 1) << "Input Layer takes a single blob as output.";
CHECK_EQ(bottom.size(), 0) << "Input Layer takes no input blobs.";
CHECK_EQ(top->size(), 2) << "Input Layer takes two blobs as output.";
// Read the file with filenames and labels
LOG(INFO) << "Opening file " << this->layer_param_.source();
std::ifstream infile(this->layer_param_.source().c_str());
string filename;
int label;
while (infile >> filename >> label) {
lines_.push_back(std::make_pair(filename, label));
}

if (this->layer_param_.shuffle_data()) {
// randomly shuffle data
LOG(INFO) << "Shuffling data";
std::random_shuffle(lines_.begin(), lines_.end());
}
LOG(INFO) << "A total of " << lines_.size() << " images.";

lines_id_ = 0;
// Check if we would need to randomly skip a few data points
if (this->layer_param_.rand_skip()) {
unsigned int skip = rand() % this->layer_param_.rand_skip();
LOG(INFO) << "Skipping first " << skip << " data points.";
CHECK_GT(lines_.size(),skip) << "Not enought points to skip";
lines_id_ = skip;
}
// Read a data point, and use it to initialize the top blob.
Datum datum;
CHECK(ReadImageToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
&datum));
// image
int cropsize = this->layer_param_.cropsize();
if (cropsize > 0) {
(*top)[0]->Reshape(
this->layer_param_.batchsize(), bottom.channels(), cropsize, cropsize);
this->layer_param_.batchsize(), datum.channels(), cropsize, cropsize);
prefetch_data_.reset(new Blob<Dtype>(
this->layer_param_.batchsize(), datum.channels(), cropsize, cropsize));
} else {
(*top)[0]->Reshape(
this->layer_param_.batchsize(), bottom.channels(), bottom.height(),
bottom.width());
this->layer_param_.batchsize(), datum.channels(), datum.height(),
datum.width());
prefetch_data_.reset(new Blob<Dtype>(
this->layer_param_.batchsize(), datum.channels(), datum.height(),
datum.width()));
}
LOG(INFO) << "output data size: " << (*top)[0]->num() << ","
<< (*top)[0]->channels() << "," << (*top)[0]->height() << ","
<< (*top)[0]->width();
bottom_channels_ = bottom.channels();
bottom_height_ = bottom.height();
bottom_width_ = bottom.width();
bottom_size_ = bottom.channels() * bottom.height() * bottom.width();
CHECK_GT(bottom_height_, cropsize);
CHECK_GT(boottom_width_, cropsize);
// label
(*top)[1]->Reshape(this->layer_param_.batchsize(), 1, 1, 1);
prefetch_label_.reset(
new Blob<Dtype>(this->layer_param_.batchsize(), 1, 1, 1));
// datum size
datum_channels_ = datum.channels();
datum_height_ = datum.height();
datum_width_ = datum.width();
datum_size_ = datum.channels() * datum.height() * datum.width();
CHECK_GT(datum_height_, cropsize);
CHECK_GT(datum_width_, cropsize);
// check if we want to have mean
if (this->layer_param_.has_meanfile()) {
BlobProto blob_proto;
LOG(INFO) << "Loading mean file from" << this->layer_param_.meanfile();
ReadProtoFromBinaryFile(this->layer_param_.meanfile().c_str(), &blob_proto);
data_mean_.FromProto(blob_proto);
CHECK_EQ(data_mean_.num(), 1);
CHECK_EQ(data_mean_.channels(), bottom_channels_);
CHECK_EQ(data_mean_.height(), bottom_height_);
CHECK_EQ(data_mean_.width(), boottom_width_);
CHECK_EQ(data_mean_.channels(), datum_channels_);
CHECK_EQ(data_mean_.height(), datum_height_);
CHECK_EQ(data_mean_.width(), datum_width_);
} else {
// Intialize the data_mean with zeros
data_mean_.Reshape(1, bottom_channels_, bottom_height_, boottom_width_);
// Or if there is a bias_filler use it to initialize the data_mean
if (this->layer_param_.has_bias_filler()) {
shared_ptr<Filler<Dtype> > bias_filler(
GetFiller<Dtype>(this->layer_param_.bias_filler()));
bias_filler->Fill(&this->data_mean_);
}
// Simply initialize an all-empty mean.
data_mean_.Reshape(1, datum_channels_, datum_height_, datum_width_);
}
// Now, start the prefetch thread. Before calling prefetch, we make two
// cpu_data calls so that the prefetch thread does not accidentally make
Expand All @@ -172,7 +210,7 @@ void InputLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
prefetch_label_->mutable_cpu_data();
data_mean_.cpu_data();
DLOG(INFO) << "Initializing prefetch";
CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
CHECK(!pthread_create(&thread_, NULL, InputLayerPrefetch<Dtype>,
reinterpret_cast<void*>(this))) << "Pthread execution failed.";
DLOG(INFO) << "Prefetch initialized.";
}
Expand All @@ -188,7 +226,7 @@ void InputLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
memcpy((*top)[1]->mutable_cpu_data(), prefetch_label_->cpu_data(),
sizeof(Dtype) * prefetch_label_->count());
// Start a new prefetch thread
CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
CHECK(!pthread_create(&thread_, NULL, InputLayerPrefetch<Dtype>,
reinterpret_cast<void*>(this))) << "Pthread execution failed.";
}

Expand All @@ -205,7 +243,7 @@ void InputLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
prefetch_label_->cpu_data(), sizeof(Dtype) * prefetch_label_->count(),
cudaMemcpyHostToDevice));
// Start a new prefetch thread
CHECK(!pthread_create(&thread_, NULL, DataLayerPrefetch<Dtype>,
CHECK(!pthread_create(&thread_, NULL, InputLayerPrefetch<Dtype>,
reinterpret_cast<void*>(this))) << "Pthread execution failed.";
}

Expand Down
2 changes: 2 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ message LayerParameter {
// point would be set as rand_skip * rand(0,1). Note that rand_skip should not
// be larger than the number of keys in the leveldb.
optional uint32 rand_skip = 53 [ default = 0 ];

optional bool shuffle_data = 61 [default = true];
}

message LayerConnection {
Expand Down

0 comments on commit e8e3a1b

Please sign in to comment.