Skip to content

Commit

Permalink
Improve control/replace API
Browse files Browse the repository at this point in the history
  • Loading branch information
gbaydin committed Mar 8, 2018
1 parent 8b946d7 commit 47d7e4f
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 13 deletions.
5 changes: 3 additions & 2 deletions src/cpproblight/include/cpproblight.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#define VERSION "0.1.0"
#define GIT_BRANCH "master"
#define GIT_COMMIT_HASH "5987556"
#define GIT_COMMIT_HASH "8b946d7"

namespace cpproblight
{
Expand Down Expand Up @@ -65,7 +65,8 @@ namespace cpproblight
xt::xarray<double> sample(distributions::Distribution& distribution, const bool control, const bool replace, const std::string& address);
void observe(distributions::Distribution& distribution, xt::xarray<double> value, const std::string& address="");

void setDefault(bool control = true, bool replace = false);
void setDefaultControl(bool control = true);
void setDefaultReplace(bool replace = false);

xt::xarray<double> ProtocolTensorToXTensor(const PPLProtocol::ProtocolTensor* protocolTensor);

Expand Down
3 changes: 2 additions & 1 deletion src/cpproblight/include/cpproblight.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ namespace cpproblight
xt::xarray<double> sample(distributions::Distribution& distribution, const bool control, const bool replace, const std::string& address);
void observe(distributions::Distribution& distribution, xt::xarray<double> value, const std::string& address="");

void setDefault(bool control = true, bool replace = false);
void setDefaultControl(bool control = true);
void setDefaultReplace(bool replace = false);

xt::xarray<double> ProtocolTensorToXTensor(const PPLProtocol::ProtocolTensor* protocolTensor);

Expand Down
18 changes: 11 additions & 7 deletions src/cpproblight/src/cpproblight.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ namespace cpproblight
zmq::socket_t zmqSocket = zmq::socket_t(zmqContext, ZMQ_REP);
bool zmqSocketConnected = false;
flatbuffers::FlatBufferBuilder builder;
bool controlDefault = true;
bool replaceDefault = false;
bool defaultControl = true;
bool defaultReplace = false;

namespace distributions
{
Expand Down Expand Up @@ -204,12 +204,12 @@ namespace cpproblight
xt::xarray<double> sample(distributions::Distribution& distribution)
{
auto address = extractAddress();
return distribution.sample(controlDefault, replaceDefault, address);
return distribution.sample(defaultControl, defaultReplace, address);
}

xt::xarray<double> sample(distributions::Distribution& distribution, const std::string& address)
{
return distribution.sample(controlDefault, replaceDefault, address);
return distribution.sample(defaultControl, defaultReplace, address);
}

xt::xarray<double> sample(distributions::Distribution& distribution, const bool control, const bool replace)
Expand All @@ -233,10 +233,14 @@ namespace cpproblight
return distribution.observe(value, addr);
}

void setDefault(bool control, bool replace)
void setDefaultControl(bool control)
{
controlDefault = control;
replaceDefault = replace;
defaultControl = control;
}

void setDefaultReplace(bool replace)
{
defaultReplace = replace;
}

xt::xarray<double> ProtocolTensorToXTensor(const PPLProtocol::ProtocolTensor* protocolTensor)
Expand Down
9 changes: 6 additions & 3 deletions src/cpproblight/test/test_set_defaults_and_addresses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@ xt::xarray<double> forward(xt::xarray<double> observation)
auto prior_stddev = std::sqrt(5);
auto likelihood_stddev = std::sqrt(2);

cpproblight::setDefault(true, false); // control=true, replace=false
cpproblight::setDefaultControl(true);
cpproblight::setDefaultReplace(false);

auto normal1 = cpproblight::distributions::Normal(prior_mean, prior_stddev);
auto mu1 = cpproblight::sample(normal1, "normal1");
mu1 = cpproblight::sample(normal1, "normal1");

cpproblight::setDefault(true, true); // control=true, replace=true
cpproblight::setDefaultControl(true);
cpproblight::setDefaultReplace(true);

auto normal2 = cpproblight::distributions::Normal(mu1, prior_stddev);
auto mu2 = cpproblight::sample(normal2, "normal2");
mu2 = cpproblight::sample(normal2, "normal2");

cpproblight::setDefault(false, false); // control=false, replace=false
cpproblight::setDefaultControl(false);
cpproblight::setDefaultReplace(false);

auto normal3 = cpproblight::distributions::Normal(mu2, prior_stddev);
auto mu3 = cpproblight::sample(normal3, "normal3");
Expand Down

0 comments on commit 47d7e4f

Please sign in to comment.