Skip to content

Commit

Permalink
Adds a pivoted Cholesky decomposition to tfp.math. This can be useful…
Browse files Browse the repository at this point in the history
… as a preconditioner for conjugate gradient solves.

PiperOrigin-RevId: 245225276
  • Loading branch information
brianwa84 authored and tensorflower-gardener committed Apr 25, 2019
1 parent a10b9a6 commit e36ae05
Show file tree
Hide file tree
Showing 4 changed files with 320 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tensorflow_probability/python/math/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,19 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
# absl/testing:parameterized dep,
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/internal:dtype_util",
"//tensorflow_probability/python/internal:tensorshape_util",
],
)

py_test(
name = "linalg_test",
size = "small",
srcs = ["linalg_test.py"],
shard_count = 3,
shard_count = 5,
deps = [
# hypothesis dep,
# numpy dep,
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_probability/python/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tensorflow_probability.python.math.linalg import lu_solve
from tensorflow_probability.python.math.linalg import matrix_rank
from tensorflow_probability.python.math.linalg import pinv
from tensorflow_probability.python.math.linalg import pivoted_cholesky
from tensorflow_probability.python.math.linalg import sparse_or_dense_matmul
from tensorflow_probability.python.math.linalg import sparse_or_dense_matvecmul
from tensorflow_probability.python.math.numeric import clip_by_value_preserve_gradient
Expand Down Expand Up @@ -57,6 +58,7 @@
'lu_solve',
'matrix_rank',
'pinv',
'pivoted_cholesky',
'random_rademacher',
'random_rayleigh',
'secant_root',
Expand Down
183 changes: 183 additions & 0 deletions tensorflow_probability/python/math/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow.python.ops.linalg import linear_operator_util


Expand All @@ -38,6 +39,7 @@
'lu_solve',
'matrix_rank',
'pinv',
'pivoted_cholesky',
'sparse_or_dense_matmul',
'sparse_or_dense_matvecmul',
]
Expand Down Expand Up @@ -137,6 +139,187 @@ def cholesky_concat_slow(chol, cols): # cols shaped (n + m) x m = z x m
axis=-2)


def _swap_m_with_i(vecs, m, i):
"""Swaps `m` and `i` on axis -1. (Helper for pivoted_cholesky.)
Given a batch of int64 vectors `vecs`, scalar index `m`, and compatibly shaped
per-vector indices `i`, this function swaps elements `m` and `i` in each
vector. For the use-case below, these are permutation vectors.
Args:
vecs: Vectors on which we perform the swap, int64 `Tensor`.
m: Scalar int64 `Tensor`, the index into which the `i`th element is going.
i: Batch int64 `Tensor`, shaped like vecs.shape[:-1] + [1]; the index into
which the `m`th element is going.
Returns:
vecs: The updated vectors.
"""
vecs = tf.convert_to_tensor(value=vecs, dtype=tf.int64, name='vecs')
m = tf.convert_to_tensor(value=m, dtype=tf.int64, name='m')
i = tf.convert_to_tensor(value=i, dtype=tf.int64, name='i')
trailing_elts = tf.broadcast_to(
tf.range(m + 1,
prefer_static.shape(vecs, out_type=tf.int64)[-1]),
prefer_static.shape(vecs[..., m + 1:]))
shp = prefer_static.shape(trailing_elts)
trailing_elts = tf.where(
tf.equal(trailing_elts, tf.broadcast_to(i, shp)),
tf.broadcast_to(tf.gather(vecs, [m], axis=-1), shp),
tf.broadcast_to(vecs[..., m + 1:], shp))
# TODO(bjp): Could we use tensor_scatter_nd_update?
vecs_shape = vecs.shape
vecs = tf.concat([
vecs[..., :m],
tf.gather(vecs, i, batch_dims=prefer_static.rank(vecs) - 1), trailing_elts
], axis=-1)
tensorshape_util.set_shape(vecs, vecs_shape)
return vecs


def _invert_permutation(perm): # TODO(b/130217510): Remove this function.
return tf.cast(
tf.math.top_k(perm, k=prefer_static.shape(perm)[-1],
sorted=True).indices[..., ::-1], perm.dtype)


