Skip to content
This repository has been archived by the owner on Jan 23, 2022. It is now read-only.

Commit

Permalink
Add tests for SampleParticles and fix some issues around batch shape.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 305728146
  • Loading branch information
davmre authored and tensorflower-gardener committed Apr 9, 2020
1 parent 2fcf8af commit ed8edb8
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 103 deletions.
201 changes: 109 additions & 92 deletions tensorflow_probability/python/experimental/mcmc/particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import collections
import functools

import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.distributions import categorical
from tensorflow_probability.python.distributions import distribution as distribution_lib
Expand All @@ -36,7 +37,7 @@
]


# TODO(davmre): add unit tests for SampleParticles.
# TODO(b/153467570): Move SampleParticles into `tfp.distributions`.
class SampleParticles(distribution_lib.Distribution):
"""Like tfd.Sample, but inserts new rightmost batch (vs event) dim."""

Expand Down Expand Up @@ -70,38 +71,78 @@ def _event_shape_tensor(self, **kwargs):

def _batch_shape(self):
return tf.nest.map_structure(
lambda b: tensorshape_util.concatenate(b, [self.num_particles]),
lambda b: tensorshape_util.concatenate( # pylint: disable=g-long-lambda
[tf.get_static_value(self.num_particles)], b),
self.distribution.batch_shape)

def _batch_shape_tensor(self, **kwargs):
return tf.nest.map_structure(
lambda b: prefer_static.concat([b, [self.num_particles]], axis=0),
lambda b: prefer_static.concat([[self.num_particles], b], axis=0),
self.distribution.batch_shape_tensor(**kwargs))

def _log_prob(self, x, **kwargs):
return self.distribution.log_prob(x, **kwargs)
return self._call_log_measure('_log_prob', x, kwargs)

def _log_cdf(self, x, **kwargs):
return self.distribution.log_cdf(x, **kwargs)
return self._call_log_measure('_log_cdf', x, kwargs)

def _log_sf(self, x, **kwargs):
return self.distribution.log_sf(x, **kwargs)
return self._call_log_measure('_log_sf', x, kwargs)

def _call_log_measure(self, attr, x, kwargs):
return getattr(self.distribution, attr)(x, **kwargs)

# TODO(b/152797117): Override _sample_n, once it supports joint distributions.
def sample(self, sample_shape=(), seed=None, name=None):
with tf.name_scope(name or 'sample_particles'):
sample_shape = prefer_static.concat([
[self.num_particles],
dist_util.expand_to_vector(sample_shape)], axis=0)
x = self.distribution.sample(sample_shape, seed=seed)
dist_util.expand_to_vector(sample_shape),
[self.num_particles]], axis=0)
return self.distribution.sample(sample_shape, seed=seed)


# TODO(davmre): Replace this hack with a more efficient TF builtin.
def _batch_gather(params, indices, axis=0):
"""Gathers a batch of indices from `params` along the given axis.
def move_particles_to_rightmost_batch_dim(x, event_shape):
ndims = prefer_static.rank_from_shape(prefer_static.shape(x))
event_ndims = prefer_static.rank_from_shape(event_shape)
return dist_util.move_dimension(x, 0, ndims - event_ndims - 1)
return tf.nest.map_structure(
move_particles_to_rightmost_batch_dim,
x, self.distribution.event_shape_tensor())
Args:
params: `Tensor` of shape `[d[0], d[1], ..., d[N - 1]]`.
indices: int `Tensor` of shape broadcastable to that of `params`.
axis: int `Tensor` dimension of `params` (and of the broadcast indices) to
gather over.
Returns:
result: `Tensor` of the same type and shape as `params`.
"""
params_rank = prefer_static.rank_from_shape(prefer_static.shape(params))
indices_rank = prefer_static.rank_from_shape(prefer_static.shape(indices))
params_with_axis_on_right = dist_util.move_dimension(
params, source_idx=axis, dest_idx=-1)
indices_with_axis_on_right = prefer_static.broadcast_to(
dist_util.move_dimension(indices,
source_idx=axis - (params_rank - indices_rank),
dest_idx=-1),
prefer_static.shape(params_with_axis_on_right))

result = tf.gather(params_with_axis_on_right,
indices_with_axis_on_right,
axis=params_rank - 1,
batch_dims=params_rank - 1)
return dist_util.move_dimension(result, source_idx=-1, dest_idx=axis)


