Skip to content

Commit

Permalink
Update linear operators to use prefer_static for concat/shape and o…
Browse files Browse the repository at this point in the history
…ther various backend fixes

PiperOrigin-RevId: 414003888
  • Loading branch information
sharadmv authored and tensorflower-gardener committed Dec 3, 2021
1 parent 22f2f06 commit 914251a
Show file tree
Hide file tree
Showing 29 changed files with 199 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@
tensorshape_util = private.LazyLoader(
"tensorshape_util", globals(),
"tensorflow_probability.substrates.numpy.internal.tensorshape_util")
prefer_static = private.LazyLoader(
"prefer_static", globals(),
"tensorflow_probability.substrates.numpy.internal.prefer_static")
"""

LINOP_UTIL_SUFFIX = """
Expand Down Expand Up @@ -174,6 +177,8 @@ def gen_module(module_name):
code = re.sub(r'([_a-zA-Z0-9.\[\]]+).is_integer',
'np.issubdtype(\\1, np.integer)', code)

code = code.replace('array_ops.shape', 'prefer_static.shape')
code = code.replace('array_ops.concat', 'prefer_static.concat')
code = code.replace('array_ops.broadcast_static_shape',
'_ops.broadcast_static_shape')
code = code.replace('array_ops.broadcast_to', '_ops.broadcast_to')
Expand Down
1 change: 1 addition & 0 deletions tensorflow_probability/python/internal/backend/numpy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ py_library(
srcs_version = "PY3",
deps = [
":_utils",
":dtype",
":v1",
":v2",
# six dep,
Expand Down
21 changes: 21 additions & 0 deletions tensorflow_probability/python/internal/backend/numpy/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import six

from tensorflow_probability.python.internal.backend.numpy import _utils as utils
from tensorflow_probability.python.internal.backend.numpy import dtype as dt
from tensorflow_probability.python.internal.backend.numpy import ops


Expand All @@ -40,6 +41,7 @@
'assert_rank_at_least',
'assert_rank_in',
'assert_scalar',
'assert_same_float_dtype',
'check_numerics',
]

Expand Down Expand Up @@ -210,6 +212,21 @@ def _assert_rank_in(*_, **__): # pylint: disable=unused-argument
pass


def _assert_same_float_dtype(tensors=None, dtype=None): # pylint: disable=unused-argument,redefined-outer-name
"""Checks that all tensors have the same dtype."""
expected_dtype = None
if tensors:
expected_dtype = dtype
for t in tensors:
if not expected_dtype:
expected_dtype = t.dtype
elif expected_dtype != t.dtype:
raise ValueError(f'Mismatched dtypes.: {expected_dtype} vs. {t.dtype}')
if not expected_dtype:
expected_dtype = dt.float32
return expected_dtype


# --- Begin Public Functions --------------------------------------------------


Expand Down Expand Up @@ -290,6 +307,10 @@ def _assert_rank_in(*_, **__): # pylint: disable=unused-argument
'tf.debugging.assert_rank_in',
_assert_rank_in)

assert_same_float_dtype = utils.copy_docstring(
'tf.debugging.assert_same_float_dtype',
_assert_same_float_dtype)

