Skip to content

Commit

Permalink
Ensure JDPinned produces types consistent with underlying JD
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 339919800
  • Loading branch information
csuter authored and tensorflower-gardener committed Oct 30, 2020
1 parent e5ee9e6 commit 6baf348
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -355,16 +355,17 @@ def _model_flatten(self, x):

def _model_unflatten(self, xs):
"""Unflattens `xs` to a structure like-typed to `self.distribution`."""
model = self.distribution.dtype
if isinstance(model, dict):
# Use the underlying JD dtype to infer model structure.
dtype = self.distribution.dtype
if isinstance(dtype, dict):
ks = self._flat_resolve_names()
if len(ks) != len(xs):
raise ValueError('Invalid xs length {}, ks={}'.format(len(xs), ks))
return type(model)(zip(ks, xs))
if hasattr(model, '_fields') and hasattr(model, '_asdict'):
ks = [k for k in model._fields if k not in self.pins]
return type(dtype)(zip(ks, xs))
if hasattr(dtype, '_fields') and hasattr(dtype, '_asdict'):
ks = [k for k in dtype._fields if k not in self.pins]
return structural_tuple.structtuple(ks)(*xs)
return tuple(xs)
return type(dtype)(xs)

def _prune(self, xs, retain=None):
"""Drops fields from `xs`, retaining those specified by `retain`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ def test_bijector(self):
lambda a: tfd.Uniform(a + tf.ones_like(a), a + tf.constant(2, a.dtype)),
lambda b, a: tfd.Uniform(a, b, name='c')])
bij = jd.experimental_default_event_space_bijector(a=-.5, b=1.)
self.assertAllClose((2/3,), tf.math.sigmoid(bij.inverse((0.5,))))
test_input = (0.5,)
self.assertIs(type(jd.dtype), type(bij.inverse(test_input)))
self.assertAllClose((2/3,), tf.math.sigmoid(bij.inverse(test_input)))

@tfd.JointDistributionCoroutine
def model():
Expand Down

0 comments on commit 6baf348

Please sign in to comment.