Skip to content

Commit

Permalink
Merge pull request jax-ml#8606 from hawkinsp:laxsplit
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 411065571
  • Loading branch information
jax authors committed Nov 19, 2021
2 parents 602dd39 + 45d7ade commit 75e063c
Show file tree
Hide file tree
Showing 7 changed files with 817 additions and 755 deletions.
9 changes: 5 additions & 4 deletions jax/_src/lax/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from jax._src import source_info_util
from jax._src import util
from jax._src.lax import lax
from jax._src.lax import windowed_reductions
from jax import linear_util as lu
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax._src.api_util import flatten_fun_nokwargs
Expand Down Expand Up @@ -2747,11 +2748,11 @@ def _cumulative_reduction_primitive(name,
reducer_p)
return reducer_p

cumsum_p = _cumulative_reduction_primitive("cumsum", lax.add, lax._reduce_window_sum)
cumsum_p = _cumulative_reduction_primitive("cumsum", lax.add, windowed_reductions._reduce_window_sum)
ad.deflinear2(cumsum_p, _cumsum_transpose_rule)
cumprod_p = _cumulative_reduction_primitive("cumprod", lax.mul, lax._reduce_window_prod)
cummax_p = _cumulative_reduction_primitive("cummax", lax.max, lax._reduce_window_max)
cummin_p = _cumulative_reduction_primitive("cummin", lax.min, lax._reduce_window_min)
cumprod_p = _cumulative_reduction_primitive("cumprod", lax.mul, windowed_reductions._reduce_window_prod)
cummax_p = _cumulative_reduction_primitive("cummax", lax.max, windowed_reductions._reduce_window_max)
cummin_p = _cumulative_reduction_primitive("cummin", lax.min, windowed_reductions._reduce_window_min)


