Skip to content

Commit

Permalink
Retain shape information through a reduce_sum.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 440949677
  • Loading branch information
ColCarroll authored and tensorflower-gardener committed Apr 11, 2022
1 parent bd74bdf commit f81ed49
Showing 1 changed file with 1 addition and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ def _symmetric_update_chol(chol, idx, value):
"""Sets the value of a row and column in a Cholesky-factorized matrix."""
# TODO(davmre): is a more efficient direct implementation possible?
old_value = tf.reduce_sum(chol * chol[..., idx : idx + 1, :], axis=-1)
old_value = tf.ensure_shape(old_value, value.shape)
return _symmetric_increment_chol(chol, idx, increment=value - old_value)


Expand Down

0 comments on commit f81ed49

Please sign in to comment.