Skip to content

Commit

Permalink
Fix a bug in which JointDistributionCoroutine wouldn't always respect…
Browse files Browse the repository at this point in the history
… its sample_dtype.

PiperOrigin-RevId: 341542985
  • Loading branch information
davmre authored and tensorflower-gardener committed Nov 10, 2020
1 parent de48765 commit aede47b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,10 @@ def __init__(
"""
parameters = dict(locals())
with tf.name_scope(name or 'JointDistributionCoroutine') as name:
self._sample_dtype = sample_dtype
self._model_coroutine = model
# Hint `no_dependency` to tell tf.Module not to screw up the sample dtype
# with extraneous wrapping (list => ListWrapper, etc.).
self._sample_dtype = self._no_dependency(sample_dtype)
self._single_sample_distributions = {}
super(JointDistributionCoroutine, self).__init__(
dtype=sample_dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@ def noncentered_horseshoe_prior(num_features):
tfd.Sample(tfd.Normal(0., 1.), num_features))
yield tfd.Independent(tfd.Deterministic(weights_noncentered * scale),
reinterpreted_batch_ndims=1)

# Currently sample_dtype is only used for `tf.nest.pack_structure_as`. In
# the future we may use it for error checking and/or casting.
sample_dtype = collections.namedtuple('Model', [
Expand All @@ -645,6 +646,18 @@ def noncentered_horseshoe_prior(num_features):
self.assertEqual([3, 4], joint.log_prob(
joint.sample([3, 4], seed=test_util.test_seed())).shape)

# Check that a list dtype doesn't get corrupted by `tf.Module` wrapping.
sample_dtype = [None, None, None, None]
joint = tfd.JointDistributionCoroutine(
lambda: noncentered_horseshoe_prior(4),
sample_dtype=sample_dtype,
validate_args=True)
ds, xs = joint.sample_distributions([2, 3], seed=test_util.test_seed())
self.assertEqual(type(sample_dtype), type(xs))
self.assertEqual(type(sample_dtype), type(ds))
tf.nest.assert_same_structure(sample_dtype, ds)
tf.nest.assert_same_structure(sample_dtype, xs)

def test_repr_with_custom_sample_dtype(self):
def model():
s = yield tfd.JointDistributionCoroutine.Root(
Expand Down

0 comments on commit aede47b

Please sign in to comment.