Skip to content

Commit

Permalink
Update blackbox predictor with new constructor (pytorch#10920)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#10920

Update the black box predictor and the related code to use the
constructor with PredictorConfig.

Reviewed By: highker

Differential Revision: D9516972

fbshipit-source-id: fbd7ece934d527e17dc6bcc740b4e67e778afa1d
  • Loading branch information
Yi Cheng authored and facebook-github-bot committed Aug 29, 2018
1 parent 56539f5 commit c99a143
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 13 deletions.
58 changes: 46 additions & 12 deletions caffe2/predictor/predictor_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace {
// We don't use the getNet() from predictor_utils.cc here because that file
// has additional dependencies that we want to avoid bringing in, to keep the
// binary size as small as possible.
const NetDef& getNet(const MetaNetDef& def, const std::string& name) {
static const NetDef& getNet(const MetaNetDef& def, const std::string& name) {
for (const auto& n : def.nets()) {
if (n.key() == name) {
return n.value();
Expand All @@ -19,7 +19,7 @@ const NetDef& getNet(const MetaNetDef& def, const std::string& name) {
CAFFE_THROW("Net not found: ", name);
}

const ::google::protobuf::RepeatedPtrField<::std::string>& getBlobs(
static const ::google::protobuf::RepeatedPtrField<::std::string>& getBlobs(
const MetaNetDef& def,
const std::string& name) {
for (const auto& b : def.blobs()) {
Expand All @@ -30,26 +30,60 @@ const ::google::protobuf::RepeatedPtrField<::std::string>& getBlobs(
CAFFE_THROW("Blob not found: ", name);
}

static std::string combine(const std::string& str, const std::string& name) {
if (name.empty()) {
return std::string(str);
}
return str + "_" + name;
}

static std::string getNamedPredictNet(const string& name) {
return combine(PredictorConsts::default_instance().predict_net_type(), name);
}

static std::string getNamedInitNet(const string& name) {
return combine(
PredictorConsts::default_instance().predict_init_net_type(), name);
}

static std::string getNamedInputs(const string& name) {
return combine(PredictorConsts::default_instance().inputs_blob_type(), name);
}

static std::string getNamedOutputs(const string& name) {
return combine(PredictorConsts::default_instance().outputs_blob_type(), name);
}

static std::string getNamedParams(const string& name) {
return combine(
PredictorConsts::default_instance().parameters_blob_type(), name);
}

} // namespace

PredictorConfig
makePredictorConfig(const MetaNetDef& def, Workspace* parent, bool run_init) {
const auto& init_net =
getNet(def, PredictorConsts::default_instance().global_init_net_type());
const auto& run_net =
getNet(def, PredictorConsts::default_instance().predict_net_type());
PredictorConfig makePredictorConfig(
const MetaNetDef& def,
Workspace* parent,
bool run_init,
const std::string& net_name) {
const auto& init_net = getNet(def, getNamedInitNet(net_name));
const auto& run_net = getNet(def, getNamedPredictNet(net_name));
auto config = makePredictorConfig(init_net, run_net, parent, run_init);
const auto& inputs =
getBlobs(def, PredictorConsts::default_instance().inputs_blob_type());
const auto& inputs = getBlobs(def, getNamedInputs(net_name));
for (const auto& input : inputs) {
config.input_names.emplace_back(input);
}

const auto& outputs =
getBlobs(def, PredictorConsts::default_instance().outputs_blob_type());
const auto& outputs = getBlobs(def, getNamedOutputs(net_name));
for (const auto& output : outputs) {
config.output_names.emplace_back(output);
}

const auto& params = getBlobs(def, getNamedParams(net_name));
for (const auto& param : params) {
config.parameter_names.emplace_back(param);
}

return config;
}

Expand Down
3 changes: 2 additions & 1 deletion caffe2/predictor/predictor_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ CAFFE2_API Workspace makeWorkspace(std::shared_ptr<PredictorParameters> paramete
CAFFE2_API PredictorConfig makePredictorConfig(
const MetaNetDef& net,
Workspace* parent = nullptr,
bool run_init = true);
bool run_init = true,
const std::string& net_name = "");

CAFFE2_API PredictorConfig makePredictorConfig(
const NetDef& init_net,
Expand Down
51 changes: 51 additions & 0 deletions caffe2/predictor/predictor_utils.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
#include "caffe2/predictor/predictor_utils.h"
#include "caffe2/predictor/predictor_config.h"

#include "caffe2/core/blob.h"
#include "caffe2/core/logging.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/proto/predictor_consts.pb.h"
#include "caffe2/utils/proto_utils.h"

CAFFE2_DEFINE_bool(
caffe2_predictor_claim_tensor_memory,
true,
"If false, then predictor will not claim tensor memory"
"otherwise when tensor is shrinked to a size smaller than current size "
"by FLAGS_caffe2_max_keep_on_shrink_memory, the memory will be claimed.");

namespace caffe2 {
namespace predictor_utils {

Expand Down Expand Up @@ -79,4 +87,47 @@ std::unique_ptr<MetaNetDef> runGlobalInitialization(
}

} // namespace predictor_utils

void removeExternalBlobs(
const std::vector<std::string>& input_blobs,
const std::vector<std::string>& output_blobs,
Workspace* ws) {
for (const auto& blob : input_blobs) {
ws->RemoveBlob(blob);
}
for (const auto& blob : output_blobs) {
ws->RemoveBlob(blob);
}
}

PredictorConfig makePredictorConfig(
const string& db_type,
const string& db_path) {
// TODO: Remove this flags once Predictor accept PredictorConfig as
// constructors. These comes are copied temporarly from the Predictor.
if (FLAGS_caffe2_predictor_claim_tensor_memory) {
if (FLAGS_caffe2_max_keep_on_shrink_memory == LLONG_MAX) {
FLAGS_caffe2_max_keep_on_shrink_memory = 8 * 1024 * 1024;
}
}
auto dbReader =
make_unique<db::DBReader>(db::CreateDB(db_type, db_path, db::READ));
auto ws = std::make_shared<Workspace>();
auto net_def =
predictor_utils::runGlobalInitialization(std::move(dbReader), ws.get());
auto config = makePredictorConfig(*net_def, ws.get());
config.ws = ws;
const auto& init_net = predictor_utils::getNet(
*net_def, PredictorConsts::default_instance().predict_init_net_type());
CAFFE_ENFORCE(config.ws->RunNetOnce(init_net));
config.ws->RemoveBlob(
PredictorConsts::default_instance().predictor_dbreader());
// Input and output blobs should never be allocated in the master workspace
// since we'll end up with race-conditions due to these being shared among
// predictor threads / TL workspaces. Safely handle against globalInitNet
// creating them in the master.
removeExternalBlobs(config.input_names, config.output_names, config.ws.get());
return config;
}

} // namespace caffe2
10 changes: 10 additions & 0 deletions caffe2/predictor/predictor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,14 @@ CAFFE2_API std::unique_ptr<MetaNetDef> runGlobalInitialization(
Workspace* master);

} // namespace predictor_utils

PredictorConfig makePredictorConfig(
const string& db_type,
const string& db_path);

void removeExternalBlobs(
const std::vector<std::string>& input_blobs,
const std::vector<std::string>& output_blobs,
Workspace* ws);

} // namespace caffe2

0 comments on commit c99a143

Please sign in to comment.