Skip to content

Commit

Permalink
Add mean, mode, covariance, variance, stddev methods to MVN precision…
Browse files Browse the repository at this point in the history
… factor linear operator distribution.

This fills out the API, and can help confirm that the distribution is written as intended to if, for example, a covariance is known.

PiperOrigin-RevId: 355852942
  • Loading branch information
ColCarroll authored and tensorflower-gardener committed Feb 5, 2021
1 parent 239de96 commit 54bb464
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.internal import tensorshape_util


__all__ = ['MultivariateNormalPrecisionFactorLinearOperator']
Expand Down Expand Up @@ -230,6 +231,46 @@ def precision_factor(self):
def precision(self):
return self._precision

def _mean(self):
shape = tensorshape_util.concatenate(self.batch_shape, self.event_shape)
has_static_shape = tensorshape_util.is_fully_defined(shape)
if not has_static_shape:
shape = tf.concat([
self.batch_shape_tensor(),
self.event_shape_tensor(),
], 0)

if self.loc is None:
return tf.zeros(shape, self.dtype)

return tf.broadcast_to(self.loc, shape)

def _covariance(self):
if self._precision is None:
inv_precision_factor = self._precision_factor.inverse()
cov = inv_precision_factor.matmul(inv_precision_factor, adjoint=True)
else:
cov = self._precision.inverse()
return cov.to_dense()

def _variance(self):
if self._precision is None:
precision = self._precision_factor.matmul(
self._precision_factor, adjoint_arg=True)
else:
precision = self._precision
variance = precision.inverse().diag_part()
return tf.broadcast_to(
variance,
ps.broadcast_shape(ps.shape(variance),
ps.shape(self.loc)))

def _stddev(self):
return tf.sqrt(self._variance())

def _mode(self):
return self._mean()

def _log_prob_unnormalized(self, value):
"""Unnormalized log probability.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,61 @@ def test_dynamic_shape(self):
self.assertAllClose(self.evaluate(dynamic_dist.log_prob(in_)),
static_dist.log_prob(in_))

@test_combinations.generate(
test_combinations.combine(
batch_shape=[(), (2,)],
dtype=[np.float32, np.float64],
),
)
def test_mean_and_mode(self, batch_shape, dtype):
event_size = 3
cov = self._random_constant_spd_linop(
event_size, batch_shape=batch_shape, dtype=dtype)
precision_factor = cov.inverse().cholesky()

# Make sure to evaluate here, else you'll have a random loc vector!
loc = self.evaluate(
tf.random.normal(
batch_shape + (event_size,),
dtype=dtype,
seed=test_util.test_seed()))

mvn_precision = tfd_e.MultivariateNormalPrecisionFactorLinearOperator(
loc=loc,
precision_factor=precision_factor)
self.assertAllClose(mvn_precision.mean(), loc)
self.assertAllClose(mvn_precision.mode(), loc)

@test_combinations.generate(
test_combinations.combine(
batch_shape=[(), (2,)],
use_precision=[True, False],
dtype=[np.float32, np.float64],
),
)
def test_cov_var_stddev(self, batch_shape, use_precision, dtype):
event_size = 3
cov = self._random_constant_spd_linop(
event_size, batch_shape=batch_shape, dtype=dtype)
precision = cov.inverse()
precision_factor = precision.cholesky()

# Make sure to evaluate here, else you'll have a random loc vector!
loc = self.evaluate(
tf.random.normal(
batch_shape + (event_size,),
dtype=dtype,
seed=test_util.test_seed()))

mvn_precision = tfd_e.MultivariateNormalPrecisionFactorLinearOperator(
loc=loc,
precision_factor=precision_factor,
precision=precision if use_precision else None)
self.assertAllClose(mvn_precision.covariance(), cov.to_dense(), atol=1e-4)
self.assertAllClose(mvn_precision.variance(), cov.diag_part(), atol=1e-4)
self.assertAllClose(mvn_precision.stddev(), tf.sqrt(cov.diag_part()),
atol=1e-5)


if __name__ == '__main__':
tf.test.main()

0 comments on commit 54bb464

Please sign in to comment.