def _cumulative_jvp_rule(primals, tangents, *, axis: int, reverse: bool,
Expand Down
708 changes: 0 additions & 708 deletions jax/_src/lax/lax.py

Large diffs are not rendered by default.

763 changes: 763 additions & 0 deletions jax/_src/lax/windowed_reductions.py

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions jax/experimental/jax2tf/impl_no_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from jax._src import dtypes
from jax._src import util
from jax._src.lax import lax
from jax._src.lax import windowed_reductions

from . import jax2tf

Expand Down Expand Up @@ -479,18 +480,18 @@ def error(msg):


# pylint: disable=protected-access
tf_impl_no_xla[lax.reduce_window_sum_p] = (
tf_impl_no_xla[windowed_reductions.reduce_window_sum_p] = (
partial(_reduce_window, name="reduce_window_sum"))
tf_impl_no_xla[lax.reduce_window_max_p] = (
tf_impl_no_xla[windowed_reductions.reduce_window_max_p] = (
partial(_reduce_window, name="reduce_window_max"))
# pylint: enable=protected-access

tf_impl_no_xla[lax.reduce_window_min_p] = _unimplemented("reduce_window_min")
tf_impl_no_xla[lax.reduce_window_p] = _unimplemented("reduce_window")
tf_impl_no_xla[windowed_reductions.reduce_window_min_p] = _unimplemented("reduce_window_min")
tf_impl_no_xla[windowed_reductions.reduce_window_p] = _unimplemented("reduce_window")

tf_impl_no_xla[lax.reduce_p] = _unimplemented("reduce")

tf_impl_no_xla[lax.select_and_scatter_add_p] = _unimplemented(
tf_impl_no_xla[windowed_reductions.select_and_scatter_add_p] = _unimplemented(
"select_and_scatter_add")

tf_impl_no_xla[lax.rng_bit_generator_p] = _unimplemented("rng_bit_generator")
Expand Down
23 changes: 12 additions & 11 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from jax._src.lax import fft as lax_fft
from jax._src.lax import lax
from jax._src.lax import linalg as lax_linalg
from jax._src.lax import windowed_reductions
import jax._src.prng
import jax._src.random
from jax.experimental import maps
Expand Down Expand Up @@ -1708,7 +1709,7 @@ def reducer(x, y):
return snd(out)


tf_impl_with_avals[lax.select_and_gather_add_p] = _select_and_gather_add
tf_impl_with_avals[windowed_reductions.select_and_gather_add_p] = _select_and_gather_add


def _get_shape_from_tensor_or_array(x):
Expand Down Expand Up @@ -1837,20 +1838,20 @@ def _get_min_identity(tf_dtype):


# pylint: disable=protected-access
tf_impl_with_avals[lax.reduce_window_sum_p] = (
tf_impl_with_avals[windowed_reductions.reduce_window_sum_p] = (
partial(_specialized_reduce_window, _add, lambda x: 0,
name="reduce_window_sum"))
tf_impl_with_avals[lax.reduce_window_min_p] = (
tf_impl_with_avals[windowed_reductions.reduce_window_min_p] = (
partial(_specialized_reduce_window,
partial(_minmax_scalar, is_min=True),
_get_min_identity,
name="reduce_window_min"))
tf_impl_with_avals[lax.reduce_window_max_p] = (
tf_impl_with_avals[windowed_reductions.reduce_window_max_p] = (
partial(_specialized_reduce_window,
partial(_minmax_scalar, is_min=False),
_get_max_identity,
name="reduce_window_max"))
tf_impl_with_avals[lax.reduce_window_p] = _reduce_window
tf_impl_with_avals[windowed_reductions.reduce_window_p] = _reduce_window
# pylint: enable=protected-access

def _reduce(*operands: TfVal,
Expand Down Expand Up @@ -1894,12 +1895,12 @@ def reducer_computation(*args: TfVal) -> TfVal:
# instead to favor different backends.
tf_impl_with_avals[lax_control_flow.cummin_p] = _convert_jax_impl(
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_min),
windowed_reductions._reduce_window_min),
multiple_results=False,
extra_name_stack="cummin")
tf_impl_with_avals[lax_control_flow.cummax_p] = _convert_jax_impl(
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_max),
windowed_reductions._reduce_window_max),
multiple_results=False,
extra_name_stack="cummin")
# TODO(bchetioui): cumsum and cumprod can be converted using pure TF ops for
Expand All @@ -1909,12 +1910,12 @@ def reducer_computation(*args: TfVal) -> TfVal:
# tests will crash.
tf_impl_with_avals[lax_control_flow.cumsum_p] = _convert_jax_impl(
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_sum),
windowed_reductions._reduce_window_sum),
multiple_results=False,
extra_name_stack="cumsum")
tf_impl_with_avals[lax_control_flow.cumprod_p] = _convert_jax_impl(
partial(lax_control_flow._cumred_tpu_translation_rule,
lax._reduce_window_prod),
windowed_reductions._reduce_window_prod),
multiple_results=False,
extra_name_stack="cumprod")

Expand All @@ -1925,7 +1926,7 @@ def _select_and_scatter(operand, source, init_value, select_jaxpr,
raise NotImplementedError("TODO: jax2tf can not convert _select_and_scatter")


tf_impl[lax.select_and_scatter_p] = _select_and_scatter
tf_impl[windowed_reductions.select_and_scatter_p] = _select_and_scatter


@partial(bool_to_int8, argnums=(0, 1))
Expand All @@ -1943,7 +1944,7 @@ def _select_and_scatter_add(source, operand, *, select_prim, window_dimensions,
return out


tf_impl_with_avals[lax.select_and_scatter_add_p] = _select_and_scatter_add
tf_impl_with_avals[windowed_reductions.select_and_scatter_add_p] = _select_and_scatter_add


def _threefry2x32_jax_impl(*args: TfVal, _in_avals, _out_aval):
Expand Down
16 changes: 7 additions & 9 deletions jax/experimental/jet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
import jax.linear_util as lu
from jax.interpreters import xla
from jax.custom_derivatives import custom_jvp_call_jaxpr_p
from jax._src.lax import lax
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax import fft as lax_fft
from jax import lax

def jet(fun, primals, series):
try:
Expand Down Expand Up @@ -241,24 +239,24 @@ def linear_prop(prim, primals_in, series_in, **params):
deflinear(lax.slice_p)
deflinear(lax.reduce_sum_p)
deflinear(lax.reduce_window_sum_p)
deflinear(lax_fft.fft_p)
deflinear(lax.fft_p)
deflinear(xla.device_put_p)

def _cumulative_jet_rule(primals_in, series_in, *, axis: int, reverse: bool,
combine_fn: Callable):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return jet(partial(lax_control_flow.associative_scan, combine_fn, axis=axis,
return jet(partial(lax.associative_scan, combine_fn, axis=axis,
reverse=reverse),
primals_in, series_in)

deflinear(lax_control_flow.cumsum_p)
jet_rules[lax_control_flow.cumprod_p] = partial(_cumulative_jet_rule,
deflinear(lax.cumsum_p)
jet_rules[lax.cumprod_p] = partial(_cumulative_jet_rule,
combine_fn=lax.mul)
jet_rules[lax_control_flow.cummax_p] = partial(_cumulative_jet_rule,
jet_rules[lax.cummax_p] = partial(_cumulative_jet_rule,
combine_fn=lax.max)
jet_rules[lax_control_flow.cummin_p] = partial(_cumulative_jet_rule,
jet_rules[lax.cummin_p] = partial(_cumulative_jet_rule,
combine_fn=lax.min)


Expand Down
42 changes: 24 additions & 18 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,6 @@
reduce_precision_p as reduce_precision_p,
reduce_prod_p as reduce_prod_p,
reduce_sum_p as reduce_sum_p,
reduce_window as reduce_window,
reduce_window_max_p as reduce_window_max_p,
reduce_window_min_p as reduce_window_min_p,
reduce_window_p as reduce_window_p,
reduce_window_shape_tuple as reduce_window_shape_tuple,
reduce_window_sum_p as reduce_window_sum_p,
regularized_incomplete_beta_p as regularized_incomplete_beta_p,
rem as rem,
rem_p as rem_p,
Expand All @@ -247,9 +241,6 @@
scatter_mul_p as scatter_mul_p,
scatter_p as scatter_p,
select as select,
select_and_gather_add_p as select_and_gather_add_p,
select_and_scatter_add_p as select_and_scatter_add_p,
select_and_scatter_p as select_and_scatter_p,
select_p as select_p,
shift_left as shift_left,
shift_left_p as shift_left_p,
Expand Down Expand Up @@ -295,15 +286,30 @@
xor_p as xor_p,
zeros_like_array as zeros_like_array,
)
from jax._src.lax.lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
_reduce_and, _reduce_window_sum, _reduce_window_max,
_reduce_window_min, _reduce_window_prod,
_select_and_gather_add,
_select_and_scatter_add, _float, _complex, _input_dtype,
_const, _eq_meet, _broadcasting_select,
_check_user_dtype_supported, _one, _zero,
_upcast_fp16_for_computation, _broadcasting_shape_rule,
_eye, _tri, _delta, _ones, _zeros, _dilate_shape)
from jax._src.lax.lax import (
_reduce_sum, _reduce_max, _reduce_min, _reduce_or, _reduce_and,
_float, _complex, _input_dtype,
_const, _eq_meet, _broadcasting_select,
_check_user_dtype_supported, _one, _zero,
_upcast_fp16_for_computation, _broadcasting_shape_rule,
_eye, _tri, _delta, _ones, _zeros, _dilate_shape)
from jax._src.lax.windowed_reductions import (
_reduce_window_sum,
_reduce_window_max,
_reduce_window_min,
_reduce_window_prod,
_select_and_gather_add,
_select_and_scatter_add,
reduce_window as reduce_window,
reduce_window_max_p as reduce_window_max_p,
reduce_window_min_p as reduce_window_min_p,
reduce_window_p as reduce_window_p,
reduce_window_shape_tuple as reduce_window_shape_tuple,
reduce_window_sum_p as reduce_window_sum_p,
select_and_gather_add_p as select_and_gather_add_p,
select_and_scatter_p as select_and_scatter_p,
select_and_scatter_add_p as select_and_scatter_add_p,
)
from jax._src.lax.control_flow import (
associative_scan as associative_scan,
cond as cond,
Expand Down

0 comments on commit 75e063c

Please sign in to comment.