Skip to content

Commit

Permalink
interim
Browse files Browse the repository at this point in the history
  • Loading branch information
Moritz Freidank committed Oct 19, 2017
1 parent 7b88b1f commit d47192d
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions docs/source/experiments/compute_ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
sys.path.insert(0, path_join(SCRIPT_PATH, "..", "..", ".."))

from pysgmcmc.samplers.relativistic_sghmc import RelativisticSGHMCSampler
from pysgmcmc.samplers.sghmc import SGHMCSampler
from pysgmcmc.samplers.sgld import SGLDSampler

from pysgmcmc.diagnostics.sample_chains import PYSGMCMCTrace
from pysgmcmc.diagnostics.objective_functions import (
banana_log_likelihood,
Expand All @@ -26,7 +29,7 @@
def main():
parser = argparse.ArgumentParser(
description="Small script to study the relationship between stepsize "
"of relativistic sghmc and effective sample sizes (ESS) on "
"of a sampler and effective sample sizes (ESS) on "
"on four different benchmarks."
)

Expand All @@ -37,6 +40,13 @@ def main():
"For reference, see: http://proceedings.mlr.press/v54/lu17b/lu17b.pdf.",
)

parser.add_argument(
"--sampler",
help="Sampler to study.",
default="RelativisticSGHMC",
action="store", dest="sampler"
)

parser.add_argument(
"--n-iterations",
help="Number of total iterations to perform for each stepsize",
Expand Down Expand Up @@ -177,6 +187,16 @@ def extract_samples(sampler, n_samples=1000, keep_every=10):
assert args.stepsize_min >= 0.0, "--stepsize-min must be >= 0.0"
assert args.stepsize_step > 0, "--stepsize-increment must be > 0.0"

samplers = {
"RelativisticSGHMC": RelativisticSGHMCSampler,
"SGHMC": SGHMCSampler,
"SGLD": SGLDSampler,
}

assert args.sampler in samplers

sampler_fun = samplers[args.sampler]

if args.stepsize is None:
stepsizes = np.arange(
args.stepsize_min, args.stepsize_max, args.stepsize_step
Expand All @@ -202,7 +222,7 @@ def extract_samples(sampler, n_samples=1000, keep_every=10):
params = [tf.Variable(0., dtype=tf.float32, name="x")]
varnames = ["x"]

sampler = RelativisticSGHMCSampler(
sampler = sampler_fun(
epsilon=stepsize,
params=params,
cost_fun=cost_function(function),
Expand Down

0 comments on commit d47192d

Please sign in to comment.