Skip to content

Commit

Permalink
Avoid feeding Nones to tf.control_dependencies in mixed eager/graph…
Browse files Browse the repository at this point in the history
… contexts.

Add a property-based test that tries to create Distributions in eager mode then
sample from them in graph mode, to exercise the failure mode.

PiperOrigin-RevId: 347692243
  • Loading branch information
csuter authored and tensorflower-gardener committed Dec 15, 2020
1 parent 5ba9d60 commit 0963d79
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tensorflow_probability/python/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,6 +1555,13 @@ def _name_and_control_scope(self, name=None, value=UNSET_VALUE, kwargs=None):
if not deps:
yield name_scope
return
# In eager mode, some `assert_util.assert_xyz` calls return None. If a
# Distribution is created in eager mode with `validate_args=True`, then
# used in a `tf.function` context, it can result in errors when
# `tf.convert_to_tensor` is called on the inputs to
# `tf.control_dependencies` below. To avoid these errors, we drop the
# `None`s here.
deps = [x for x in deps if x is not None]
with tf.control_dependencies(deps) as deps_scope:
yield deps_scope

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,29 @@ def disabled_testFailureCase(self): # pylint: disable=invalid-name
self.assertAllClose(dist.log_prob(samps)[0], dist[0].log_prob(samps[0]))


# Don't decorate with test_util.test_all_tf_execution_regimes, since we're
# explicitly mixing modes.
class TestMixingGraphAndEagerModes(test_util.TestCase):

@parameterized.named_parameters(
{'testcase_name': dname, 'dist_name': dname}
for dname in sorted(list(dhps.INSTANTIABLE_BASE_DISTS.keys()) +
list(dhps.INSTANTIABLE_META_DISTS))
)
@hp.given(hps.data())
@tfp_hps.tfp_hp_settings()
def testSampleEagerCreatedDistributionInGraphMode(self, dist_name, data):
if not tf.executing_eagerly():
self.skipTest('Only test mixed eager/graph behavior in eager tests.')
# Create in eager mode.
dist = data.draw(dhps.distributions(dist_name=dist_name, enable_vars=False))

@tf.function
def f():
dist.sample()
f()


if __name__ == '__main__':
# Hypothesis often finds numerical near misses. Debugging them is much aided
# by seeing all the digits of every floating point number, instead of the
Expand Down

0 comments on commit 0963d79

Please sign in to comment.