Skip to content

Commit

Permalink
Minor fixes to sharded distributions and MCMC kernel
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 396430940
  • Loading branch information
sharadmv authored and tensorflower-gardener committed Sep 13, 2021
1 parent bd917e0 commit f06522a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ def _sample_n(self, n, seed, **kwargs):
seed, self.experimental_shard_axis_names)
return self.distribution.sample(sample_shape=n, seed=seed, **kwargs)

def _variance(self):
return self.distribution.variance()

_log_prob = _implement_sharded_lp_fn('log_prob')
_unnormalized_log_prob = _implement_sharded_lp_fn('unnormalized_log_prob')

Expand Down
4 changes: 4 additions & 0 deletions tensorflow_probability/python/experimental/mcmc/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def one_step(self, current_state, previous_kernel_results, seed=None):
return self.inner_kernel.one_step(
current_state, previous_kernel_results, seed=seed)

@property
def parameters(self):
return self._parameters

@property
def inner_kernel(self):
return self._parameters['inner_kernel']
Expand Down

0 comments on commit f06522a

Please sign in to comment.