Skip to content

Commit

Permalink
Use the new robust implementation of log determinant (linalg.slogdet)…
Browse files Browse the repository at this point in the history
… in LinearOperator.

Move creation of aliases into linalg_impl, such that these can be used inside TensorFlow without creating circular dependencies.

PiperOrigin-RevId: 173558892
  • Loading branch information
tensorflower-gardener committed Oct 26, 2017
1 parent 8269c66 commit f0aa811
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 42 deletions.
5 changes: 2 additions & 3 deletions tensorflow/python/ops/linalg/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@ py_library(
srcs = glob(["*.py"]),
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
":linalg_impl",
"//tensorflow/python:check_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:special_math_ops",
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//third_party/py/numpy",
Expand All @@ -33,6 +31,7 @@ py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:special_math_ops",
],
)

Expand Down
40 changes: 5 additions & 35 deletions tensorflow/python/ops/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@
from __future__ import division
from __future__ import print_function

from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import special_math_ops

# go/tf-wildcard-import
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.ops.linalg.linalg_impl import *
Expand All @@ -36,39 +30,15 @@
from tensorflow.python.ops.linalg.linear_operator_lower_triangular import *
# pylint: enable=wildcard-import

# Linear algebra ops.
band_part = array_ops.matrix_band_part
cholesky = linalg_ops.cholesky
cholesky_solve = linalg_ops.cholesky_solve
det = linalg_ops.matrix_determinant
# pylint: disable=protected-access
slogdet = gen_linalg_ops._log_matrix_determinant
# pylint: disable=protected-access
diag = array_ops.matrix_diag
diag_part = array_ops.matrix_diag_part
eigh = linalg_ops.self_adjoint_eig
eigvalsh = linalg_ops.self_adjoint_eigvals
einsum = special_math_ops.einsum
eye = linalg_ops.eye
inv = linalg_ops.matrix_inverse
lstsq = linalg_ops.matrix_solve_ls
norm = linalg_ops.norm
qr = linalg_ops.qr
set_diag = array_ops.matrix_set_diag
solve = linalg_ops.matrix_solve
svd = linalg_ops.svd
tensordot = math_ops.tensordot
trace = math_ops.trace
transpose = array_ops.matrix_transpose
triangular_solve = linalg_ops.matrix_triangular_solve

# Seal API.
# pylint: disable=undefined-variable
del absolute_import
del array_ops
del division
del print_function
del ops
del array_ops
del gen_linalg_ops
del linalg_ops
del math_ops
del ops
del print_function
del special_math_ops
# pylint: enable=undefined-variable
28 changes: 28 additions & 0 deletions tensorflow/python/ops/linalg/linalg_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,35 @@
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import special_math_ops

# Linear algebra ops.
band_part = array_ops.matrix_band_part
cholesky = linalg_ops.cholesky
cholesky_solve = linalg_ops.cholesky_solve
det = linalg_ops.matrix_determinant
# pylint: disable=protected-access
slogdet = gen_linalg_ops._log_matrix_determinant
# pylint: disable=protected-access
diag = array_ops.matrix_diag
diag_part = array_ops.matrix_diag_part
eigh = linalg_ops.self_adjoint_eig
eigvalsh = linalg_ops.self_adjoint_eigvals
einsum = special_math_ops.einsum
eye = linalg_ops.eye
inv = linalg_ops.matrix_inverse
lstsq = linalg_ops.matrix_solve_ls
norm = linalg_ops.norm
qr = linalg_ops.qr
set_diag = array_ops.matrix_set_diag
solve = linalg_ops.matrix_solve
svd = linalg_ops.svd
tensordot = math_ops.tensordot
trace = math_ops.trace
transpose = array_ops.matrix_transpose
triangular_solve = linalg_ops.matrix_triangular_solve


def logdet(matrix, name=None):
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/python/ops/linalg/linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,8 @@ def _log_abs_determinant(self):
if self._can_use_cholesky():
diag = array_ops.matrix_diag_part(self._get_cached_chol())
return 2 * math_ops.reduce_sum(math_ops.log(diag), reduction_indices=[-1])
abs_det = math_ops.abs(self.determinant())
return math_ops.log(abs_det)
_, log_abs_det = linalg.slogdet(self._matrix)
return log_abs_det

def log_abs_determinant(self, name="log_abs_det"):
"""Log absolute value of determinant for every batch member.
Expand Down
3 changes: 1 addition & 2 deletions tensorflow/python/ops/linalg/linear_operator_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,7 @@ def test_log_abs_det(self):
operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
shape, dtype, use_placeholder=use_placeholder)
op_log_abs_det = operator.log_abs_determinant()
mat_log_abs_det = math_ops.log(
math_ops.abs(linalg_ops.matrix_determinant(mat)))
_, mat_log_abs_det = linalg.slogdet(mat)
if not use_placeholder:
self.assertAllEqual(shape[:-2], op_log_abs_det.get_shape())
op_log_abs_det_v, mat_log_abs_det_v = sess.run(
Expand Down

0 comments on commit f0aa811

Please sign in to comment.