Skip to content

Commit

Permalink
Allow instantiating Gaussian Process/Student T Process Regression Mod…
Browse files Browse the repository at this point in the history
…els and Schur Complement kernel with precomputed Cholesky factors (for use in Vizier).

PiperOrigin-RevId: 507908547
  • Loading branch information
emilyfertig authored and tensorflower-gardener committed Feb 7, 2023
1 parent 20dd41c commit 7a65f0f
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,9 @@ def precompute_regression_model(
always_yield_multivariate_normal=False,
validate_args=False,
allow_nan_stats=False,
name='PrecomputedGaussianProcessRegressionModel'):
name='PrecomputedGaussianProcessRegressionModel',
_precomputed_divisor_matrix_cholesky=None,
_precomputed_solve_on_observation=None):
"""Returns a GaussianProcessRegressionModel with precomputed quantities.
This differs from the constructor by precomputing quantities associated with
Expand Down Expand Up @@ -756,6 +758,8 @@ def precompute_regression_model(
Default value: `False`.
name: Python `str` name prefixed to Ops created by this class.
Default value: 'PrecomputedGaussianProcessRegressionModel'.
_precomputed_divisor_matrix_cholesky: Internal parameter -- do not use.
_precomputed_solve_on_observation: Internal parameter -- do not use.
Returns
An instance of `GaussianProcessRegressionModel` with precomputed
quantities associated with observations.
Expand Down Expand Up @@ -800,23 +804,26 @@ def precompute_regression_model(
fixed_inputs=observation_index_points,
fixed_inputs_is_missing=observations_is_missing,
cholesky_fn=cholesky_fn,
diag_shift=observation_noise_variance)

observation_cholesky_operator = tf.linalg.LinearOperatorLowerTriangular(
conditional_kernel.divisor_matrix_cholesky())
diag_shift=observation_noise_variance,
_precomputed_divisor_matrix_cholesky=(
_precomputed_divisor_matrix_cholesky))

if mean_fn is None:
mean_fn = lambda x: tf.zeros([1], dtype=dtype)
else:
if not callable(mean_fn):
raise ValueError('`mean_fn` must be a Python callable')

diff = observations - mean_fn(observation_index_points)
if observations_is_missing is not None:
diff = tf.where(
observations_is_missing, tf.zeros([], dtype=diff.dtype), diff)
solve_on_observation = observation_cholesky_operator.solvevec(
observation_cholesky_operator.solvevec(diff), adjoint=True)
solve_on_observation = _precomputed_solve_on_observation
if solve_on_observation is None:
observation_cholesky_operator = tf.linalg.LinearOperatorLowerTriangular(
conditional_kernel.divisor_matrix_cholesky())
diff = observations - mean_fn(observation_index_points)
if observations_is_missing is not None:
diff = tf.where(
observations_is_missing, tf.zeros([], dtype=diff.dtype), diff)
solve_on_observation = observation_cholesky_operator.solvevec(
observation_cholesky_operator.solvevec(diff), adjoint=True)

def conditional_mean_fn(x):
k_x_obs = kernel.matrix(x, observation_index_points)
Expand All @@ -841,6 +848,11 @@ def conditional_mean_fn(x):
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
name=name)
# pylint: disable=protected-access
gprm._precomputed_divisor_matrix_cholesky = (
conditional_kernel._precomputed_divisor_matrix_cholesky)
gprm._precomputed_solve_on_observation = solve_on_observation
# pylint: enable=protected-access

return gprm

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from unittest import mock
# Dependency imports
from absl.testing import parameterized
import numpy as np
Expand Down Expand Up @@ -656,6 +657,38 @@ def testStructuredIndexPoints(self):
gprm_with_lists.batch_shape_tensor())
self.assertAllClose(base_gprm.log_prob(s), gprm_with_lists.log_prob(s))

def testPrivateArgPreventsCholeskyRecomputation(self):
x = np.random.uniform(-1, 1, (4, 7)).astype(np.float32)
x_obs = np.random.uniform(-1, 1, (4, 7)).astype(np.float32)
y_obs = np.random.uniform(-1, 1, (4,)).astype(np.float32)
chol = np.eye(4).astype(np.float32)
mock_cholesky_fn = mock.Mock(return_value=chol)
base_kernel = exponentiated_quadratic.ExponentiatedQuadratic()
d = gprm.GaussianProcessRegressionModel.precompute_regression_model(
base_kernel,
index_points=x,
observation_index_points=x_obs,
observations=y_obs,
cholesky_fn=mock_cholesky_fn)
mock_cholesky_fn.assert_called_once()

mock_cholesky_fn.reset_mock()
d2 = gprm.GaussianProcessRegressionModel.precompute_regression_model(
base_kernel,
index_points=x,
observation_index_points=x_obs,
observations=y_obs,
cholesky_fn=mock_cholesky_fn,
_precomputed_divisor_matrix_cholesky=(
d._precomputed_divisor_matrix_cholesky),
_precomputed_solve_on_observation=d._precomputed_solve_on_observation)
mock_cholesky_fn.assert_not_called()

# The Cholesky is computed just once in each call to log_prob (on the
# index points kernel matrix).
self.assertAllClose(d.log_prob(y_obs), d2.log_prob(y_obs))
self.assertEqual(mock_cholesky_fn.call_count, 2)


class GaussianProcessRegressionModelStaticTest(
_GaussianProcessRegressionModelTest):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,9 @@ def precompute_regression_model(
cholesky_fn=None,
validate_args=False,
allow_nan_stats=False,
name='PrecomputedStudentTProcessRegressionModel'):
name='PrecomputedStudentTProcessRegressionModel',
_precomputed_divisor_matrix_cholesky=None,
_precomputed_solve_on_observation=None):
"""Returns a StudentTProcessRegressionModel with precomputed quantities.
This differs from the constructor by precomputing quantities associated with
Expand Down Expand Up @@ -564,6 +566,8 @@ def precompute_regression_model(
Default value: `False`.
name: Python `str` name prefixed to Ops created by this class.
Default value: 'PrecomputedStudentTProcessRegressionModel'.
_precomputed_divisor_matrix_cholesky: Internal parameter -- do not use.
_precomputed_solve_on_observation: Internal parameter -- do not use.
Returns
An instance of `StudentTProcessRegressionModel` with precomputed
quantities associated with observations.
Expand All @@ -584,33 +588,17 @@ def precompute_regression_model(
observation_noise_variance, dtype=dtype)
observations = tf.convert_to_tensor(observations, dtype=dtype)

observation_cholesky = kernel.matrix(
observation_index_points, observation_index_points)

broadcast_shape = distribution_util.get_broadcast_shape(
observation_cholesky,
observation_noise_variance[..., tf.newaxis, tf.newaxis])

observation_cholesky = tf.broadcast_to(
observation_cholesky, broadcast_shape)

observation_cholesky = tf.linalg.set_diag(
observation_cholesky,
tf.linalg.diag_part(observation_cholesky) +
observation_noise_variance[..., tf.newaxis])
if cholesky_fn is None:
cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn()

observation_cholesky = cholesky_fn(observation_cholesky)
observation_cholesky_operator = tf.linalg.LinearOperatorLowerTriangular(
observation_cholesky)

conditional_kernel = DampedSchurComplement(
df=df,
schur_complement=schur_complement_lib.SchurComplement(
schur_complement=schur_complement_lib.SchurComplement.with_precomputed_divisor(
base_kernel=kernel,
fixed_inputs=observation_index_points,
diag_shift=observation_noise_variance),
diag_shift=observation_noise_variance,
_precomputed_divisor_matrix_cholesky=(
_precomputed_divisor_matrix_cholesky)),
fixed_inputs_observations=observations,
validate_args=validate_args)

Expand All @@ -620,9 +608,29 @@ def precompute_regression_model(
if not callable(mean_fn):
raise ValueError('`mean_fn` must be a Python callable')

diff = observations - mean_fn(observation_index_points)
solve_on_observation = observation_cholesky_operator.solvevec(
observation_cholesky_operator.solvevec(diff), adjoint=True)
solve_on_observation = _precomputed_solve_on_observation
if solve_on_observation is None:
observation_cholesky = kernel.matrix(
observation_index_points, observation_index_points)

broadcast_shape = distribution_util.get_broadcast_shape(
observation_cholesky,
observation_noise_variance[..., tf.newaxis, tf.newaxis])

observation_cholesky = tf.broadcast_to(
observation_cholesky, broadcast_shape)

observation_cholesky = tf.linalg.set_diag(
observation_cholesky,
tf.linalg.diag_part(observation_cholesky) +
observation_noise_variance[..., tf.newaxis])
observation_cholesky = cholesky_fn(observation_cholesky)
observation_cholesky_operator = tf.linalg.LinearOperatorLowerTriangular(
observation_cholesky)

diff = observations - mean_fn(observation_index_points)
solve_on_observation = observation_cholesky_operator.solvevec(
observation_cholesky_operator.solvevec(diff), adjoint=True)

def conditional_mean_fn(x):
k_x_obs = kernel.matrix(x, observation_index_points)
Expand All @@ -643,6 +651,13 @@ def conditional_mean_fn(x):
allow_nan_stats=allow_nan_stats,
name=name)

# pylint: disable=protected-access
stprm._precomputed_divisor_matrix_cholesky = (
conditional_kernel.schur_complement
._precomputed_divisor_matrix_cholesky)
stprm._precomputed_solve_on_observation = solve_on_observation
# pylint: enable=protected-access

return stprm

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from unittest import mock
# Dependency imports
import numpy as np
import tensorflow.compat.v2 as tf
Expand Down Expand Up @@ -383,6 +384,41 @@ def cholesky_fn(x):
# self.assertAllClose(actual, call_log_prob(dist))
# self.assertAllClose(actual, call_log_prob(unflat))

def testPrivateArgPreventsCholeskyRecomputation(self):
df = np.float32(5.)
x = np.random.uniform(-1, 1, (4, 7)).astype(np.float32)
x_obs = np.random.uniform(-1, 1, (4, 7)).astype(np.float32)
y_obs = np.random.uniform(-1, 1, (4,)).astype(np.float32)
chol = np.eye(4).astype(np.float32)
mock_cholesky_fn = mock.Mock(return_value=chol)
base_kernel = psd_kernels.ExponentiatedQuadratic()
d = stprm.StudentTProcessRegressionModel.precompute_regression_model(
df,
base_kernel,
index_points=x,
observation_index_points=x_obs,
observations=y_obs,
cholesky_fn=mock_cholesky_fn)
mock_cholesky_fn.assert_called_once()

mock_cholesky_fn.reset_mock()
d2 = stprm.StudentTProcessRegressionModel.precompute_regression_model(
df,
base_kernel,
index_points=x,
observation_index_points=x_obs,
observations=y_obs,
cholesky_fn=mock_cholesky_fn,
_precomputed_divisor_matrix_cholesky=(
d._precomputed_divisor_matrix_cholesky),
_precomputed_solve_on_observation=d._precomputed_solve_on_observation)
mock_cholesky_fn.assert_not_called()

# The Cholesky is computed just once in each call to log_prob (on the
# index points kernel matrix).
self.assertAllClose(d.log_prob(y_obs), d2.log_prob(y_obs))
self.assertEqual(mock_cholesky_fn.call_count, 2)


if __name__ == '__main__':
test_util.main()
22 changes: 13 additions & 9 deletions tensorflow_probability/python/math/psd_kernels/schur_complement.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _compute_divisor_matrix(
base_kernel,
diag_shift,
fixed_inputs):
"""Compute the the modified kernel with respect to the fixed inputs."""
"""Compute the modified kernel with respect to the fixed inputs."""
divisor_matrix = base_kernel.matrix(fixed_inputs, fixed_inputs)
if diag_shift is not None:
diag_shift = tf.convert_to_tensor(diag_shift)
Expand Down Expand Up @@ -287,7 +287,8 @@ def with_precomputed_divisor(
diag_shift=None,
cholesky_fn=None,
validate_args=False,
name='PrecomputedSchurComplement'):
name='PrecomputedSchurComplement',
_precomputed_divisor_matrix_cholesky=None):
"""Returns a `SchurComplement` with a precomputed divisor matrix.
This method is the same as creating a `SchurComplement` kernel, but assumes
Expand Down Expand Up @@ -337,6 +338,7 @@ def with_precomputed_divisor(
Default value: `False`
name: Python `str` name prefixed to Ops created by this class.
Default value: `"PrecomputedSchurComplement"`
_precomputed_divisor_matrix_cholesky: Internal arg -- do not use.
"""
if tf.nest.is_nested(base_kernel.feature_ndims):
dtype = dtype_util.common_dtype(
Expand Down Expand Up @@ -364,13 +366,15 @@ def with_precomputed_divisor(
from tensorflow_probability.python.distributions import cholesky_util # pylint:disable=g-import-not-at-top
cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn()

# TODO(b/196219597): Add a check to ensure that we have a `base_kernel`
# that is explicitly concretized.
divisor_matrix_cholesky = cholesky_fn(util.mask_matrix(
_compute_divisor_matrix(base_kernel,
diag_shift=diag_shift,
fixed_inputs=fixed_inputs),
is_missing=fixed_inputs_is_missing))
divisor_matrix_cholesky = _precomputed_divisor_matrix_cholesky
if divisor_matrix_cholesky is None:
# TODO(b/196219597): Add a check to ensure that we have a `base_kernel`
# that is explicitly concretized.
divisor_matrix_cholesky = cholesky_fn(util.mask_matrix(
_compute_divisor_matrix(base_kernel,
diag_shift=diag_shift,
fixed_inputs=fixed_inputs),
is_missing=fixed_inputs_is_missing))

schur_complement = SchurComplement(
base_kernel=base_kernel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import functools
import itertools
from unittest import mock

from absl.testing import parameterized
import numpy as np
Expand Down Expand Up @@ -407,6 +408,27 @@ def testStructuredBaseKernel(self):
self.assertAllClose(masked_schur.apply(x, y),
masked_structured_schur.apply(struct_x, struct_y))

def testPrivateArgPreventsCholeskyRecomputation(self):
x = np.random.uniform(-1, 1, (4, 7)).astype(np.float32)
chol = np.eye(4).astype(np.float32)
mock_cholesky_fn = mock.Mock(return_value=chol)
base_kernel = exponentiated_quadratic.ExponentiatedQuadratic()
k = schur_complement.SchurComplement.with_precomputed_divisor(
base_kernel, x, cholesky_fn=mock_cholesky_fn)
mock_cholesky_fn.assert_called_once()

# Assert that the Cholesky is not recomputed when the kernel is rebuilt and
# its methods are called.
mock_cholesky_fn.reset_mock()
k2 = schur_complement.SchurComplement.with_precomputed_divisor(
base_kernel, x, cholesky_fn=mock_cholesky_fn,
_precomputed_divisor_matrix_cholesky=(
k._precomputed_divisor_matrix_cholesky))
y = np.random.uniform(-1, 1, size=(3, 7))
z = np.random.uniform(-1, 1, size=(2, 7))
self.assertAllClose(k.matrix(y, z), k2.matrix(y, z))
mock_cholesky_fn.assert_not_called()


if __name__ == '__main__':
test_util.main()

0 comments on commit 7a65f0f

Please sign in to comment.