Skip to content

Commit

Permalink
Add an all_gather-based fallback for TF implementation of pmin/pmax.
Browse files Browse the repository at this point in the history
This can be used when the non-named reduction axes reduce the size of the
tensor enough that memory and bandwidth inefficiencies of `all_gather` are
insignificant.

PiperOrigin-RevId: 410679597
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Nov 18, 2021
1 parent c73714b commit 059d789
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 63 deletions.
58 changes: 45 additions & 13 deletions tensorflow_probability/python/internal/distribute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,13 @@ def _make_reduce_op(tensor_reduce_fn, collective_reduce_fn):

def reduce_fn(x, axis=None, named_axis=None, **kwargs):
named_axis = canonicalize_named_axis(named_axis)
allow_all_gather = kwargs.pop('allow_all_gather', None)
x = tensor_reduce_fn(x, axis=axis, **kwargs)
return collective_reduce_fn(x, named_axis=named_axis)
if allow_all_gather is None:
collection_kwargs = {}
else:
collection_kwargs = {'allow_all_gather': allow_all_gather}
return collective_reduce_fn(x, named_axis=named_axis, **collection_kwargs)

return reduce_fn

Expand Down Expand Up @@ -95,35 +100,62 @@ def pmean(x, named_axis=None):
reduce_mean = _make_reduce_op(tf.reduce_mean, pmean)


def pmax(x, named_axis=None):
def pmax(x, named_axis=None, allow_all_gather=False):
"""Generic `pmax` implementation."""
# TODO(b/187173243): fix gradients for pmax
axes = canonicalize_named_axis(named_axis)
for axis in axes:
if not JAX_MODE:
raise NotImplementedError('`pmax` not supported in TF')
x = lax.pmax(x, axis)
if JAX_MODE:
x = lax.pmax(x, axis)
elif allow_all_gather:
ctx = tf.distribute.get_replica_context()
x = tf.reduce_max(ctx.all_gather(x[tf.newaxis], axis=0), axis=0)
else:
raise NotImplementedError(
'`pmax` has no native implementation in TF. Pass in '
'`allow_all_gather=True` to enable a potentially '
'inefficient `all_gather`-based fallback. Also see b/191501877.'
)
return x


reduce_max = _make_reduce_op(tf.reduce_max, pmax)


def pmin(x, named_axis=None):
def pmin(x, named_axis=None, allow_all_gather=False):
"""Generic `pmin` implementation."""
# TODO(b/187173243): fix gradients for pmin
axis_name = canonicalize_named_axis(named_axis)
for name in axis_name:
if not JAX_MODE:
raise NotImplementedError('`pmax` not supported in TF')
x = lax.pmin(x, name)
axes = canonicalize_named_axis(named_axis)
for axis in axes:
if JAX_MODE:
x = lax.pmin(x, axis)
elif allow_all_gather:
ctx = tf.distribute.get_replica_context()
x = tf.reduce_min(ctx.all_gather(x[tf.newaxis], axis=0), axis=0)
else:
raise NotImplementedError(
'`pmin` has no native implementation in TF. Pass in '
'`allow_all_gather=True` to enable a potentially '
'inefficient `all_gather`-based fallback. Also see b/191501877.'
)
return x


reduce_min = _make_reduce_op(tf.reduce_min, pmin)


