Skip to content

Commit

Permalink
Fix distributions with broken slicing caused by not overriding parame…
Browse files Browse the repository at this point in the history
…ters; as identified by hypothesis testing.

PiperOrigin-RevId: 237472521
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Mar 8, 2019
1 parent a3d2fec commit 051dbaf
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 6 deletions.
4 changes: 3 additions & 1 deletion tensorflow_probability/python/distributions/chi.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self,
name: Python `str` name prefixed to Ops created by this class.
Default value: `'Chi'`.
"""
parameters = dict(locals())
with tf.compat.v1.name_scope(name, values=[df]) as name:
df = tf.convert_to_tensor(
value=df,
Expand All @@ -93,7 +94,8 @@ def __init__(self,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=name),
bijector=invert_bijector.Invert(square_bijector.Square()))
bijector=invert_bijector.Invert(square_bijector.Square()),
parameters=parameters)

def _params_event_ndims(self):
return dict(df=0)
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_probability/python/distributions/gumbel.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(self,
Raises:
TypeError: if loc and scale are different dtypes.
"""
parameters = dict(locals())
with tf.compat.v1.name_scope(name, values=[loc, scale]) as name:
dtype = dtype_util.common_dtype([loc, scale], preferred_dtype=tf.float32)
loc = tf.convert_to_tensor(value=loc, name="loc", dtype=dtype)
Expand All @@ -153,6 +154,7 @@ def __init__(self,
# be inverted.
bijector=invert_bijector.Invert(self._gumbel_bijector),
batch_shape=distribution_util.get_broadcast_shape(loc, scale),
parameters=parameters,
name=name)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_probability/python/distributions/inverse_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def __init__(self,
with tf.control_dependencies([
tf.compat.v1.assert_positive(
concentration, message="Concentration must be positive."),
tf.compat.v1
.assert_positive(scale, message="Scale must be positive."),
tf.compat.v1.assert_positive(
scale, message="Scale must be positive."),
] if validate_args else []):
self._concentration = tf.identity(concentration, name="concentration")
self._scale = tf.identity(scale, name="scale")
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_probability/python/distributions/kumaraswamy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from tensorflow_probability.python.distributions import uniform
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import reparameterization

__all__ = [
"Kumaraswamy",
Expand Down Expand Up @@ -146,6 +145,7 @@ def __init__(self,
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
parameters = dict(locals())
with tf.compat.v1.name_scope(
name, values=[concentration1, concentration0]) as name:
dtype = dtype_util.common_dtype([concentration1, concentration0],
Expand All @@ -165,8 +165,8 @@ def __init__(self,
validate_args=validate_args),
batch_shape=distribution_util.get_broadcast_shape(
concentration1, concentration0),
parameters=parameters,
name=name)
self._reparameterization_type = reparameterization.FULLY_REPARAMETERIZED

def _params_event_ndims(self):
return dict(concentration1=0, concentration0=0)
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_probability/python/distributions/lognormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(self,
undefined statistics will return NaN for this statistic.
name: The name to give Ops created by the initializer.
"""
parameters = dict(locals())
with tf.compat.v1.name_scope(name, values=[loc, scale]) as name:
dtype = dtype_util.common_dtype([loc, scale], tf.float32)
super(LogNormal, self).__init__(
Expand All @@ -70,6 +71,7 @@ def __init__(self,
value=scale, name="scale", dtype=dtype)),
bijector=exp_bijector.Exp(),
validate_args=validate_args,
parameters=parameters,
name=name)

def _params_event_ndims(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def __init__(self,
batch_shape=None,
event_shape=None,
validate_args=False,
parameters=None,
name=None):
"""Construct a Transformed Distribution.
Expand All @@ -243,10 +244,12 @@ def __init__(self,
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
parameters: Locals dict captured by subclass constructor, to be used for
copy/slice re-instantiation operations.
name: Python `str` name prefixed to Ops created by this class. Default:
`bijector.name + distribution.name`.
"""
parameters = dict(locals())
parameters = dict(locals()) if parameters is None else parameters
name = name or (("" if bijector is None else bijector.name) +
distribution.name)
with tf.compat.v1.name_scope(
Expand Down

0 comments on commit 051dbaf

Please sign in to comment.