def _dummy_indices_like(indices):
"""Returns dummy indices ([0, 1, 2, ...]) with batch shape like `indices`."""
indices_shape = prefer_static.shape(indices)
num_particles = indices_shape[0]
return tf.broadcast_to(
prefer_static.reshape(
prefer_static.range(num_particles),
prefer_static.concat([[num_particles],
prefer_static.ones([
prefer_static.rank_from_shape(
indices_shape) - 1], dtype=np.int32)],
axis=0)),
indices_shape)


def _gather_history(structure, step, num_steps):
Expand All @@ -115,9 +156,9 @@ def _gather_history(structure, step, num_steps):
def ess_below_threshold(unnormalized_log_weights, threshold=0.5):
"""Determines if the effective sample size is much less than num_particles."""
with tf.name_scope('ess_below_threshold'):
num_particles = prefer_static.shape(unnormalized_log_weights)[-1]
log_weights = tf.math.log_softmax(unnormalized_log_weights, axis=-1)
log_ess = -tf.math.reduce_logsumexp(2 * log_weights, axis=-1)
num_particles = prefer_static.shape(unnormalized_log_weights)[0]
log_weights = tf.math.log_softmax(unnormalized_log_weights, axis=0)
log_ess = -tf.math.reduce_logsumexp(2 * log_weights, axis=0)
return log_ess < (prefer_static.log(num_particles) +
prefer_static.log(threshold))