def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None):
"""Computes the (partial) pivoted cholesky decomposition of `matrix`.
The pivoted Cholesky is a low rank approximation of the Cholesky decomposition
of `matrix`, i.e. as described in [(Harbrecht et al., 2012)][1]. The
currently-worst-approximated diagonal element is selected as the pivot at each
iteration. This yields from a `[B1...Bn, N, N]` shaped `matrix` a `[B1...Bn,
N, K]` shaped rank-`K` approximation `lr` such that `lr @ lr.T ~= matrix`.
Note that, unlike the Cholesky decomposition, `lr` is not triangular even in
a rectangular-matrix sense. However, under a permutation it could be made
triangular (it has one more zero in each column as you move to the right).
Such a matrix can be useful as a preconditioner for conjugate gradient
optimization, i.e. as in [(Wang et al. 2019)][2], as matmuls and solves can be
cheaply done via the Woodbury matrix identity, as implemented by
`tf.linalg.LinearOperatorLowRankUpdate`.
Args:
matrix: Floating point `Tensor` batch of symmetric, positive definite
matrices.
max_rank: Scalar `int` `Tensor`, the rank at which to truncate the
approximation.
diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the
errors of all diagonal elements of `lr @ lr.T` are each lower than
`element * diag_rtol`, iteration is permitted to terminate early.
name: Optional name for the op.
Returns:
lr: Low rank pivoted Cholesky approximation of `matrix`.
#### References
[1]: H Harbrecht, M Peters, R Schneider. On the low-rank approximation by the
pivoted Cholesky decomposition. _Applied numerical mathematics_,
62(4):428-440, 2012.
[2]: K. A. Wang et al. Exact Gaussian Processes on a Million Data Points.
_arXiv preprint arXiv:1903.08114_, 2019. https://arxiv.org/abs/1903.08114
"""
with tf.compat.v2.name_scope(name or 'pivoted_cholesky'):
dtype = dtype_util.common_dtype([matrix, diag_rtol],
preferred_dtype=tf.float32)
matrix = tf.convert_to_tensor(value=matrix, name='matrix', dtype=dtype)
if tensorshape_util.rank(matrix.shape) is None:
raise NotImplementedError('Rank of `matrix` must be known statically')

max_rank = tf.convert_to_tensor(
value=max_rank, name='max_rank', dtype=tf.int64)
max_rank = tf.minimum(max_rank,
prefer_static.shape(matrix, out_type=tf.int64)[-1])
diag_rtol = tf.convert_to_tensor(
value=diag_rtol, dtype=dtype, name='diag_rtol')
matrix_diag = tf.linalg.diag_part(matrix)
# matrix is P.D., therefore all matrix_diag > 0, so we don't need abs.
orig_error = tf.reduce_max(input_tensor=matrix_diag, axis=-1)

def cond(m, pchol, perm, matrix_diag):
"""Condition for `tf.while_loop` continuation."""
del pchol
del perm
error = tf.linalg.norm(tensor=matrix_diag, ord=1, axis=-1)
max_err = tf.reduce_max(input_tensor=error / orig_error)
return (m < max_rank) & (tf.equal(m, 0) | (max_err > diag_rtol))

batch_dims = tensorshape_util.rank(matrix.shape) - 2
def batch_gather(params, indices, axis=-1):
return tf.gather(params, indices, axis=axis, batch_dims=batch_dims)

def body(m, pchol, perm, matrix_diag):
"""Body of a single `tf.while_loop` iteration."""
# Here is roughly a numpy, non-batched version of what's going to happen.
# (See also Algorithm 1 of Harbrecht et al.)
# 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m
# 2: maxval = matrix_diag[perm][maxi]
# 3: perm[m], perm[maxi] = perm[maxi], perm[m]
# 4: row = matrix[perm[m]][perm[m + 1:]]
# 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2)
# 6: pivot = np.sqrt(maxval); row /= pivot
# 7: row = np.concatenate([[[pivot]], row], -1)
# 8: matrix_diag[perm[m:]] -= row**2
# 9: pchol[m, perm[m:]] = row

