From ed8edb8f3cb72a056463921fee8cff5bd3f65a96 Mon Sep 17 00:00:00 2001 From: Dave Moore Date: Thu, 9 Apr 2020 11:49:58 -0700 Subject: [PATCH] Add tests for SampleParticles and fix some issues around batch shape. PiperOrigin-RevId: 305728146 --- .../experimental/mcmc/particle_filter.py | 201 ++++++++++-------- .../experimental/mcmc/particle_filter_test.py | 89 +++++++- 2 files changed, 187 insertions(+), 103 deletions(-) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter.py b/tensorflow_probability/python/experimental/mcmc/particle_filter.py index 424ec0a7a4..631198c898 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter.py @@ -17,6 +17,7 @@ import collections import functools +import numpy as np import tensorflow.compat.v2 as tf from tensorflow_probability.python.distributions import categorical from tensorflow_probability.python.distributions import distribution as distribution_lib @@ -36,7 +37,7 @@ ] -# TODO(davmre): add unit tests for SampleParticles. +# TODO(b/153467570): Move SampleParticles into `tfp.distributions`. class SampleParticles(distribution_lib.Distribution): """Like tfd.Sample, but inserts new rightmost batch (vs event) dim.""" @@ -70,38 +71,78 @@ def _event_shape_tensor(self, **kwargs): def _batch_shape(self): return tf.nest.map_structure( - lambda b: tensorshape_util.concatenate(b, [self.num_particles]), + lambda b: tensorshape_util.concatenate( # pylint: disable=g-long-lambda + [tf.get_static_value(self.num_particles)], b), self.distribution.batch_shape) def _batch_shape_tensor(self, **kwargs): return tf.nest.map_structure( - lambda b: prefer_static.concat([b, [self.num_particles]], axis=0), + lambda b: prefer_static.concat([[self.num_particles], b], axis=0), self.distribution.batch_shape_tensor(**kwargs)) def _log_prob(self, x, **kwargs): - return self.distribution.log_prob(x, **kwargs) + return self._call_log_measure('_log_prob', x, kwargs) def _log_cdf(self, x, **kwargs): - return self.distribution.log_cdf(x, **kwargs) + return self._call_log_measure('_log_cdf', x, kwargs) def _log_sf(self, x, **kwargs): - return self.distribution.log_sf(x, **kwargs) + return self._call_log_measure('_log_sf', x, kwargs) + + def _call_log_measure(self, attr, x, kwargs): + return getattr(self.distribution, attr)(x, **kwargs) # TODO(b/152797117): Override _sample_n, once it supports joint distributions. def sample(self, sample_shape=(), seed=None, name=None): with tf.name_scope(name or 'sample_particles'): sample_shape = prefer_static.concat([ - [self.num_particles], - dist_util.expand_to_vector(sample_shape)], axis=0) - x = self.distribution.sample(sample_shape, seed=seed) + dist_util.expand_to_vector(sample_shape), + [self.num_particles]], axis=0) + return self.distribution.sample(sample_shape, seed=seed) + + +# TODO(davmre): Replace this hack with a more efficient TF builtin. +def _batch_gather(params, indices, axis=0): + """Gathers a batch of indices from `params` along the given axis. - def move_particles_to_rightmost_batch_dim(x, event_shape): - ndims = prefer_static.rank_from_shape(prefer_static.shape(x)) - event_ndims = prefer_static.rank_from_shape(event_shape) - return dist_util.move_dimension(x, 0, ndims - event_ndims - 1) - return tf.nest.map_structure( - move_particles_to_rightmost_batch_dim, - x, self.distribution.event_shape_tensor()) + Args: + params: `Tensor` of shape `[d[0], d[1], ..., d[N - 1]]`. + indices: int `Tensor` of shape broadcastable to that of `params`. + axis: int `Tensor` dimension of `params` (and of the broadcast indices) to + gather over. + Returns: + result: `Tensor` of the same type and shape as `params`. + """ + params_rank = prefer_static.rank_from_shape(prefer_static.shape(params)) + indices_rank = prefer_static.rank_from_shape(prefer_static.shape(indices)) + params_with_axis_on_right = dist_util.move_dimension( + params, source_idx=axis, dest_idx=-1) + indices_with_axis_on_right = prefer_static.broadcast_to( + dist_util.move_dimension(indices, + source_idx=axis - (params_rank - indices_rank), + dest_idx=-1), + prefer_static.shape(params_with_axis_on_right)) + + result = tf.gather(params_with_axis_on_right, + indices_with_axis_on_right, + axis=params_rank - 1, + batch_dims=params_rank - 1) + return dist_util.move_dimension(result, source_idx=-1, dest_idx=axis) + + +def _dummy_indices_like(indices): + """Returns dummy indices ([0, 1, 2, ...]) with batch shape like `indices`.""" + indices_shape = prefer_static.shape(indices) + num_particles = indices_shape[0] + return tf.broadcast_to( + prefer_static.reshape( + prefer_static.range(num_particles), + prefer_static.concat([[num_particles], + prefer_static.ones([ + prefer_static.rank_from_shape( + indices_shape) - 1], dtype=np.int32)], + axis=0)), + indices_shape) def _gather_history(structure, step, num_steps): @@ -115,9 +156,9 @@ def _gather_history(structure, step, num_steps): def ess_below_threshold(unnormalized_log_weights, threshold=0.5): """Determines if the effective sample size is much less than num_particles.""" with tf.name_scope('ess_below_threshold'): - num_particles = prefer_static.shape(unnormalized_log_weights)[-1] - log_weights = tf.math.log_softmax(unnormalized_log_weights, axis=-1) - log_ess = -tf.math.reduce_logsumexp(2 * log_weights, axis=-1) + num_particles = prefer_static.shape(unnormalized_log_weights)[0] + log_weights = tf.math.log_softmax(unnormalized_log_weights, axis=0) + log_ess = -tf.math.reduce_logsumexp(2 * log_weights, axis=0) return log_ess < (prefer_static.log(num_particles) + prefer_static.log(threshold)) @@ -253,7 +294,7 @@ def infer_trajectories(observations, Returns: trajectories: a (structure of) Tensor(s) matching the latent state, each of shape - `concat([[num_timesteps, b1, ..., bN, num_particles], event_shape])`, + `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`, representing unbiased samples from the posterior distribution `p(latent_states | observations)`. step_log_marginal_likelihoods: float `Tensor` of shape @@ -369,18 +410,13 @@ def observation_fn(_, state): weighted_trajectories = reconstruct_trajectories(particles, parent_indices) # Resample all steps of the trajectories using the final weights. - final_log_weights = log_weights[-1, ...] - batch_rank = prefer_static.rank_from_shape( - prefer_static.shape(final_log_weights)[:-1]) - final_indices = dist_util.move_dimension( - categorical.Categorical(final_log_weights).sample( - num_particles, seed=seed), source_idx=0, dest_idx=-1) - final_indices_tiled_over_time = tf.broadcast_to( - final_indices, prefer_static.shape(parent_indices)) + resample_indices = categorical.Categorical( + dist_util.move_dimension( + log_weights[-1, ...], + source_idx=0, + dest_idx=-1)).sample(num_particles, seed=seed) trajectories = tf.nest.map_structure( - lambda x: tf.gather( # pylint: disable=g-long-lambda - x, final_indices_tiled_over_time, - axis=batch_rank + 1, batch_dims=batch_rank + 1), + lambda x: _batch_gather(x, resample_indices, axis=1), weighted_trajectories) return trajectories, step_log_marginal_likelihoods @@ -416,18 +452,18 @@ def particle_filter(observations, Returns: particles: a (structure of) Tensor(s) matching the latent state, each of shape - `concat([[num_timesteps, b1, ..., bN, num_particles], event_shape])`, + `concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`, representing (possibly weighted) samples from the series of filtering distributions `p(latent_states[t] | observations[:t])`. log_weights: `float` `Tensor` of shape - `[num_timesteps, b1, ..., bN, num_particles]`, such that + `[num_timesteps, num_particles, b1, ..., bN]`, such that `log_weights[t, :]` are the logarithms of normalized importance weights (such that `exp(reduce_logsumexp(log_weights), axis=-1) == 1.`) of the particles at time `t`. These may be used in conjunction with `particles` to compute expectations under the series of filtering distributions. parent_indices: `int` `Tensor` of shape - `[num_timesteps, b1, ..., bN, num_particles]`, + `[num_timesteps, num_particles, b1, ..., bN]`, such that `parent_indices[t, k]` gives the index of the particle at time `t - 1` that the `k`th particle at time `t` is immediately descended from. See also @@ -460,9 +496,18 @@ def particle_filter(observations, None if initial_state_proposal is None else lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda initial_state_proposal, num_particles)) - log_uniform_weights = ( - prefer_static.zeros([num_particles], dtype=tf.float32) - - prefer_static.log(num_particles)) + + # Initially the particles all have the same weight, `1. / num_particles`. + broadcast_batch_shape = tf.convert_to_tensor( + functools.reduce( + prefer_static.broadcast_shape, + tf.nest.flatten(initial_state_prior.batch_shape_tensor()), + []), dtype=tf.int32) + log_uniform_weights = prefer_static.zeros( + prefer_static.concat([ + [num_particles], + broadcast_batch_shape], axis=0), + dtype=tf.float32) - prefer_static.log(num_particles) # Initialize from the prior, and incorporate the first observation. initial_step_results = _filter_one_step( @@ -573,19 +618,12 @@ def _update_loop_variables(step, history_is_empty = (tf.is_tensor(state_history) and state_history.shape[0] == 0) if not history_is_empty: - batch_shape = prefer_static.shape( - current_step_results.parent_indices)[1:-1] - batch_rank = prefer_static.rank_from_shape(batch_shape) - # Permute the particles from previous steps to match the current resampled # indices, so that the state history reflects coherent trajectories. - def update_state_history_to_use_current_indices(x): - return tf.gather(x[-1:], - current_step_results.parent_indices, - axis=batch_rank + 1, - batch_dims=batch_rank) resampled_state_history = tf.nest.map_structure( - update_state_history_to_use_current_indices, + lambda x: _batch_gather(x[1:], # pylint: disable=g-long-lambda + current_step_results.parent_indices, + axis=1), state_history) # Update the history by concat'ing the carried-forward elements with the @@ -614,7 +652,7 @@ def _filter_one_step(step, """Advances the particle filter by a single time step.""" with tf.name_scope('filter_one_step'): seed = SeedStream(seed, 'filter_one_step') - num_particles = prefer_static.shape(log_weights)[-1] + num_particles = prefer_static.shape(log_weights)[0] proposed_particles, proposal_log_weights = _propose_with_log_weights( step=step - 1, @@ -629,14 +667,11 @@ def _filter_one_step(step, proposal_log_weights + observation_log_weights) step_log_marginal_likelihood = tf.math.reduce_logsumexp( - unnormalized_log_weights, axis=-1) - log_weights = (unnormalized_log_weights - - step_log_marginal_likelihood[..., tf.newaxis]) + unnormalized_log_weights, axis=0) + log_weights = (unnormalized_log_weights - step_log_marginal_likelihood) # Adaptive resampling: resample particles iff the specified criterion. - do_resample = tf.convert_to_tensor( - resample_criterion_fn(unnormalized_log_weights) - )[..., tf.newaxis] # Broadcast over particles. + do_resample = resample_criterion_fn(unnormalized_log_weights) # Some batch elements may require resampling and others not, so # we first do the resampling for all elements, then select whether to use @@ -647,9 +682,7 @@ def _filter_one_step(step, # for statistical (not computational) purposes, so this isn't a dealbreaker. resampled_particles, resample_indices = _resample( proposed_particles, log_weights, seed=seed) - dummy_indices = tf.broadcast_to( - prefer_static.range(num_particles), - prefer_static.shape(resample_indices)) + uniform_weights = (prefer_static.zeros_like(log_weights) - prefer_static.log(num_particles)) (resampled_particles, @@ -657,7 +690,9 @@ def _filter_one_step(step, log_weights) = tf.nest.map_structure( lambda r, p: prefer_static.where(do_resample, r, p), (resampled_particles, resample_indices, uniform_weights), - (proposed_particles, dummy_indices, log_weights)) + (proposed_particles, + _dummy_indices_like(resample_indices), + log_weights)) return ParticleFilterStepResults( particles=resampled_particles, @@ -713,26 +748,18 @@ def _compute_observation_log_weights(step, Args: step: int `Tensor` current step. particles: Nested structure of `Tensor`s, each of shape - `concat([[b1, ..., bN], [num_particles], latent_part_event_shape])`, where + `concat([[num_particles, b1, ..., bN], latent_part_event_shape])`, where `b1, ..., bN` are optional batch dimensions. observation: Nested structure of `Tensor`s, each of shape `concat([[b1, ..., bN], observation_part_event_shape])` where `b1, ..., bN` are optional batch dimensions. observation_fn: callable, producing a distribution over `observation`s. Returns: - log_weights: `Tensor` of shape `concat([[b1, ..., bN], [num_particles]])`. + log_weights: `Tensor` of shape `concat([num_particles, b1, ..., bN])`. """ with tf.name_scope('compute_observation_log_weights'): observation_dist = observation_fn(step, particles) - def _add_right_batch_dim(obs, event_shape): - ndims = prefer_static.rank_from_shape(prefer_static.shape(obs)) - event_ndims = prefer_static.rank_from_shape(event_shape) - return tf.expand_dims(obs, ndims - event_ndims) - observation_broadcast_over_particles = tf.nest.map_structure( - _add_right_batch_dim, - observation, - observation_dist.event_shape_tensor()) - return observation_dist.log_prob(observation_broadcast_over_particles) + return observation_dist.log_prob(observation) def _resample(particles, log_weights, seed=None): @@ -740,26 +767,24 @@ def _resample(particles, log_weights, seed=None): Args: particles: Nested structure of `Tensor`s each of shape - `[b1, ..., bN, num_particles, ...]`, where + `[num_particles, b1, ..., bN, ...]`, where `b1, ..., bN` are optional batch dimensions. - log_weights: float `Tensor` of shape `[b1, ..., bN, num_particles]`, where + log_weights: float `Tensor` of shape `[num_particles, b1, ..., bN]`, where `b1, ..., bN` are optional batch dimensions. seed: Python `int` random seed. Returns: resampled_particles: Nested structure of `Tensor`s, matching `particles`. - resample_indices: int `Tensor` of shape `[b1, ..., bN, num_particles]`. + resample_indices: int `Tensor` of shape `[num_particles, b1, ..., bN]`. """ with tf.name_scope('resample'): - weights_shape = prefer_static.shape(log_weights) - batch_shape, num_particles = weights_shape[:-1], weights_shape[-1] - batch_rank = prefer_static.rank_from_shape(batch_shape) - - resample_indices = dist_util.move_dimension( - categorical.Categorical(log_weights).sample(num_particles, seed=seed), - 0, -1) + num_particles = prefer_static.shape(log_weights)[0] + resample_indices = categorical.Categorical( + dist_util.move_dimension( + log_weights, + source_idx=0, + dest_idx=-1)).sample(num_particles, seed=seed) resampled_particles = tf.nest.map_structure( - lambda x: tf.gather( # pylint: disable=g-long-lambda - x, resample_indices, axis=batch_rank, batch_dims=batch_rank), + lambda x: _batch_gather(x, resample_indices, axis=0), particles) return resampled_particles, resample_indices @@ -767,22 +792,14 @@ def _resample(particles, log_weights, seed=None): def reconstruct_trajectories(particles, parent_indices, name=None): """Reconstructs the ancestor trajectory that generated each final particle.""" with tf.name_scope(name or 'reconstruct_trajectories'): - indices_shape = prefer_static.shape(parent_indices) - batch_shape, num_trajectories = indices_shape[1:-1], indices_shape[-1] - batch_rank = prefer_static.rank_from_shape(batch_shape) - # Walk backwards to compute the ancestor of each final particle at time t. - final_indices = tf.broadcast_to( - tf.range(0, num_trajectories), indices_shape[1:]) + final_indices = _dummy_indices_like(parent_indices[-1]) ancestor_indices = tf.scan( - fn=lambda ancestor, parent: tf.gather( # pylint: disable=g-long-lambda - parent, ancestor, axis=batch_rank, batch_dims=batch_rank), + fn=lambda ancestor, parent: _batch_gather(parent, ancestor, axis=0), elems=parent_indices[1:], initializer=final_indices, reverse=True) ancestor_indices = tf.concat([ancestor_indices, [final_indices]], axis=0) return tf.nest.map_structure( - lambda part: tf.gather(part, ancestor_indices, # pylint: disable=g-long-lambda - axis=batch_rank + 1, batch_dims=batch_rank + 1), - particles) + lambda part: _batch_gather(part, ancestor_indices, axis=1), particles) diff --git a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py index d5ed80a34b..749381f06d 100644 --- a/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py +++ b/tensorflow_probability/python/experimental/mcmc/particle_filter_test.py @@ -21,13 +21,80 @@ import numpy as np import tensorflow.compat.v2 as tf import tensorflow_probability as tfp +from tensorflow_probability.python.experimental.mcmc.particle_filter import SampleParticles from tensorflow_probability.python.internal import prefer_static +from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.internal import test_util + tfb = tfp.bijectors tfd = tfp.distributions +@test_util.test_all_tf_execution_regimes +class SampleParticlesTest(test_util.TestCase): + + def test_sample_particles_works_with_joint_distributions(self): + num_particles = 3 + jd = tfd.JointDistributionNamed({'x': tfd.Normal(0., 1.)}) + sp = SampleParticles(jd, num_particles=num_particles) + + # Check that SampleParticles has the correct shapes. + self.assertAllEqualNested(jd.event_shape, sp.event_shape) + self.assertAllEqualNested( + *self.evaluate((jd.event_shape_tensor(), sp.event_shape_tensor()))) + self.assertAllEqualNested( + tf.nest.map_structure( + lambda x: np.concatenate([[num_particles], x], axis=0), + jd.batch_shape), + tf.nest.map_structure(tensorshape_util.as_list, sp.batch_shape)) + self.assertAllEqualNested( + *self.evaluate( + (tf.nest.map_structure( + lambda x: tf.concat([[num_particles], x], axis=0), + jd.batch_shape_tensor()), + sp.batch_shape_tensor()))) + + # Check that sample and log-prob work, and that we can take the log-prob + # of a sample. + x = self.evaluate(sp.sample()) + lp = self.evaluate(sp.log_prob(x)) + self.assertAllEqual( + [part.shape for part in tf.nest.flatten(x)], [[num_particles]]) + self.assertAllEqual( + [part.shape for part in tf.nest.flatten(lp)], [[num_particles]]) + + def test_sample_particles_works_with_batch_and_event_shape(self): + num_particles = 3 + d = tfd.MultivariateNormalDiag(loc=tf.zeros([2, 4]), + scale_diag=tf.ones([2, 4])) + sp = SampleParticles(d, num_particles=num_particles) + + # Check that SampleParticles has the correct shapes. + self.assertAllEqual(sp.event_shape, d.event_shape) + self.assertAllEqual(sp.batch_shape, + np.concatenate([[num_particles], + d.batch_shape], axis=0)) + + # Draw a sample, combining sample shape, batch shape, num_particles, *and* + # event_shape, and check that it has the correct shape, and that we can + # compute a log_prob with the correct shape. + sample_shape = [5, 1] + x = self.evaluate(sp.sample(sample_shape, seed=test_util.test_seed())) + self.assertAllEqual(x.shape, # [5, 3, 1, 2, 4] + np.concatenate([sample_shape, + [num_particles], + d.batch_shape, + d.event_shape], + axis=0)) + lp = self.evaluate(sp.log_prob(x)) + self.assertAllEqual(lp.shape, + np.concatenate([sample_shape, + [num_particles], + d.batch_shape], + axis=0)) + + @test_util.test_all_tf_execution_regimes class _ParticleFilterTest(test_util.TestCase): @@ -85,7 +152,7 @@ def test_batch_of_filters(self): def transition_fn(_, previous_state): return tfd.JointDistributionNamed({ 'position': tfd.Normal( - loc=previous_state['position'] +previous_state['velocity'], + loc=previous_state['position'] + previous_state['velocity'], scale=0.1), 'velocity': tfd.Normal(loc=previous_state['velocity'], scale=0.01)}) @@ -114,38 +181,38 @@ def observation_fn(_, state): seed=test_util.test_seed())) self.assertAllEqual(particles['position'].shape, - [num_timesteps] + batch_shape + [num_particles]) + [num_timesteps, num_particles] + batch_shape) self.assertAllEqual(particles['velocity'].shape, - [num_timesteps] + batch_shape + [num_particles]) + [num_timesteps, num_particles] + batch_shape) self.assertAllEqual(parent_indices.shape, - [num_timesteps] + batch_shape + [num_particles]) + [num_timesteps, num_particles] + batch_shape) self.assertAllEqual(step_log_marginal_likelihoods.shape, [num_timesteps] + batch_shape) self.assertAllClose( self.evaluate( tf.reduce_sum(tf.exp(log_weights) * - particles['position'], axis=-1)), + particles['position'], axis=1)), observed_positions, atol=0.1) velocity_means = tf.reduce_sum(tf.exp(log_weights) * - particles['velocity'], axis=-1) + particles['velocity'], axis=1) self.assertAllClose( self.evaluate(tf.reduce_mean(velocity_means, axis=0)), true_velocities, atol=0.05) # Uncertainty in velocity should decrease over time. velocity_stddev = self.evaluate( - tf.math.reduce_std(particles['velocity'], axis=-1)) + tf.math.reduce_std(particles['velocity'], axis=1)) self.assertAllLess((velocity_stddev[-1] - velocity_stddev[0]), 0.) trajectories = self.evaluate( tfp.experimental.mcmc.reconstruct_trajectories(particles, parent_indices)) - self.assertAllEqual([num_timesteps] + batch_shape + [num_particles], + self.assertAllEqual([num_timesteps, num_particles] + batch_shape, trajectories['position'].shape) - self.assertAllEqual([num_timesteps] + batch_shape + [num_particles], + self.assertAllEqual([num_timesteps, num_particles] + batch_shape, trajectories['velocity'].shape) # Verify that `infer_trajectories` also works on batches. @@ -157,9 +224,9 @@ def observation_fn(_, state): observation_fn=observation_fn, num_particles=num_particles, seed=test_util.test_seed())) - self.assertAllEqual([num_timesteps] + batch_shape + [num_particles], + self.assertAllEqual([num_timesteps, num_particles] + batch_shape, trajectories['position'].shape) - self.assertAllEqual([num_timesteps] + batch_shape + [num_particles], + self.assertAllEqual([num_timesteps, num_particles] + batch_shape, trajectories['velocity'].shape) self.assertAllEqual(step_log_marginal_likelihoods.shape, [num_timesteps] + batch_shape)