Skip to content

Commit

Permalink
update HMC with preconditioning
Browse files Browse the repository at this point in the history
  • Loading branch information
kthohr committed Oct 23, 2017
1 parent 2eaf5a3 commit 64fd61f
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/hmc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
################################################################################*/

/*
* Hamiltonian Monte Carlo
* Hamiltonian Monte Carlo (HMC)
*/

#include "mcmc.hpp"
Expand All @@ -31,7 +31,7 @@ mcmc::hmc_int(const arma::vec& initial_vals, arma::mat& draws_out, std::function
const int n_vals = initial_vals.n_elem;

//
// MALA settings
// HMC settings

algo_settings settings;

Expand All @@ -44,6 +44,10 @@ mcmc::hmc_int(const arma::vec& initial_vals, arma::mat& draws_out, std::function

const double step_size = settings.hmc_step_size;

const arma::mat precond_matrix = (settings.hmc_precond_mat.n_elem == n_vals*n_vals) ? settings.hmc_precond_mat : arma::eye(n_vals,n_vals);
const arma::mat inv_precond_matrix = arma::inv(precond_matrix);
const arma::mat sqrt_precond_matrix = arma::chol(precond_matrix,"lower");

const bool vals_bound = settings.vals_bound;

const arma::vec lower_bounds = settings.lower_bounds;
Expand Down Expand Up @@ -118,7 +122,7 @@ mcmc::hmc_int(const arma::vec& initial_vals, arma::mat& draws_out, std::function
arma::vec prev_draw = first_draw;
arma::vec new_draw = first_draw;

arma::vec new_mntm = arma::randn(n_vals,1);
arma::vec new_mntm = arma::randn(n_vals,1);

//

Expand All @@ -127,14 +131,14 @@ mcmc::hmc_int(const arma::vec& initial_vals, arma::mat& draws_out, std::function

for (int jj = 0; jj < n_draws_keep + n_draws_burnin; jj++) {

krand.randn();
krand = sqrt_precond_matrix*arma::randn(n_vals,1);
prev_K = arma::dot(krand,inv_precond_matrix*krand) / 2.0;

new_mntm = mntm_update_fn(prev_draw,krand,target_data,step_size,nullptr); // half-step
prev_K = arma::dot(krand,krand) / 2.0;

//

new_draw = prev_draw + step_size*new_mntm;
new_draw = prev_draw + step_size*inv_precond_matrix*new_mntm;

prop_U = - box_log_kernel(new_draw, nullptr, target_data);

Expand All @@ -146,7 +150,7 @@ mcmc::hmc_int(const arma::vec& initial_vals, arma::mat& draws_out, std::function

new_mntm = mntm_update_fn(new_draw,new_mntm,target_data,step_size,nullptr); // half-step

prop_K = arma::dot(new_mntm,new_mntm) / 2.0;
prop_K = arma::dot(new_mntm,inv_precond_matrix*new_mntm) / 2.0;

//

Expand Down Expand Up @@ -183,7 +187,7 @@ mcmc::hmc_int(const arma::vec& initial_vals, arma::mat& draws_out, std::function
}

if (settings_inp) {
settings_inp->hmc_accept_rate = (double) n_accept / (double) n_draws_keep;
settings_inp->hmc_accept_rate = static_cast<double>(n_accept) / static_cast<double>(n_draws_keep);
}

//
Expand Down

0 comments on commit 64fd61f

Please sign in to comment.