Skip to content

Commit

Permalink
Enable named axes in windowed sampling
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 394764288
  • Loading branch information
sharadmv authored and tensorflower-gardener committed Sep 3, 2021
1 parent a6c0639 commit 9f9eec8
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import print_function

from tensorflow_probability.python.experimental.distribute.joint_distribution import JointDistributionCoroutine
from tensorflow_probability.python.experimental.distribute.joint_distribution import JointDistributionDistributedMixin
from tensorflow_probability.python.experimental.distribute.joint_distribution import JointDistributionNamed
from tensorflow_probability.python.experimental.distribute.joint_distribution import JointDistributionSequential
from tensorflow_probability.python.experimental.distribute.sharded import Sharded
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,12 @@ def _parameter_control_dependencies(self, is_init):
return self.distribution._parameter_control_dependencies(is_init=is_init) # pylint: disable=protected-access

def _default_event_space_bijector(self, *args, **kwargs):
bij = self.distribution.experimental_default_event_space_bijector(
*args, **kwargs)
if bij is None:
return None
return sharded_bij.Sharded(
self.distribution.experimental_default_event_space_bijector(
*args, **kwargs),
shard_axis_name=self.experimental_shard_axis_names)
bij, shard_axis_name=self.experimental_shard_axis_names)


@log_prob_ratio.RegisterLogProbRatio(Sharded)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,11 @@ def _model_coroutine(self):
except StopIteration:
pass

@property
def experimental_shard_axis_names(self):
return self._prune(self.distribution.experimental_shard_axis_names,
retain='unpinned')


def _to_pins(dist, *args, **kwargs):
"""Converts *args and **kwargs to a dict of pins.
Expand Down
7 changes: 7 additions & 0 deletions tensorflow_probability/python/experimental/mcmc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ multi_substrate_py_library(
# tensorflow dep,
"//tensorflow_probability/python/bijectors",
"//tensorflow_probability/python/distributions",
"//tensorflow_probability/python/experimental/distribute",
],
)

Expand All @@ -113,6 +114,8 @@ multi_substrate_py_test(
":initialization",
# tensorflow dep,
"//tensorflow_probability",
"//tensorflow_probability/python/experimental/distribute",
"//tensorflow_probability/python/internal:distribute_test_lib",
"//tensorflow_probability/python/internal:test_util",
],
)
Expand Down Expand Up @@ -997,13 +1000,15 @@ multi_substrate_py_library(
":preconditioned_hmc",
":preconditioned_nuts",
":preconditioning_utils",
":sharded",
# tensorflow dep,
"//tensorflow_probability/python/bijectors:bijector",
"//tensorflow_probability/python/bijectors:invert",
"//tensorflow_probability/python/bijectors:joint_map",
"//tensorflow_probability/python/bijectors:reshape",
"//tensorflow_probability/python/bijectors:restructure",
"//tensorflow_probability/python/experimental/stats:sample_stats",
"//tensorflow_probability/python/internal:distribute_lib",
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/internal:prefer_static",
"//tensorflow_probability/python/internal:samplers",
Expand Down Expand Up @@ -1031,6 +1036,8 @@ multi_substrate_py_test(
# absl/testing:parameterized dep,
# tensorflow dep,
"//tensorflow_probability",
"//tensorflow_probability/python/experimental/distribute",
"//tensorflow_probability/python/internal:distribute_test_lib",
"//tensorflow_probability/python/internal:prefer_static",
"//tensorflow_probability/python/internal:samplers",
"//tensorflow_probability/python/internal:test_util",
Expand Down
31 changes: 24 additions & 7 deletions tensorflow_probability/python/experimental/mcmc/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@

from tensorflow_probability.python import bijectors as tfb
from tensorflow_probability.python import distributions as tfd
from tensorflow_probability.python.experimental import distribute
from tensorflow_probability.python.internal import batched_rejection_sampler as brs
from tensorflow_probability.python.internal import nest_util
from tensorflow_probability.python.internal import tensorshape_util

from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import

__all__ = [
'init_near_unconstrained_zero',
]
Expand All @@ -35,7 +38,7 @@
def init_near_unconstrained_zero(
model=None, constraining_bijector=None, event_shapes=None,
event_shape_tensors=None, batch_shapes=None, batch_shape_tensors=None,
dtypes=None):
dtypes=None, shard_axis_names=None):
"""Returns an initialization Distribution for starting a Markov chain.
This initialization scheme follows Stan: we sample every latent
Expand Down Expand Up @@ -83,6 +86,9 @@ def init_near_unconstrained_zero(
the desired samples. Must be an acceptable input to
`constraining_bijector.inverse_dtype`. If supplied together
with `model`, acts as an override.
shard_axis_names: A structure of `str`s indicating the named axes by which
the distribution event is sharded. See
`tfp.experimental.distribute.Sharded` for more context.
Returns:
init_dist: A `Distribution` representing the initialization
Expand Down Expand Up @@ -126,6 +132,8 @@ def model():
if batch_shape_tensors is None:
batch_shape_tensors = nest_util.broadcast_structure(
dtypes, model.batch_shape_tensor())
if shard_axis_names is None:
shard_axis_names = model.experimental_shard_axis_names

else:
if constraining_bijector is None or event_shapes is None or dtypes is None:
Expand Down Expand Up @@ -157,13 +165,15 @@ def model():

# Actually initialize
def one_term(event_shape, event_shape_tensor, batch_shape, batch_shape_tensor,
dtype):
dtype, shard_axes=None):
if not tensorshape_util.is_fully_defined(event_shape):
event_shape = event_shape_tensor
result = tfd.Sample(
tfd.Uniform(low=tf.constant(-2., dtype=dtype),
high=tf.constant(2., dtype=dtype)),
sample_shape=event_shape)
if shard_axes:
result = distribute.Sharded(result, shard_axes)
if not tensorshape_util.is_fully_defined(batch_shape):
batch_shape = batch_shape_tensor
needs_bcast = True
Expand All @@ -180,11 +190,18 @@ def one_term(event_shape, event_shape_tensor, batch_shape, batch_shape_tensor,
else:
inv_shape_tensors = tf.nest.map_structure(lambda _: None, inv_shapes)
inv_dtypes = constraining_bijector.inverse_dtype(dtypes)
terms = tf.nest.map_structure(
one_term, inv_shapes, inv_shape_tensors, batch_shapes,
batch_shape_tensors, inv_dtypes)
unconstrained = tfb.pack_sequence_as(inv_shapes)(
tfd.JointDistributionSequential(tf.nest.flatten(terms)))
if shard_axis_names is None:
shard_axis_names = tf.nest.map_structure(lambda _: None, batch_shapes)
terms = nest.map_structure_up_to(inv_shapes, one_term, inv_shapes,
inv_shape_tensors, batch_shapes,
batch_shape_tensors,
inv_dtypes, shard_axis_names)
if shard_axis_names and any(shard_axes for shard_axes in nest.flatten_up_to(
batch_shapes, shard_axis_names)):
dist = distribute.JointDistributionSequential(tf.nest.flatten(terms))
else:
dist = tfd.JointDistributionSequential(tf.nest.flatten(terms))
unconstrained = tfb.pack_sequence_as(inv_shapes)(dist)
return tfd.TransformedDistribution(
unconstrained, bijector=constraining_bijector)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import tensorflow.compat.v2 as tf

import tensorflow_probability as tfp
from tensorflow_probability.python.experimental import distribute
from tensorflow_probability.python.internal import distribute_test_lib
from tensorflow_probability.python.internal import test_util

tfb = tfp.bijectors
Expand Down Expand Up @@ -158,5 +160,34 @@ def model():
seed=test_util.test_seed()).x.shape)


