Skip to content

Commit

Permalink
Restrain Hypothesis from generating absurd limit arguments to Truncat…
Browse files Browse the repository at this point in the history
…edNormal.

PiperOrigin-RevId: 411620996
  • Loading branch information
axch authored and tensorflower-gardener committed Nov 22, 2021
1 parent 8ac3ac1 commit b305c9c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
14 changes: 11 additions & 3 deletions tensorflow_probability/python/distributions/hypothesis_testlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ def fix_triangular(d):
return dict(d, peak=peak, high=high)


def fix_truncated_normal(params):
new_params = dict(params, high=tfp_hps.ensure_high_gt_low(
params['low'], params['high']))
max_low = params['loc'] + 5 * params['scale']
min_high = params['loc'] - 5 * params['scale']
new_params['high'] = tfp_hps.ensure_high_gt_low(min_high, new_params['high'])
new_params['low'] = tfp_hps.ensure_low_lt_high(new_params['low'], max_low)
return new_params


def fix_wishart(d):
df = d['df']
scale = d.get('scale', d.get('scale_tril'))
Expand Down Expand Up @@ -341,9 +351,7 @@ def fix_bates(d):
'TruncatedCauchy':
lambda d: dict(d, high=tfp_hps.ensure_high_gt_low( # pylint:disable=g-long-lambda
d['low'], d['high'])),
'TruncatedNormal':
lambda d: dict(d, high=tfp_hps.ensure_high_gt_low( # pylint:disable=g-long-lambda
d['low'], d['high'])),
'TruncatedNormal': fix_truncated_normal,
'Uniform':
lambda d: dict(d, high=tfp_hps.ensure_high_gt_low( # pylint:disable=g-long-lambda
d['low'], d['high'])),
Expand Down
21 changes: 21 additions & 0 deletions tensorflow_probability/python/internal/hypothesis_testlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,27 @@ def ensure_high_gt_low(low, high):
return new_high


def ensure_low_lt_high(low, high):
"""Returns a value with shape matching `low` and lt broadcastable `high`."""
new_low = tf.minimum(high - tf.abs(high) * .1 - .1, low)
reduce_dims = []
if (tensorshape_util.rank(new_low.shape) >
tensorshape_util.rank(low.shape)):
reduced_leading_axes = tf.range(
tensorshape_util.rank(new_low.shape) -
tensorshape_util.rank(low.shape))
new_low = tf.math.reduce_min(
new_low, axis=reduced_leading_axes)
reduce_dims = [
d for d in range(tensorshape_util.rank(low.shape))
if low.shape[d] < new_low.shape[d]
]
if reduce_dims:
new_low = tf.math.reduce_min(
new_low, axis=reduce_dims, keepdims=True)
return new_low


def softplus_plus_eps(eps=1e-6):
return lambda x: tf.nn.softplus(x) + eps

Expand Down

0 comments on commit b305c9c

Please sign in to comment.