Skip to content

Commit

Permalink
minor code simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
kthohr committed Oct 3, 2017
1 parent c0ab66f commit c010b58
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 59 deletions.
7 changes: 4 additions & 3 deletions src/de.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,12 @@ mcmc::de_int(const arma::vec& initial_vals, arma::cube& draws_out, std::function
prop_kernel_val = BIG_NEG_VAL;
}

double comp_val = prop_kernel_val - target_vals(i);

//

double comp_val = prop_kernel_val - target_vals(i);
double z = arma::as_scalar(arma::randu(1,1));

if (comp_val > temperature_j * std::log(arma::as_scalar(arma::randu(1)))) {
if (comp_val > temperature_j * std::log(z)) {
X.row(i) = X_prop;

target_vals(i) = prop_kernel_val;
Expand Down
27 changes: 7 additions & 20 deletions src/hmc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ mcmc::hmc_int(const arma::vec& initial_vals, arma::mat& draws_out, std::function

//

int n_accept = 0;
double comp_val, rand_val;
int n_accept = 0;
arma::vec krand(n_vals);

for (int jj = 0; jj < n_draws_keep + n_draws_burnin; jj++) {
Expand Down Expand Up @@ -151,9 +150,10 @@ mcmc::hmc_int(const arma::vec& initial_vals, arma::mat& draws_out, std::function

//

comp_val = - prop_U - prop_K + prev_U + prev_K;

if (comp_val > 0.0) { // the '> exp(0)' case; works around taking exp of big values and receiving an error
double comp_val = std::min(0.0,- prop_U - prop_K + prev_U + prev_K);
double z = arma::as_scalar(arma::randu(1));

if (z < std::exp(comp_val)) {
prev_draw = new_draw;
prev_U = prop_U;
prev_K = prop_K;
Expand All @@ -163,21 +163,8 @@ mcmc::hmc_int(const arma::vec& initial_vals, arma::mat& draws_out, std::function
n_accept++;
}
} else {
rand_val = arma::as_scalar(arma::randu(1));

if (rand_val < std::exp(comp_val)) {
prev_draw = new_draw;
prev_U = prop_U;
prev_K = prop_K;

if (jj >= n_draws_burnin) {
draws_out.row(jj - n_draws_burnin) = new_draw.t();
n_accept++;
}
} else {
if (jj >= n_draws_burnin) {
draws_out.row(jj - n_draws_burnin) = prev_draw.t();
}
if (jj >= n_draws_burnin) {
draws_out.row(jj - n_draws_burnin) = prev_draw.t();
}
}
}
Expand Down
24 changes: 6 additions & 18 deletions src/mala.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ mcmc::mala_int(const arma::vec& initial_vals, arma::mat& draws_out, std::functio
//

int n_accept = 0;
double comp_val, rand_val;
arma::vec krand(n_vals);

for (int jj = 0; jj < n_draws_keep + n_draws_burnin; jj++) {
Expand All @@ -132,9 +131,10 @@ mcmc::mala_int(const arma::vec& initial_vals, arma::mat& draws_out, std::functio

//

comp_val = prop_LP - prev_LP + mala_prop_adjustment(new_draw, prev_draw, step_size, vals_bound, mala_mean_fn, target_data);

if (comp_val > 0.0) { // the '> exp(0)' case; works around taking exp of big values and receiving an error
double comp_val = std::min(0.0, prop_LP - prev_LP + mala_prop_adjustment(new_draw, prev_draw, step_size, vals_bound, mala_mean_fn, target_data));
double z = arma::as_scalar(arma::randu(1));

if (z < std::exp(comp_val)) {
prev_draw = new_draw;
prev_LP = prop_LP;

Expand All @@ -143,20 +143,8 @@ mcmc::mala_int(const arma::vec& initial_vals, arma::mat& draws_out, std::functio
n_accept++;
}
} else {
rand_val = arma::as_scalar(arma::randu(1));

if (rand_val < std::exp(comp_val)) {
prev_draw = new_draw;
prev_LP = prop_LP;

if (jj >= n_draws_burnin) {
draws_out.row(jj - n_draws_burnin) = new_draw.t();
n_accept++;
}
} else {
if (jj >= n_draws_burnin) {
draws_out.row(jj - n_draws_burnin) = prev_draw.t();
}
if (jj >= n_draws_burnin) {
draws_out.row(jj - n_draws_burnin) = prev_draw.t();
}
}
}
Expand Down
24 changes: 6 additions & 18 deletions src/rwmh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ mcmc::rwmh_int(const arma::vec& initial_vals, arma::mat& draws_out, std::functio
//

int n_accept = 0;
double comp_val, rand_val;
arma::vec krand(n_vals);

for (int jj = 0; jj < n_draws_keep + n_draws_burnin; jj++) {
Expand All @@ -104,9 +103,10 @@ mcmc::rwmh_int(const arma::vec& initial_vals, arma::mat& draws_out, std::functio

//

comp_val = prop_LP - prev_LP;

if (comp_val > 0.0) { // the '> exp(0)' case; works around taking exp of big values and receiving an error
double comp_val = std::min(0.0,prop_LP - prev_LP);
double z = arma::as_scalar(arma::randu(1));

if (z < std::exp(comp_val)) {
prev_draw = new_draw;
prev_LP = prop_LP;

Expand All @@ -115,20 +115,8 @@ mcmc::rwmh_int(const arma::vec& initial_vals, arma::mat& draws_out, std::functio
n_accept++;
}
} else {
rand_val = arma::as_scalar(arma::randu(1));

if (rand_val < std::exp(comp_val)) {
prev_draw = new_draw;
prev_LP = prop_LP;

if (jj >= n_draws_burnin) {
draws_out.row(jj - n_draws_burnin) = new_draw.t();
n_accept++;
}
} else {
if (jj >= n_draws_burnin) {
draws_out.row(jj - n_draws_burnin) = prev_draw.t();
}
if (jj >= n_draws_burnin) {
draws_out.row(jj - n_draws_burnin) = prev_draw.t();
}
}
}
Expand Down

0 comments on commit c010b58

Please sign in to comment.