@test_util.disable_test_for_backend(
disable_numpy=True,
reason='Sharding not available for NumPy backend.')
class DistributedTest(distribute_test_lib.DistributedTest):

def test_can_initialize_from_sharded_distribution(self):

def model():
x = yield tfd.Normal(0., 1.)
yield distribute.Sharded(tfd.Normal(x, 1.), self.axis_name)

jd = distribute.JointDistributionCoroutine(model)

def run(seed):
init_jd = tfp.experimental.mcmc.init_near_unconstrained_zero(jd)
return init_jd.sample(seed=seed)

x, y = self.evaluate(
self.per_replica_to_tensor(
self.strategy_run(run, args=(test_util.test_seed(),),
in_axes=None)))
for i in range(distribute_test_lib.NUM_DEVICES):
for j in range(distribute_test_lib.NUM_DEVICES):
if i == j:
continue
self.assertAllClose(x[i], x[j])
self.assertNotAllClose(y[i], y[j])


if __name__ == '__main__':
test_util.main()
5 changes: 3 additions & 2 deletions tensorflow_probability/python/experimental/mcmc/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ def is_calibrated(self):

@property
def experimental_shard_axis_names(self):
return self.kernel.experimental_shard_axis_names
return self.inner_kernel.experimental_shard_axis_names

def experimental_with_shard_axes(self, shard_axis_names):
return self.copy(
kernel=kernel.experimental_with_shard_axes(shard_axis_names))
inner_kernel=self.inner_kernel.experimental_with_shard_axes(
shard_axis_names))
Loading

0 comments on commit 9f9eec8

Please sign in to comment.