# Find the maximal position of the (remaining) permuted diagonal.
# Steps 1, 2 above.
permuted_diag = batch_gather(matrix_diag, perm[..., m:])
maxi = tf.argmax(
input=permuted_diag, axis=-1, output_type=tf.int64)[..., tf.newaxis]
maxval = batch_gather(permuted_diag, maxi)
maxi = maxi + m
maxval = maxval[..., 0]
# Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above.
perm = _swap_m_with_i(perm, m, maxi)
# Step 4.
row = batch_gather(matrix, perm[..., m:m + 1], axis=-2)
row = batch_gather(row, perm[..., m + 1:])
# Step 5.
prev_rows = pchol[..., :m, :]
prev_rows_perm_m_onward = batch_gather(prev_rows, perm[..., m + 1:])
prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1])
row -= tf.reduce_sum(
input_tensor=prev_rows_perm_m_onward * prev_rows_pivot_col,
axis=-2)[..., tf.newaxis, :]
# Step 6.
pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis]
# Step 7.
row = tf.concat([pivot, row / pivot], axis=-1)
# TODO(b/130899118): Pad grad fails with int64 paddings.
# Step 8.
paddings = tf.concat([
tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32),
[[tf.cast(m, tf.int32), 0]]], axis=0)
diag_update = tf.pad(tensor=row**2, paddings=paddings)[..., 0, :]
reverse_perm = _invert_permutation(perm)
matrix_diag -= batch_gather(diag_update, reverse_perm)
# Step 9.
row = tf.pad(tensor=row, paddings=paddings)
# TODO(bjp): Defer the reverse permutation all-at-once at the end?
row = batch_gather(row, reverse_perm)
pchol_shape = pchol.shape
pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]],
axis=-2)
tensorshape_util.set_shape(pchol, pchol_shape)
return m + 1, pchol, perm, matrix_diag

m = np.int64(0)
pchol = tf.zeros_like(matrix[..., :max_rank, :])
matrix_shape = prefer_static.shape(matrix, out_type=tf.int64)
perm = tf.broadcast_to(
prefer_static.range(matrix_shape[-1]), matrix_shape[:-1])
_, pchol, _, _ = tf.while_loop(
cond=cond, body=body, loop_vars=(m, pchol, perm, matrix_diag))
pchol = tf.linalg.matrix_transpose(pchol)
tensorshape_util.set_shape(
pchol, tensorshape_util.concatenate(matrix_diag.shape, [None]))
return pchol