check_numerics = utils.copy_docstring(
'tf.debugging.check_numerics',
lambda x, *_, **__: x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,7 @@ def _adjoint_householder(householder_operator):
tensorshape_util = private.LazyLoader(
"tensorshape_util", globals(),
"tensorflow_probability.substrates.numpy.internal.tensorshape_util")
prefer_static = private.LazyLoader(
"prefer_static", globals(),
"tensorflow_probability.substrates.numpy.internal.prefer_static")

Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,7 @@ def _cholesky_kronecker(kronecker_operator):
tensorshape_util = private.LazyLoader(
"tensorshape_util", globals(),
"tensorflow_probability.substrates.numpy.internal.tensorshape_util")
prefer_static = private.LazyLoader(
"prefer_static", globals(),
"tensorflow_probability.substrates.numpy.internal.prefer_static")

Original file line number Diff line number Diff line change
Expand Up @@ -250,4 +250,7 @@ def _inverse_householder(householder_operator):
tensorshape_util = private.LazyLoader(
"tensorshape_util", globals(),
"tensorflow_probability.substrates.numpy.internal.tensorshape_util")
prefer_static = private.LazyLoader(
"prefer_static", globals(),
"tensorflow_probability.substrates.numpy.internal.prefer_static")

Original file line number Diff line number Diff line change
Expand Up @@ -1479,4 +1479,7 @@ def _trace(x, name=None):
tensorshape_util = private.LazyLoader(
"tensorshape_util", globals(),
"tensorflow_probability.substrates.numpy.internal.tensorshape_util")
prefer_static = private.LazyLoader(
"prefer_static", globals(),
"tensorflow_probability.substrates.numpy.internal.prefer_static")

Original file line number Diff line number Diff line change
Expand Up @@ -474,4 +474,7 @@ def _type(operator):
tensorshape_util = private.LazyLoader(
"tensorshape_util", globals(),
"tensorflow_probability.substrates.numpy.internal.tensorshape_util")
prefer_static = private.LazyLoader(
"prefer_static", globals(),
"tensorflow_probability.substrates.numpy.internal.prefer_static")

Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _shape(self):
def _shape_tensor(self):
# Rotate last dimension
shape = self.operator.shape_tensor()
return array_ops.concat([
return prefer_static.concat([
shape[:-2], [shape[-1], shape[-2]]], axis=-1)

def _matmul(self, x, adjoint=False, adjoint_arg=False):
Expand Down Expand Up @@ -262,4 +262,7 @@ def _composite_tensor_fields(self):
tensorshape_util = private.LazyLoader(
"tensorshape_util", globals(),
"tensorflow_probability.substrates.numpy.internal.tensorshape_util")
prefer_static = private.LazyLoader(
"prefer_static", globals(),
"tensorflow_probability.substrates.numpy.internal.prefer_static")

Original file line number Diff line number Diff line change
Expand Up @@ -435,4 +435,7 @@ def __call__(self, inverse_fn):
tensorshape_util = private.LazyLoader(
"tensorshape_util", globals(),
"tensorflow_probability.substrates.numpy.internal.tensorshape_util")
prefer_static = private.LazyLoader(
"prefer_static", globals(),
"tensorflow_probability.substrates.numpy.internal.prefer_static")

Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,9 @@ def _shape_tensor(self):
zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor())
for operator in self.operators[1:]:
zeros = zeros + array_ops.zeros(shape=operator.batch_shape_tensor())
batch_shape = array_ops.shape(zeros)
batch_shape = prefer_static.shape(zeros)

return array_ops.concat((batch_shape, matrix_shape), 0)
return prefer_static.concat((batch_shape, matrix_shape), 0)

# TODO(b/188080761): Add a more efficient implementation of `cond` that
# constructs the condition number from the blockwise singular values.
Expand Down Expand Up @@ -431,7 +431,7 @@ def _matmul(self, x, adjoint=False, adjoint_arg=False):

result_list = linear_operator_util.broadcast_matrix_batch_dims(
result_list)
return array_ops.concat(result_list, axis=-2)
return prefer_static.concat(result_list, axis=-2)

def matvec(self, x, adjoint=False, name="matvec"):
"""Transform [batch] vector `x` with left multiplication: `x --> Ax`.
Expand Down Expand Up @@ -619,7 +619,7 @@ def _check_operators_agree(r, l, message):

solution_list = linear_operator_util.broadcast_matrix_batch_dims(
solution_list)
return array_ops.concat(solution_list, axis=-2)
return prefer_static.concat(solution_list, axis=-2)

def solvevec(self, rhs, adjoint=False, name="solve"):
"""Solve single equation with best effort: `A X = rhs`.
Expand Down Expand Up @@ -694,7 +694,7 @@ def _diag_part(self):
# Extend the axis for broadcasting.
diag_list = diag_list + [operator.diag_part()[..., _ops.newaxis]]
diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list)
diagonal = array_ops.concat(diag_list, axis=-2)
diagonal = prefer_static.concat(diag_list, axis=-2)
return array_ops.squeeze(diagonal, axis=-1)

def _trace(self):
Expand All @@ -714,23 +714,23 @@ def _to_dense(self):
broadcasted_blocks = linear_operator_util.broadcast_matrix_batch_dims(
broadcasted_blocks)
for block in broadcasted_blocks:
batch_row_shape = array_ops.shape(block)[:-1]
batch_row_shape = prefer_static.shape(block)[:-1]

zeros_to_pad_before_shape = array_ops.concat(
zeros_to_pad_before_shape = prefer_static.concat(
[batch_row_shape, [num_cols]], axis=-1)
zeros_to_pad_before = array_ops.zeros(
shape=zeros_to_pad_before_shape, dtype=block.dtype)
num_cols = num_cols + array_ops.shape(block)[-1]
zeros_to_pad_after_shape = array_ops.concat(
num_cols = num_cols + prefer_static.shape(block)[-1]
zeros_to_pad_after_shape = prefer_static.concat(
[batch_row_shape,
[self.domain_dimension_tensor() - num_cols]], axis=-1)
zeros_to_pad_after = array_ops.zeros(
shape=zeros_to_pad_after_shape, dtype=block.dtype)

rows.append(array_ops.concat(
rows.append(prefer_static.concat(
[zeros_to_pad_before, block, zeros_to_pad_after], axis=-1))

mat = array_ops.concat(rows, axis=-2)
mat = prefer_static.concat(rows, axis=-2)
tensorshape_util.set_shape(mat, tensor_shape.TensorShape(self.shape))
return mat

Expand All @@ -756,7 +756,7 @@ def _eigvals(self):
# Extend the axis for broadcasting.
eig_list = eig_list + [operator.eigvals()[..., _ops.newaxis]]
eig_list = linear_operator_util.broadcast_matrix_batch_dims(eig_list)
eigs = array_ops.concat(eig_list, axis=-2)
eigs = prefer_static.concat(eig_list, axis=-2)
return array_ops.squeeze(eigs, axis=-1)

@property
Expand All @@ -775,4 +775,7 @@ def _composite_tensor_fields(self):
tensorshape_util = private.LazyLoader(
"tensorshape_util", globals(),
"tensorflow_probability.substrates.numpy.internal.tensorshape_util")
prefer_static = private.LazyLoader(
"prefer_static", globals(),
"tensorflow_probability.substrates.numpy.internal.prefer_static")

Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def _shape_tensor(self):
batch_shape = array_ops.broadcast_dynamic_shape(
batch_shape, operator.batch_shape_tensor())

return array_ops.concat((batch_shape, matrix_shape), 0)
return prefer_static.concat((batch_shape, matrix_shape), 0)

def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
"""Transform [batch] matrix `x` with left multiplication: `x --> Ax`.
Expand Down Expand Up @@ -539,7 +539,7 @@ def _matmul(self, x, adjoint=False, adjoint_arg=False):

result_list = linear_operator_util.broadcast_matrix_batch_dims(
result_list)
return array_ops.concat(result_list, axis=-2)
return prefer_static.concat(result_list, axis=-2)

def matvec(self, x, adjoint=False, name="matvec"):
"""Transform [batch] vector `x` with left multiplication: `x --> Ax`.
Expand Down Expand Up @@ -779,7 +779,7 @@ def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):

solution_list = linear_operator_util.broadcast_matrix_batch_dims(
solution_list)
return array_ops.concat(solution_list, axis=-2)
return prefer_static.concat(solution_list, axis=-2)

def solvevec(self, rhs, adjoint=False, name="solve"):
"""Solve single equation with best effort: `A X = rhs`.
Expand Down Expand Up @@ -850,7 +850,7 @@ def _diag_part(self):
# final two dimensions as batch dimensions.
diag_list.append(op.diag_part()[..., _ops.newaxis])
diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list)
diagonal = array_ops.concat(diag_list, axis=-2)
diagonal = prefer_static.concat(diag_list, axis=-2)
return array_ops.squeeze(diagonal, axis=-1)

def _trace(self):
Expand All @@ -868,18 +868,18 @@ def _to_dense(self):
flat_broadcast_operators[i * (i + 1) // 2:(i + 1) * (i + 2) // 2]
for i in range(len(self.operators))]
for row_blocks in broadcast_operators:
batch_row_shape = array_ops.shape(row_blocks[0])[:-1]
num_cols = num_cols + array_ops.shape(row_blocks[-1])[-1]
zeros_to_pad_after_shape = array_ops.concat(
batch_row_shape = prefer_static.shape(row_blocks[0])[:-1]
num_cols = num_cols + prefer_static.shape(row_blocks[-1])[-1]
zeros_to_pad_after_shape = prefer_static.concat(
[batch_row_shape,
[self.domain_dimension_tensor() - num_cols]], axis=-1)
zeros_to_pad_after = array_ops.zeros(
shape=zeros_to_pad_after_shape, dtype=self.dtype)

row_blocks.append(zeros_to_pad_after)
dense_rows.append(array_ops.concat(row_blocks, axis=-1))
dense_rows.append(prefer_static.concat(row_blocks, axis=-1))

mat = array_ops.concat(dense_rows, axis=-2)
mat = prefer_static.concat(dense_rows, axis=-2)
tensorshape_util.set_shape(mat, tensor_shape.TensorShape(self.shape))
return mat

Expand All @@ -893,7 +893,7 @@ def _eigvals(self):
# Extend the axis for broadcasting.
eig_list.append(op.eigvals()[..., _ops.newaxis])
eig_list = linear_operator_util.broadcast_matrix_batch_dims(eig_list)
eigs = array_ops.concat(eig_list, axis=-2)
eigs = prefer_static.concat(eig_list, axis=-2)
return array_ops.squeeze(eigs, axis=-1)

@property
Expand All @@ -912,4 +912,7 @@ def _composite_tensor_fields(self):
tensorshape_util = private.LazyLoader(
"tensorshape_util", globals(),
"tensorflow_probability.substrates.numpy.internal.tensorshape_util")
prefer_static = private.LazyLoader(
"prefer_static", globals(),
"tensorflow_probability.substrates.numpy.internal.prefer_static")

Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
tensorshape_util = private.LazyLoader(
"tensorshape_util", globals(),
"tensorflow_probability.substrates.numpy.internal.tensorshape_util")
prefer_static = private.LazyLoader(
"prefer_static", globals(),
"tensorflow_probability.substrates.numpy.internal.prefer_static")

from tensorflow_probability.python.internal.backend.numpy import linalg_impl as linalg
from tensorflow_probability.python.internal.backend.numpy.gen import linear_operator
Expand Down Expand Up @@ -211,7 +214,7 @@ def _block_shape_tensor(self, spectrum_shape=None):
return linear_operator_util.shape_tensor(
self.block_shape.as_list(), name="block_shape")
spectrum_shape = (
array_ops.shape(self.spectrum)
prefer_static.shape(self.spectrum)
if spectrum_shape is None else spectrum_shape)
return spectrum_shape[-self.block_depth:]

Expand Down Expand Up @@ -248,8 +251,8 @@ def _vectorize_then_blockify(self, matrix):
vec_leading_shape = tensor_shape.TensorShape(vec.shape)[:-1]
final_shape = vec_leading_shape.concatenate(self.block_shape)
else:
vec_leading_shape = array_ops.shape(vec)[:-1]
final_shape = array_ops.concat(
vec_leading_shape = prefer_static.shape(vec)[:-1]
final_shape = prefer_static.concat(
(vec_leading_shape, self.block_shape_tensor()), 0)
return array_ops.reshape(vec, final_shape)

Expand All @@ -276,10 +279,10 @@ def _unblockify_then_matricize(self, vec):
# flat_shape = [v0, v1, v2*v3]
flat_shape = vec_leading_shape + [np.prod(vec_block_shape)]
else:
vec_shape = array_ops.shape(vec)
vec_shape = prefer_static.shape(vec)
vec_leading_shape = vec_shape[:-self.block_depth]
vec_block_shape = vec_shape[-self.block_depth:]
flat_shape = array_ops.concat(
flat_shape = prefer_static.concat(
(vec_leading_shape, [math_ops.reduce_prod(vec_block_shape)]), 0)
vec_flat = array_ops.reshape(vec, flat_shape)

Expand Down Expand Up @@ -353,12 +356,12 @@ def _shape(self):
def _shape_tensor(self, spectrum=None):
spectrum = self.spectrum if spectrum is None else spectrum
# See tensor_shape.TensorShape(self.shape) for explanation of steps
s_shape = array_ops.shape(spectrum)
s_shape = prefer_static.shape(spectrum)
batch_shape = s_shape[:-self.block_depth]
trailing_dims = s_shape[-self.block_depth:]
n = math_ops.reduce_prod(trailing_dims)
n_x_n = [n, n]
return array_ops.concat((batch_shape, n_x_n), 0)
return prefer_static.concat((batch_shape, n_x_n), 0)

def assert_hermitian_spectrum(self, name="assert_hermitian_spectrum"):
"""Returns an `Op` that asserts this operator has Hermitian spectrum.
Expand Down Expand Up @@ -415,16 +418,16 @@ def _broadcast_batch_dims(self, x, spectrum):
batch_shape = self._batch_shape_tensor(
shape=self._shape_tensor(spectrum=spectrum))
spec_mat = array_ops.reshape(
spectrum, array_ops.concat((batch_shape, [-1, 1]), axis=0))
spectrum, prefer_static.concat((batch_shape, [-1, 1]), axis=0))
# Second, broadcast, possibly requiring an addition of array of zeros.
x, spec_mat = linear_operator_util.broadcast_matrix_batch_dims((x,
spec_mat))
# Third, put the block shape back into spectrum.
x_batch_shape = array_ops.shape(x)[:-2]
spectrum_shape = array_ops.shape(spectrum)
x_batch_shape = prefer_static.shape(x)[:-2]
spectrum_shape = prefer_static.shape(spectrum)
spectrum = array_ops.reshape(
spec_mat,
array_ops.concat(
prefer_static.concat(
(x_batch_shape,
self._block_shape_tensor(spectrum_shape=spectrum_shape)),
axis=0))
Expand Down Expand Up @@ -1173,4 +1176,7 @@ def _to_complex(x):
tensorshape_util = private.LazyLoader(
"tensorshape_util", globals(),
"tensorflow_probability.substrates.numpy.internal.tensorshape_util")
prefer_static = private.LazyLoader(
"prefer_static", globals(),
"tensorflow_probability.substrates.numpy.internal.prefer_static")

Loading

0 comments on commit 914251a

Please sign in to comment.