Skip to content

Commit

Permalink
Changed image preprocessing and CPU_ONLY macros
Browse files Browse the repository at this point in the history
Fixed api "std" parameter
Added api "mean" parameter
The 'get_gpu_ids' function is no longer defined when compiling with CPU_ONLY
  • Loading branch information
Julien CHICHA committed Jun 7, 2018
1 parent 5d8f9fa commit 40d7944
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/backends/caffe2/caffe2inputconns.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ namespace dd

int get_tensor_test(caffe2::TensorCPU &tensor, int num = -1);

float _std = 255.0f;
float _std = 1.0f;
};
}

Expand Down
5 changes: 4 additions & 1 deletion src/backends/caffe2/caffe2inputimg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace dd {
void ImgCaffe2InputFileConn::init(const APIData &ad) {
ImgInputFileConn::init(ad);
if (ad.has("std"))
_std = ad.get("std").get<float>();
_std = ad.get("std").get<double>();
}

void ImgCaffe2InputFileConn::transform(const APIData &ad) {
Expand Down Expand Up @@ -74,6 +74,9 @@ namespace dd {

// Convert from NHWC uint8_t to NCHW float
it->convertTo(*it, CV_32F);
if (_has_mean_scalar) {
*it -= _mean;
}
cv::split(*it / _std, chan);
for (cv::Mat &ch : chan) {
std::memcpy(data, ch.data, channel_size);
Expand Down
25 changes: 13 additions & 12 deletions src/backends/caffe2/caffe2lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,12 @@ namespace dd {
Caffe2Lib<TInputConnectorStrategy,TOutputConnectorStrategy,TMLModel>::~Caffe2Lib() {
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy, class TMLModel>
std::vector<int> Caffe2Lib<TInputConnectorStrategy,TOutputConnectorStrategy,TMLModel>::
#ifdef CPU_ONLY
get_gpu_ids(const APIData &) const { return {}; }
#define UPDATE_GPU_STATE(ad)
#else

template <class TInputConnectorStrategy, class TOutputConnectorStrategy, class TMLModel>
std::vector<int> Caffe2Lib<TInputConnectorStrategy,TOutputConnectorStrategy,TMLModel>::
get_gpu_ids(const APIData &ad) const {
std::vector<int> ids;
try {
Expand All @@ -124,6 +125,14 @@ namespace dd {
}
return ids;
}

#define UPDATE_GPU_STATE(ad) \
if (ad.has("gpu")) { \
_state.set_is_gpu(ad.get("gpu").get<bool>()); \
if (_state.is_gpu() && ad.has("gpuid")) { \
_state.set_gpu_ids(get_gpu_ids(ad)); \
} \
}
#endif

template <class TInputConnectorStrategy, class TOutputConnectorStrategy, class TMLModel>
Expand Down Expand Up @@ -262,15 +271,7 @@ namespace dd {
}
}

// gpu
#ifndef CPU_ONLY
if (ad_mllib.has("gpu")) {
_state.set_is_gpu(ad_mllib.get("gpu").get<bool>());
if (_state.is_gpu() && ad_mllib.has("gpuid")) {
_state.set_gpu_ids(get_gpu_ids(ad_mllib));
}
}
#endif
UPDATE_GPU_STATE(ad_mllib);

if (_state.has_changed()) {
try {
Expand Down
2 changes: 2 additions & 0 deletions src/backends/caffe2/caffe2lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ namespace dd
void init_mllib(const APIData &ad);
void clear_mllib(const APIData &ad);

#ifndef CPU_ONLY
std::vector<int> get_gpu_ids(const APIData &ad) const;
#endif

int train(const APIData &ad, APIData &out);

Expand Down

0 comments on commit 40d7944

Please sign in to comment.