Expand Down Expand Up @@ -253,7 +294,7 @@ def infer_trajectories(observations,
Returns:
trajectories: a (structure of) Tensor(s) matching the latent state, each
of shape
`concat([[num_timesteps, b1, ..., bN, num_particles], event_shape])`,
`concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`,
representing unbiased samples from the posterior distribution
`p(latent_states | observations)`.
step_log_marginal_likelihoods: float `Tensor` of shape
Expand Down Expand Up @@ -369,18 +410,13 @@ def observation_fn(_, state):
weighted_trajectories = reconstruct_trajectories(particles, parent_indices)

# Resample all steps of the trajectories using the final weights.
final_log_weights = log_weights[-1, ...]
batch_rank = prefer_static.rank_from_shape(
prefer_static.shape(final_log_weights)[:-1])
final_indices = dist_util.move_dimension(
categorical.Categorical(final_log_weights).sample(
num_particles, seed=seed), source_idx=0, dest_idx=-1)
final_indices_tiled_over_time = tf.broadcast_to(
final_indices, prefer_static.shape(parent_indices))
resample_indices = categorical.Categorical(
dist_util.move_dimension(
log_weights[-1, ...],
source_idx=0,
dest_idx=-1)).sample(num_particles, seed=seed)
trajectories = tf.nest.map_structure(
lambda x: tf.gather( # pylint: disable=g-long-lambda
x, final_indices_tiled_over_time,
axis=batch_rank + 1, batch_dims=batch_rank + 1),
lambda x: _batch_gather(x, resample_indices, axis=1),
weighted_trajectories)

return trajectories, step_log_marginal_likelihoods
Expand Down Expand Up @@ -416,18 +452,18 @@ def particle_filter(observations,
Returns:
particles: a (structure of) Tensor(s) matching the latent state, each
of shape
`concat([[num_timesteps, b1, ..., bN, num_particles], event_shape])`,
`concat([[num_timesteps, num_particles, b1, ..., bN], event_shape])`,
representing (possibly weighted) samples from the series of filtering
distributions `p(latent_states[t] | observations[:t])`.
log_weights: `float` `Tensor` of shape
`[num_timesteps, b1, ..., bN, num_particles]`, such that
`[num_timesteps, num_particles, b1, ..., bN]`, such that
`log_weights[t, :]` are the logarithms of normalized importance weights
(such that `exp(reduce_logsumexp(log_weights), axis=-1) == 1.`) of
the particles at time `t`. These may be used in conjunction with
`particles` to compute expectations under the series of filtering
distributions.
parent_indices: `int` `Tensor` of shape
`[num_timesteps, b1, ..., bN, num_particles]`,
`[num_timesteps, num_particles, b1, ..., bN]`,
such that `parent_indices[t, k]` gives the index of the particle at
time `t - 1` that the `k`th particle at time `t` is immediately descended
from. See also
Expand Down Expand Up @@ -460,9 +496,18 @@ def particle_filter(observations,
None if initial_state_proposal is None
else lambda _1, _2: SampleParticles( # pylint: disable=g-long-lambda
initial_state_proposal, num_particles))
log_uniform_weights = (
prefer_static.zeros([num_particles], dtype=tf.float32) -
prefer_static.log(num_particles))

# Initially the particles all have the same weight, `1. / num_particles`.
broadcast_batch_shape = tf.convert_to_tensor(
functools.reduce(
prefer_static.broadcast_shape,
tf.nest.flatten(initial_state_prior.batch_shape_tensor()),
[]), dtype=tf.int32)
log_uniform_weights = prefer_static.zeros(
prefer_static.concat([
[num_particles],
broadcast_batch_shape], axis=0),
dtype=tf.float32) - prefer_static.log(num_particles)

# Initialize from the prior, and incorporate the first observation.
initial_step_results = _filter_one_step(
Expand Down Expand Up @@ -573,19 +618,12 @@ def _update_loop_variables(step,
history_is_empty = (tf.is_tensor(state_history) and
state_history.shape[0] == 0)
if not history_is_empty:
batch_shape = prefer_static.shape(
current_step_results.parent_indices)[1:-1]
batch_rank = prefer_static.rank_from_shape(batch_shape)

# Permute the particles from previous steps to match the current resampled
# indices, so that the state history reflects coherent trajectories.
def update_state_history_to_use_current_indices(x):
return tf.gather(x[-1:],
current_step_results.parent_indices,
axis=batch_rank + 1,
batch_dims=batch_rank)
resampled_state_history = tf.nest.map_structure(
update_state_history_to_use_current_indices,
lambda x: _batch_gather(x[1:], # pylint: disable=g-long-lambda
current_step_results.parent_indices,
axis=1),
state_history)

# Update the history by concat'ing the carried-forward elements with the
Expand Down Expand Up @@ -614,7 +652,7 @@ def _filter_one_step(step,
"""Advances the particle filter by a single time step."""
with tf.name_scope('filter_one_step'):
seed = SeedStream(seed, 'filter_one_step')
num_particles = prefer_static.shape(log_weights)[-1]
num_particles = prefer_static.shape(log_weights)[0]

proposed_particles, proposal_log_weights = _propose_with_log_weights(
step=step - 1,
Expand All @@ -629,14 +667,11 @@ def _filter_one_step(step,
proposal_log_weights +
observation_log_weights)
step_log_marginal_likelihood = tf.math.reduce_logsumexp(
unnormalized_log_weights, axis=-1)
log_weights = (unnormalized_log_weights -
step_log_marginal_likelihood[..., tf.newaxis])
unnormalized_log_weights, axis=0)
log_weights = (unnormalized_log_weights - step_log_marginal_likelihood)

# Adaptive resampling: resample particles iff the specified criterion.
do_resample = tf.convert_to_tensor(
resample_criterion_fn(unnormalized_log_weights)
)[..., tf.newaxis] # Broadcast over particles.
do_resample = resample_criterion_fn(unnormalized_log_weights)

# Some batch elements may require resampling and others not, so
# we first do the resampling for all elements, then select whether to use
Expand All @@ -647,17 +682,17 @@ def _filter_one_step(step,
# for statistical (not computational) purposes, so this isn't a dealbreaker.
resampled_particles, resample_indices = _resample(
proposed_particles, log_weights, seed=seed)
dummy_indices = tf.broadcast_to(
prefer_static.range(num_particles),
prefer_static.shape(resample_indices))

uniform_weights = (prefer_static.zeros_like(log_weights) -
prefer_static.log(num_particles))
(resampled_particles,
resample_indices,
log_weights) = tf.nest.map_structure(
lambda r, p: prefer_static.where(do_resample, r, p),
(resampled_particles, resample_indices, uniform_weights),
(proposed_particles, dummy_indices, log_weights))
(proposed_particles,
_dummy_indices_like(resample_indices),
log_weights))

return ParticleFilterStepResults(
particles=resampled_particles,
Expand Down Expand Up @@ -713,76 +748,58 @@ def _compute_observation_log_weights(step,
Args:
step: int `Tensor` current step.
particles: Nested structure of `Tensor`s, each of shape
`concat([[b1, ..., bN], [num_particles], latent_part_event_shape])`, where
`concat([[num_particles, b1, ..., bN], latent_part_event_shape])`, where
`b1, ..., bN` are optional batch dimensions.
observation: Nested structure of `Tensor`s, each of shape
`concat([[b1, ..., bN], observation_part_event_shape])` where
`b1, ..., bN` are optional batch dimensions.
observation_fn: callable, producing a distribution over `observation`s.
Returns:
log_weights: `Tensor` of shape `concat([[b1, ..., bN], [num_particles]])`.
log_weights: `Tensor` of shape `concat([num_particles, b1, ..., bN])`.
"""
with tf.name_scope('compute_observation_log_weights'):
observation_dist = observation_fn(step, particles)
def _add_right_batch_dim(obs, event_shape):
ndims = prefer_static.rank_from_shape(prefer_static.shape(obs))
event_ndims = prefer_static.rank_from_shape(event_shape)
return tf.expand_dims(obs, ndims - event_ndims)
observation_broadcast_over_particles = tf.nest.map_structure(
_add_right_batch_dim,
observation,
observation_dist.event_shape_tensor())
return observation_dist.log_prob(observation_broadcast_over_particles)
return observation_dist.log_prob(observation)


def _resample(particles, log_weights, seed=None):
"""Resamples the current particles according to provided weights.
Args:
particles: Nested structure of `Tensor`s each of shape
`[b1, ..., bN, num_particles, ...]`, where
`[num_particles, b1, ..., bN, ...]`, where
`b1, ..., bN` are optional batch dimensions.
log_weights: float `Tensor` of shape `[b1, ..., bN, num_particles]`, where
log_weights: float `Tensor` of shape `[num_particles, b1, ..., bN]`, where
`b1, ..., bN` are optional batch dimensions.
seed: Python `int` random seed.
Returns:
resampled_particles: Nested structure of `Tensor`s, matching `particles`.
resample_indices: int `Tensor` of shape `[b1, ..., bN, num_particles]`.
resample_indices: int `Tensor` of shape `[num_particles, b1, ..., bN]`.
"""
with tf.name_scope('resample'):
weights_shape = prefer_static.shape(log_weights)
batch_shape, num_particles = weights_shape[:-1], weights_shape[-1]
batch_rank = prefer_static.rank_from_shape(batch_shape)

resample_indices = dist_util.move_dimension(
categorical.Categorical(log_weights).sample(num_particles, seed=seed),
0, -1)
num_particles = prefer_static.shape(log_weights)[0]
resample_indices = categorical.Categorical(
dist_util.move_dimension(
log_weights,
source_idx=0,
dest_idx=-1)).sample(num_particles, seed=seed)
resampled_particles = tf.nest.map_structure(
lambda x: tf.gather( # pylint: disable=g-long-lambda
x, resample_indices, axis=batch_rank, batch_dims=batch_rank),
lambda x: _batch_gather(x, resample_indices, axis=0),
particles)
return resampled_particles, resample_indices


def reconstruct_trajectories(particles, parent_indices, name=None):
"""Reconstructs the ancestor trajectory that generated each final particle."""
with tf.name_scope(name or 'reconstruct_trajectories'):
indices_shape = prefer_static.shape(parent_indices)
batch_shape, num_trajectories = indices_shape[1:-1], indices_shape[-1]
batch_rank = prefer_static.rank_from_shape(batch_shape)

# Walk backwards to compute the ancestor of each final particle at time t.
final_indices = tf.broadcast_to(
tf.range(0, num_trajectories), indices_shape[1:])
final_indices = _dummy_indices_like(parent_indices[-1])
ancestor_indices = tf.scan(
fn=lambda ancestor, parent: tf.gather( # pylint: disable=g-long-lambda
parent, ancestor, axis=batch_rank, batch_dims=batch_rank),
fn=lambda ancestor, parent: _batch_gather(parent, ancestor, axis=0),
elems=parent_indices[1:],
initializer=final_indices,
reverse=True)
ancestor_indices = tf.concat([ancestor_indices, [final_indices]], axis=0)

return tf.nest.map_structure(
lambda part: tf.gather(part, ancestor_indices, # pylint: disable=g-long-lambda
axis=batch_rank + 1, batch_dims=batch_rank + 1),
particles)
lambda part: _batch_gather(part, ancestor_indices, axis=1), particles)
Loading

0 comments on commit ed8edb8

Please sign in to comment.