Skip to content

Commit

Permalink
Make sample_size argument mandatory in monte_carlo_variational_loss.
Browse files Browse the repository at this point in the history
The current behavior is actually broken -- it defaults to None, but None is not a valid sample size. This was an oversight in the previous CL.

PiperOrigin-RevId: 257815557
  • Loading branch information
davmre authored and tensorflower-gardener committed Jul 12, 2019
1 parent 4fe34ae commit 9b90a00
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 5 additions & 3 deletions tensorflow_probability/python/vi/csiszar_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,8 +790,8 @@ def symmetrized_csiszar_function(logu, csiszar_function, name=None):

def monte_carlo_variational_loss(target_log_prob_fn,
surrogate_posterior,
sample_size=1,
discrepancy_fn=kl_reverse,
sample_size=None,
use_reparametrization=None,
seed=None,
name=None):
Expand Down Expand Up @@ -831,12 +831,14 @@ def monte_carlo_variational_loss(target_log_prob_fn,
this is to use `tfp.util.DeferredTensor` to represent any parameters
defined as transformations of unconstrained variables, so that the
transformations execute at runtime instead of at distribution creation.
sample_size: Integer scalar number of Monte Carlo samples used to
approximate the variational divergence. Larger values may stabilize
the optimization, but at higher cost per step in time and memory.
Default value: `1`.
discrepancy_fn: Python `callable` representing a Csiszar `f` function in
in log-space. That is, `discrepancy_fn(log(u)) = f(u)`, where `f` is
convex in `u`.
Default value: `tfp.vi.kl_reverse`.
sample_size: Integer scalar number of Monte Carlo samples used to
approximate the variational divergence.
use_reparametrization: Python `bool`. When `None` (the default),
automatically set to:
`surrogate_posterior.reparameterization_type ==
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_probability/python/vi/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def fit_surrogate_posterior(target_log_prob_fn,
num_steps,
trace_fn=_trace_loss,
variational_loss_fn=_reparameterized_elbo,
sample_size=10,
sample_size=1,
trainable_variables=None,
seed=None,
name='fit_surrogate_posterior'):
Expand Down Expand Up @@ -108,7 +108,7 @@ def fit_surrogate_posterior(target_log_prob_fn,
sample_size: Python `int` number of Monte Carlo samples to use
in estimating the variational divergence. Larger values may stabilize
the optimization, but at higher cost per step in time and memory.
Default value: 10.
Default value: `1`.
trainable_variables: Optional list of `tf.Variable` instances to optimize
with respect to. If `None`, defaults to the set of all variables accessed
during the computation of the variational bound, i.e., those defining
Expand Down

0 comments on commit 9b90a00

Please sign in to comment.