def pinv(a, rcond=None, validate_args=False, name=None):
"""Compute the Moore-Penrose pseudo-inverse of a matrix.
Expand Down
132 changes: 132 additions & 0 deletions tensorflow_probability/python/math/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import print_function

# Dependency imports
from absl.testing import parameterized
import hypothesis as hp
from hypothesis import strategies as hps
from hypothesis.extra import numpy as hpnp
Expand Down Expand Up @@ -197,6 +198,137 @@ class CholeskyExtend64Dynamic(_CholeskyExtend):
del _CholeskyExtend


class _PivotedCholesky(tf.test.TestCase, parameterized.TestCase):

def _random_batch_psd(self, dim):
matrix = np.random.random([2, dim, dim])
matrix = np.matmul(matrix, np.swapaxes(matrix, -2, -1))
matrix = (matrix + np.diag(np.arange(dim) * .1)).astype(self.dtype)
masked_shape = (
matrix.shape if self.use_static_shape else [None] * len(matrix.shape))
matrix = tf.compat.v1.placeholder_with_default(matrix, shape=masked_shape)
return matrix

def testPivotedCholesky(self):
dim = 11
matrix = self._random_batch_psd(dim)
true_diag = tf.linalg.diag_part(matrix)

pchol = tfp.math.pivoted_cholesky(matrix, max_rank=1)
mat = tf.matmul(pchol, pchol, transpose_b=True)
diag_diff_prev = self.evaluate(tf.abs(tf.linalg.diag_part(mat) - true_diag))
diff_norm_prev = self.evaluate(
tf.linalg.norm(tensor=mat - matrix, ord='fro', axis=[-1, -2]))
for rank in range(2, dim + 1):
# Specifying diag_rtol forces the full max_rank decomposition.
pchol = tfp.math.pivoted_cholesky(matrix, max_rank=rank, diag_rtol=-1)
zeros_per_col = dim - tf.math.count_nonzero(pchol, axis=-2)
mat = tf.matmul(pchol, pchol, transpose_b=True)
pchol_shp, diag_diff, diff_norm, zeros_per_col = self.evaluate([
tf.shape(input=pchol),
tf.abs(tf.linalg.diag_part(mat) - true_diag),
tf.linalg.norm(tensor=mat - matrix, ord='fro', axis=[-1, -2]),
zeros_per_col
])
self.assertAllEqual([2, dim, rank], pchol_shp)
self.assertAllEqual(
np.ones([2, rank], dtype=np.bool), zeros_per_col >= np.arange(rank))
self.assertAllLessEqual(diag_diff - diag_diff_prev,
np.finfo(self.dtype).resolution)
self.assertAllLessEqual(diff_norm - diff_norm_prev,
np.finfo(self.dtype).resolution)
diag_diff_prev, diff_norm_prev = diag_diff, diff_norm

def testGradient(self):
dim = 11
matrix = self._random_batch_psd(dim)
_, dmatrix = tfp.math.value_and_gradient(
lambda matrix: tfp.math.pivoted_cholesky(matrix, max_rank=dim // 3),
matrix)
self.assertIsNotNone(dmatrix)
self.assertAllGreater(
tf.linalg.norm(tensor=dmatrix, ord='fro', axis=[-1, -2]), 0.)

@test_util.enable_control_flow_v2
def testGradientTapeCFv2(self):
dim = 11
matrix = self._random_batch_psd(dim)
with tf.GradientTape() as tape:
tape.watch(matrix)
pchol = tfp.math.pivoted_cholesky(matrix, max_rank=dim // 3)
dmatrix = tape.gradient(
pchol, matrix, output_gradients=tf.ones_like(pchol) * .01)
self.assertIsNotNone(dmatrix)
self.assertAllGreater(
tf.linalg.norm(tensor=dmatrix, ord='fro', axis=[-1, -2]), 0.)

# pyformat: disable
@parameterized.parameters(
# Inputs are randomly shuffled arange->tril; outputs from gpytorch.
(
np.array([
[7., 0, 0, 0, 0, 0],
[9, 13, 0, 0, 0, 0],
[4, 10, 6, 0, 0, 0],
[18, 1, 2, 14, 0, 0],
[5, 11, 20, 3, 17, 0],
[19, 12, 16, 15, 8, 21]
]),
np.array([
[3.4444, -1.3545, 4.084, 1.7674, -1.1789, 3.7562],
[8.4685, 1.2821, 3.1179, 12.9197, 0.0000, 0.0000],
[7.5621, 4.8603, 0.0634, 7.3942, 4.0637, 0.0000],
[15.435, -4.8864, 16.2137, 0.0000, 0.0000, 0.0000],
[18.8535, 22.103, 0.0000, 0.0000, 0.0000, 0.0000],
[38.6135, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]
])),
(
np.array([
[1, 0, 0],
[2, 3, 0],
[4, 5, 6.]
]),
np.array([
[0.4558, 0.3252, 0.8285],
[2.6211, 2.4759, 0.0000],
[8.7750, 0.0000, 0.0000]
])),
(
np.array([
[6, 0, 0],
[3, 2, 0],
[4, 1, 5.]
]),
np.array([
[3.7033, 4.7208, 0.0000],
[2.1602, 2.1183, 1.9612],
[6.4807, 0.0000, 0.0000]
])))
# pyformat: enable
def testOracleExamples(self, mat, oracle_pchol):
mat = np.matmul(mat, mat.T)
for rank in range(1, mat.shape[-1] + 1):
self.assertAllClose(
oracle_pchol[..., :rank],
tfp.math.pivoted_cholesky(mat, max_rank=rank, diag_rtol=-1),
atol=1e-4)


@test_util.run_all_in_graph_and_eager_modes
class PivotedCholesky32Static(_PivotedCholesky):
dtype = np.float32
use_static_shape = True


@test_util.run_all_in_graph_and_eager_modes
class PivotedCholesky64Dynamic(_PivotedCholesky):
dtype = np.float64
use_static_shape = False


del _PivotedCholesky


def make_tensor_hiding_attributes(value, hide_shape, hide_value=True):
if not hide_value:
return tf.convert_to_tensor(value=value)
Expand Down

0 comments on commit e36ae05

Please sign in to comment.