def reduce_logsumexp(x, axis=None, named_axis=None, **kwargs):
def reduce_logsumexp(x,
axis=None,
named_axis=None,
allow_all_gather=False,
**kwargs):
"""`logsumexp` wrapper."""
xmax = reduce_max(
tf.stop_gradient(x), axis=axis, named_axis=named_axis, keepdims=True)
tf.stop_gradient(x),
axis=axis,
named_axis=named_axis,
keepdims=True,
allow_all_gather=allow_all_gather)
xmax = tf.where(tf.math.is_finite(xmax), xmax, tf.zeros_like(xmax))
result = tf.math.log(
reduce_sum(tf.exp(x - xmax), axis=axis, named_axis=named_axis, **kwargs))
Expand Down
63 changes: 34 additions & 29 deletions tensorflow_probability/python/internal/distribute_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""Tests for tensorflow_probability.python.experimental.distribute.distribute_lib."""
import functools
import itertools

from absl.testing import parameterized
Expand All @@ -31,6 +32,10 @@
from jax import random # pylint: disable=g-import-not-at-top


def _allow_all_gather(fn):
return functools.partial(fn, allow_all_gather=True)


@test_util.test_all_tf_execution_regimes
class CollectiveTest(test_lib.DistributedTest):

Expand All @@ -44,13 +49,12 @@ def test_tf_should_error_with_more_than_one_named_axis(self):
@parameterized.named_parameters(
('sum', tf.reduce_sum, distribute_lib.reduce_sum),
('mean', tf.reduce_mean, distribute_lib.reduce_mean),
('max', tf.reduce_max, distribute_lib.reduce_max, True),
('min', tf.reduce_min, distribute_lib.reduce_min, True),
('logsumexp', tf.reduce_logsumexp, distribute_lib.reduce_logsumexp, True))
('max', tf.reduce_max, _allow_all_gather(distribute_lib.reduce_max)),
('min', tf.reduce_min, _allow_all_gather(distribute_lib.reduce_min)),
('logsumexp', tf.reduce_logsumexp,
_allow_all_gather(distribute_lib.reduce_logsumexp)))
def test_distributed_reduce_works_as_normal_with_int_axes(
self, reduce_op, distributed_op, jax_only=False):
if not JAX_MODE and jax_only:
self.skipTest('Only supported in JAX.')
self, reduce_op, distributed_op):
x = tf.reshape(
tf.range(test_lib.NUM_DEVICES * 6.) / 5., [test_lib.NUM_DEVICES, 3, 2])

Expand All @@ -63,24 +67,24 @@ def make_run(axis):
self.assertAllEqual(reduce_out, dist_out)

@parameterized.named_parameters(*(
(f'{name} {ax}', (op, d_op, jo), ax) # pylint: disable=g-complex-comprehension
for (name, op, d_op, jo), ax in itertools.product((
('sum', tf.reduce_sum, distribute_lib.reduce_sum, False),
('mean', tf.reduce_mean, distribute_lib.reduce_mean, False),
('max', tf.reduce_max, distribute_lib.reduce_max, True),
('min', tf.reduce_min, distribute_lib.reduce_min, True),
('logsumexp', tf.reduce_logsumexp, distribute_lib.reduce_logsumexp,
True)), (None, 0, 1, 2, [0, 1], [1, 2], [0, 2], [0, 1, 2]))))
(f'{name} {ax}', (op, d_op), ax) # pylint: disable=g-complex-comprehension
for (name, op, d_op), ax in itertools.product((
('sum', tf.reduce_sum, distribute_lib.reduce_sum),
('mean', tf.reduce_mean, distribute_lib.reduce_mean),
('max', tf.reduce_max, _allow_all_gather(distribute_lib.reduce_max)),
('min', tf.reduce_min, _allow_all_gather(distribute_lib.reduce_min)),
('logsumexp', tf.reduce_logsumexp,
_allow_all_gather(distribute_lib.reduce_logsumexp))), (
None, 0, 1, 2, [0, 1], [1, 2], [0, 2], [0, 1, 2]))))
def test_reduce_with_collectives_matches_reduce_without_collectives(
self, ops, axes):
reduce_op, distributed_op, jax_only = ops
if not JAX_MODE and jax_only:
self.skipTest('Only supported in JAX.')
reduce_op, distributed_op = ops
x = tf.reshape(
tf.range(test_lib.NUM_DEVICES * 6.) / 5., [test_lib.NUM_DEVICES, 3, 2])

def run(x):
return distributed_op(x, axis=pos_axes, named_axis=named_axes)
return distributed_op(
x, axis=pos_axes, named_axis=named_axes)

def distributed_run(x):
return self.per_replica_to_tensor(
Expand All @@ -104,25 +108,26 @@ def distributed_run(x):
self.assertAllClose(reduce_out, dist_out)

@parameterized.named_parameters(
('sum', tf.reduce_sum, distribute_lib.reduce_sum, False, True),
('mean', tf.reduce_mean, distribute_lib.reduce_mean, False, True),
('max', tf.reduce_max, distribute_lib.reduce_max, True, False),
('min', tf.reduce_min, distribute_lib.reduce_min, True, False),
('logsumexp', tf.reduce_logsumexp, distribute_lib.reduce_logsumexp, True,
True))
('sum', tf.reduce_sum, distribute_lib.reduce_sum, True),
('mean', tf.reduce_mean, distribute_lib.reduce_mean, True),
('max', tf.reduce_max, _allow_all_gather(
distribute_lib.reduce_max), False),
('min', tf.reduce_min, _allow_all_gather(distribute_lib.reduce_min),
False), ('logsumexp', tf.reduce_logsumexp,
_allow_all_gather(distribute_lib.reduce_logsumexp), True))
def test_reduce_with_collective_grads_matches_without_collectives(
self, reduce_op, distributed_op, jax_only, is_supported):
if not JAX_MODE and jax_only:
self.skipTest('Only supported in JAX.')
self, reduce_op, distributed_op, is_supported):
if not is_supported:
self.skipTest('Gradient of operation not supported.')
x = tf.reshape(
tf.range(test_lib.NUM_DEVICES * 6.) / 5., [test_lib.NUM_DEVICES, 3, 2])

def compute_dist_grads(x):
return tfp.math.value_and_gradient(
lambda x: distributed_op(x, axis=[0, 1], named_axis=self.axis_name),
[x])[1][0]
lambda x: distributed_op( # pylint: disable=g-long-lambda
x,
axis=[0, 1],
named_axis=self.axis_name), [x])[1][0]

def distributed_run(x):
return self.per_replica_to_tensor(
Expand Down
49 changes: 39 additions & 10 deletions tensorflow_probability/python/math/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,12 @@ def reduce_kahan_sum(input_tensor, axis=None, keepdims=False, name=None):
_reduce_kahan_sum(operands, inits, axis=axis, keepdims=keepdims))


def reduce_logmeanexp(input_tensor, axis=None, keepdims=False,
experimental_named_axis=None, name=None):
def reduce_logmeanexp(input_tensor,
axis=None,
keepdims=False,
experimental_named_axis=None,
experimental_allow_all_gather=False,
name=None):
"""Computes `log(mean(exp(input_tensor)))`.
Reduces `input_tensor` along the dimensions given in `axis`. Unless
Expand All @@ -237,6 +241,10 @@ def reduce_logmeanexp(input_tensor, axis=None, keepdims=False,
Default value: `False` (i.e., squeeze the reduced dimensions).
experimental_named_axis: A `str or list of `str` axis names to additionally
reduce over. Providing `None` will not reduce over any axes.
experimental_allow_all_gather: Allow using an `all_gather`-based fallback
under TensorFlow when computing the distributed maximum. This fallback is
only efficient when `axis` reduces away most of the dimensions of
`input_tensor`.
name: Python `str` name prefixed to Ops created by this function.
Default value: `None` (i.e., `'reduce_logmeanexp'`).
Expand All @@ -245,9 +253,12 @@ def reduce_logmeanexp(input_tensor, axis=None, keepdims=False,
"""
with tf.name_scope(name or 'reduce_logmeanexp'):
named_axes = distribute_lib.canonicalize_named_axis(experimental_named_axis)
lse = distribute_lib.reduce_logsumexp(input_tensor, axis=axis,
keepdims=keepdims,
named_axis=named_axes)
lse = distribute_lib.reduce_logsumexp(
input_tensor,
axis=axis,
keepdims=keepdims,
named_axis=named_axes,
allow_all_gather=experimental_allow_all_gather)
n = ps.size(input_tensor) // ps.size(lse)
for named_axis in named_axes:
n = n * distribute_lib.get_axis_size(named_axis)
Expand All @@ -261,6 +272,7 @@ def reduce_weighted_logsumexp(logx,
keep_dims=False,
return_sign=False,
experimental_named_axis=None,
experimental_allow_all_gather=False,
name=None):
"""Computes `log(abs(sum(weight * exp(elements across tensor dimensions))))`.
Expand Down Expand Up @@ -315,6 +327,11 @@ def reduce_weighted_logsumexp(logx,
return_sign: If `True`, returns the sign of the result.
experimental_named_axis: A `str or list of `str` axis names to additionally
reduce over. Providing `None` will not reduce over any axes.
experimental_allow_all_gather: Allow using an `all_gather`-based fallback
under TensorFlow when computing the distributed maximum. This fallback is
only efficient when `axis` reduces away most of the dimensions of
`input_tensor`.
name: A name for the operation (optional).
Returns:
Expand All @@ -324,9 +341,12 @@ def reduce_weighted_logsumexp(logx,
with tf.name_scope(name or 'reduce_weighted_logsumexp'):
logx = tf.convert_to_tensor(logx, name='logx')
if w is None:
lswe = distribute_lib.reduce_logsumexp(logx, axis=axis,
keepdims=keep_dims,
named_axis=experimental_named_axis)
lswe = distribute_lib.reduce_logsumexp(
logx,
axis=axis,
keepdims=keep_dims,
named_axis=experimental_named_axis,
allow_all_gather=experimental_allow_all_gather)
if return_sign:
sgn = tf.ones_like(lswe)
return lswe, sgn
Expand Down Expand Up @@ -361,6 +381,7 @@ def reduce_log_harmonic_mean_exp(input_tensor,
axis=None,
keepdims=False,
experimental_named_axis=None,
experimental_allow_all_gather=False,
name=None):
"""Computes `log(1 / mean(1 / exp(input_tensor)))`.
Expand All @@ -385,15 +406,23 @@ def reduce_log_harmonic_mean_exp(input_tensor,
Default value: `False` (i.e., squeeze the reduced dimensions).
experimental_named_axis: A `str or list of `str` axis names to additionally
reduce over. Providing `None` will not reduce over any axes.
experimental_allow_all_gather: Allow using an `all_gather`-based fallback
under TensorFlow when computing the distributed maximum. This fallback is
only efficient when `axis` reduces away most of the dimensions of
`input_tensor`.
name: Python `str` name prefixed to Ops created by this function.
Default value: `None` (i.e., `'reduce_log_harmonic_mean_exp'`).
Returns:
log_mean_exp: The reduced tensor.
"""
with tf.name_scope(name or 'reduce_log_harmonic_mean_exp'):
return -reduce_logmeanexp(-input_tensor, axis=axis, keepdims=keepdims,
experimental_named_axis=experimental_named_axis)
return -reduce_logmeanexp(
-input_tensor,
axis=axis,
keepdims=keepdims,
experimental_named_axis=experimental_named_axis,
experimental_allow_all_gather=experimental_allow_all_gather)


def soft_threshold(x, threshold, name=None):
Expand Down
26 changes: 15 additions & 11 deletions tensorflow_probability/python/math/generic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
JAX_MODE = False


def _allow_all_gather(fn):
return functools.partial(fn, experimental_allow_all_gather=True)


@test_util.test_all_tf_execution_regimes
class LogHarmonicMeanExpTest(test_util.TestCase):

Expand Down Expand Up @@ -704,19 +708,19 @@ class KahanSumTest(_KahanSumTest):
@test_util.test_all_tf_execution_regimes
class CollectiveTest(distribute_test_lib.DistributedTest):

@test_util.numpy_disable_test_missing_functionality(
'NumPy backend does not support distributed computation.')
@parameterized.named_parameters(*(
(f'{name} {ax}', (op, jo), ax) # pylint: disable=g-complex-comprehension
for (name, op, jo), ax in itertools.product((
('logmeanexp', tfp.math.reduce_logmeanexp, True),
('log_harmonicmeanexp', tfp.math.reduce_log_harmonic_mean_exp, True),
('reduce_weighted_logsumexp', tfp.math.reduce_weighted_logsumexp,
True),
), (None, 0, 1, 2, [0, 1], [1, 2], [0, 2], [0, 1, 2]))))
(f'{name} {ax}', op, ax) # pylint: disable=g-complex-comprehension
for (name, op), ax in itertools.product((
('logmeanexp', _allow_all_gather(tfp.math.reduce_logmeanexp)),
('log_harmonicmeanexp',
_allow_all_gather(tfp.math.reduce_log_harmonic_mean_exp)),
('reduce_weighted_logsumexp',
_allow_all_gather(tfp.math.reduce_weighted_logsumexp)),
), (None, 0, 1, 2, [0, 1], [1, 2], [0, 2], [0, 1, 2]))))
def test_reduce_with_collectives_matches_reduce_without_collectives(
self, op_info, axes):
reduce_op, jax_only = op_info
if not JAX_MODE and jax_only:
self.skipTest('Only supported in JAX.')
self, reduce_op, axes):

if axes is None:
pos_axes = list(range(2))
Expand Down

0 comments on commit 059d789

Please sign in to comment.