Skip to content

Commit

Permalink
Increase minimum jaxlib version to 0.1.74.
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp committed Nov 18, 2021
1 parent 52421da commit 3fd3c46
Show file tree
Hide file tree
Showing 15 changed files with 24 additions and 155 deletions.
6 changes: 2 additions & 4 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@
"Set this to `False` only if it crashes otherwise and report "
"the error to the jax-team.")
flags.DEFINE_bool(
"experimental_cpp_pmap",
bool_env("JAX_CPP_PMAP", jax._src.lib._xla_extension_version >= 39),
"experimental_cpp_pmap", bool_env("JAX_CPP_PMAP", True),
"A flag enabling the C++ jax.pmap fast path. Until the default "
"is switched to True, the feature is not supported and possibly broken "
"(e.g. it may use unreleased code from jaxlib.")
Expand Down Expand Up @@ -2014,8 +2013,7 @@ def cache_miss(*args, **kwargs):

return out, fastpath_data

# TODO(slebedev): Remove the ignore once jaxlib>=0.1.71.
cpp_mapped_f = pmap_lib.pmap(fun, cache_miss, # type: ignore[call-arg]
cpp_mapped_f = pmap_lib.pmap(fun, cache_miss,
static_broadcasted_tuple, pxla._shard_arg)

f_pmapped = wraps(fun)(cpp_mapped_f)
Expand Down
84 changes: 5 additions & 79 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
from jax._src.lib import pytree
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import version as jaxlib_version

xb = xla_bridge
xc = xla_client
Expand Down Expand Up @@ -2801,21 +2800,7 @@ def _abs_jvp_rule(g, ans, x):
lambda g, ans, x:
mul(g, mul(_const(x, -0.5), div(ans, x))))

# TODO(phawkins): remove the fallback translation rule after the minimum jaxlib
# is 0.1.70 or newer.
if jax._src.lib._xla_extension_version >= 28:
_cbrt_translation_rule = None
else:
def _cbrt_translation_rule(ctx, avals_in, avals_out, x):
x_aval, = avals_in
return [xops.Mul(
xops.Sign(x),
xops.Pow(xops.Abs(x),
xla.pyval_to_ir_constant(ctx.builder,
np.array(1/3, dtype=x_aval.dtype))))]

cbrt_p = standard_unop(_float, 'cbrt',
translation_rule=_cbrt_translation_rule)
cbrt_p = standard_unop(_float, 'cbrt')
ad.defjvp2(cbrt_p,
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))

Expand Down Expand Up @@ -5900,15 +5885,6 @@ def reducer_fn(op_val_index, acc_val_index):
axes)
return res[1]

def _argminmax_gpu_translation_rule(op, a, *, axes, index_dtype):
axis, = axes
idxs = tie_in(a, broadcasted_iota(index_dtype, a.shape, axis))
maxval = np.array(dtypes.iinfo(index_dtype).max, dtype=index_dtype)
maxval = broadcast(tie_in(a, maxval), a.shape)
maxvals = expand_dims(op(a, (axis,)), (axis,))
mask_idxs = select(eq(a, maxvals) | ne(a, a), idxs, maxval)
return _reduce_min(mask_idxs, (axis,))

_argmin_translation_rule = xla.lower_fun(
partial(_compute_argminmax, lt, _get_min_identity),
multiple_results=False, new_style=True)
Expand All @@ -5922,28 +5898,12 @@ def _argminmax_gpu_translation_rule(op, a, *, axes, index_dtype):
weak_type_rule=_strip_weak_type)
batching.defreducer(argmin_p)
ad.defjvp_zero(argmin_p)
if jax._src.lib._xla_extension_version < 41:
xla.register_translation(
argmin_p,
xla.lower_fun(
partial(_argminmax_gpu_translation_rule, _reduce_min),
multiple_results=False,
new_style=True),
platform='gpu')

argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
'argmax', _argmax_translation_rule,
weak_type_rule=_strip_weak_type)
batching.defreducer(argmax_p)
ad.defjvp_zero(argmax_p)
if jax._src.lib._xla_extension_version < 41:
xla.register_translation(
argmax_p,
xla.lower_fun(
partial(_argminmax_gpu_translation_rule, _reduce_max),
multiple_results=False,
new_style=True),
platform='gpu')


def _reduce_logical_shape_rule(operand, *, axes):
Expand Down Expand Up @@ -6927,7 +6887,6 @@ def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm):
def _rng_bit_generator_translation_rule(
ctx, avals_in, avals_out, key, *, shape, dtype, algorithm):
c = ctx.builder
backend_is_gpu = ctx.platform == "gpu"
key_shape, key_dtype = c.get_shape(key).dimensions(), c.get_shape(key).numpy_dtype()
# While the RngBitGenerator HLO accepts a u64[2] key on all backends, we
# typically represent the key argument to this primitive as a u32[4] so as to
Expand All @@ -6938,48 +6897,15 @@ def _rng_bit_generator_translation_rule(
(key_shape == (2,) and key_dtype == np.dtype('uint64'))), (key_shape, key_dtype)
xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape)
if key_dtype == np.dtype('uint32'):
# TODO(mattjj): the BitcastConvertType segfaults on GPU
# TODO(mattjj): remove fallback when minimum jaxlib is 0.1.72 or newer
if jaxlib_version >= (0, 1, 72) and not backend_is_gpu:
u64_etype = xla.dtype_to_primitive_type(np.dtype('uint64'))
key = xops.BitcastConvertType(xops.Reshape(key, (2, 2)), u64_etype)
else:
key = _convert_4xU32_to_2xU64_without_bitcast(c, key)
u64_etype = xla.dtype_to_primitive_type(np.dtype('uint64'))
key = xops.BitcastConvertType(xops.Reshape(key, (2, 2)), u64_etype)
out_key, out_vals = xla.xla_destructure(
c, xops.RngBitGenerator(algorithm, key, xla_shape))
if key_dtype == np.dtype('uint32'):
if jaxlib_version >= (0, 1, 72) and not backend_is_gpu:
u32_etype = xla.dtype_to_primitive_type(np.dtype('uint32'))
out_key = xops.Reshape(xops.BitcastConvertType(out_key, u32_etype), (4,))
else:
out_key = _convert_2xU64_to_4xU32_without_bitcast(c, out_key)
u32_etype = xla.dtype_to_primitive_type(np.dtype('uint32'))
out_key = xops.Reshape(xops.BitcastConvertType(out_key, u32_etype), (4,))
return [out_key, out_vals]

def _convert_4xU32_to_2xU64_without_bitcast(c, key):
u64_etype = xla.dtype_to_primitive_type(np.dtype('uint64'))
new_key = xops.Constant(c, np.zeros(2, dtype=np.dtype('uint64')))
_32 = xops.Constant(c, np.array(32, np.uint64))
for i in [0, 2]:
hi = xops.ConvertElementType(xops.Slice(key, [i] , [i+1], [1]), u64_etype)
lo = xops.ConvertElementType(xops.Slice(key, [i+1], [i+2], [1]), u64_etype)
elt = xops.Xor(xops.ShiftLeft(hi, _32), lo)
new_key = xops.DynamicUpdateSlice(new_key, elt,
[xla.pyval_to_ir_constant(c, i // 2)])
return new_key

def _convert_2xU64_to_4xU32_without_bitcast(c, key):
u32_etype = xla.dtype_to_primitive_type(np.dtype('uint32'))
new_key = xops.Constant(c, np.zeros(4, dtype=np.dtype('uint32')))
_32 = xops.Constant(c, np.array(32, np.uint64))
for i in [0, 1]:
elt = xops.Slice(key, [i], [i+1], [1])
hi = xops.ConvertElementType(xops.ShiftRightLogical(elt, _32), u32_etype)
lo = xops.ConvertElementType(elt, u32_etype)
new_key = xops.DynamicUpdateSlice(new_key, hi,
[xla.pyval_to_ir_constant(c, 2 * i)])
new_key = xops.DynamicUpdateSlice(new_key, lo,
[xla.pyval_to_ir_constant(c, 2 * i + 1)])
return new_key

def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
return [key.named_shape, key.named_shape]
Expand Down
6 changes: 0 additions & 6 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from jax._src.lib import rocsolver

from jax._src.lib import xla_client
from jax._src.lib import version as jaxlib_version

xops = xla_client.ops

Expand Down Expand Up @@ -1538,11 +1537,6 @@ def _schur_cpu_translation_rule(ctx, avals_in, avals_out, operand, *,
batch_dims = operand_aval.shape[:-2]
c = ctx.builder

if jaxlib_version < (0, 1, 72):
raise NotImplementedError(
"The Schur primitive is only implemented for jaxlib versions >= 0.1.72"
)

_cpu_gees = lapack.gees

if sort_eig_vals:
Expand Down
8 changes: 2 additions & 6 deletions jax/experimental/compilation_cache/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,7 @@ def _hash_compile_options(hash_obj, compile_options_obj):
hash_obj.update(compile_options_obj.device_assignment.serialize())

def _hash_executable_build_options(hash_obj, executable_obj):
if jax._src.lib.version >= (0, 1, 72):
expected_options = 31
else:
expected_options = 30
expected_options = 31
assert len(dir(executable_obj)) == expected_options, (
f"Unexpected number of executable_build_options fields: "
f"{len(dir(executable_obj))}. This likely means that an extra "
Expand All @@ -136,8 +133,7 @@ def _hash_executable_build_options(hash_obj, executable_obj):
if executable_obj.device_assignment is not None:
hash_obj.update(executable_obj.device_assignment.serialize())
_hash_bool(hash_obj, executable_obj.use_spmd_partitioning)
if jax._src.lib.version >= (0, 1, 72):
_hash_bool(hash_obj, executable_obj.allow_spmd_sharding_propagation_to_output)
_hash_bool(hash_obj, executable_obj.allow_spmd_sharding_propagation_to_output)

def _hash_debug_options(hash_obj, debug_obj):
_hash_bool(hash_obj, debug_obj.xla_cpu_enable_fast_math)
Expand Down
15 changes: 4 additions & 11 deletions jax/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,20 +1674,13 @@ def _rng_uniform_lowering(ctx, avals_in, avals_out, a, b, *, shape):
# xla_shape = xc.Shape.array_shape(np.dtype(dtype), shape)
# if key_dtype == np.dtype('uint32'):
# # TODO(mattjj): the BitcastConvertType segfaults on GPU
# # TODO(mattjj): remove fallback when minimum jaxlib is 0.1.72 or newer
# if jaxlib_version >= (0, 1, 72) and not backend_is_gpu:
# u64_etype = xla.dtype_to_primitive_type(np.dtype('uint64'))
# key = xops.BitcastConvertType(xops.Reshape(key, (2, 2)), u64_etype)
# else:
# key = _convert_4xU32_to_2xU64_without_bitcast(c, key)
# u64_etype = xla.dtype_to_primitive_type(np.dtype('uint64'))
# key = xops.BitcastConvertType(xops.Reshape(key, (2, 2)), u64_etype)
# out_key, out_vals = xla.xla_destructure(
# c, xops.RngBitGenerator(algorithm, key, xla_shape))
# if key_dtype == np.dtype('uint32'):
# if jaxlib_version >= (0, 1, 72) and not backend_is_gpu:
# u32_etype = xla.dtype_to_primitive_type(np.dtype('uint32'))
# out_key = xops.Reshape(xops.BitcastConvertType(out_key, u32_etype), (4,))
# else:
# out_key = _convert_2xU64_to_4xU32_without_bitcast(c, out_key)
# u32_etype = xla.dtype_to_primitive_type(np.dtype('uint32'))
# out_key = xops.Reshape(xops.BitcastConvertType(out_key, u32_etype), (4,))
# return [out_key, out_vals]


Expand Down
23 changes: 7 additions & 16 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,13 @@
import itertools as it
import operator as op
import threading
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional,
from typing import (Any, Callable, Dict, List, Optional,
Sequence, Set, Tuple, Type, Union, Iterable)
import sys

from absl import logging
import numpy as np

import jax
from .._src.config import config
from .. import core
from .. import linear_util as lu
Expand All @@ -55,7 +54,6 @@
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.lib import _xla_extension_version
from ..tree_util import tree_flatten, tree_map
from . import batching
from . import partial_eval as pe
Expand Down Expand Up @@ -97,9 +95,7 @@ class WeakRefList(list):
# mypy will consider this constant to be True at type check time.
MYPY = False

# TODO(jblespiau): Remove the version check when jaxlib 0.1.70 is the minimal
# version.
if MYPY or (not TYPE_CHECKING and _xla_extension_version < 30):
if MYPY:
class ShardingSpec:
"""Describes the sharding of an ndarray.
Expand Down Expand Up @@ -503,8 +499,8 @@ def gsda_array_result_handler(global_aval, global_mesh, out_axis_resources):

### lazy device-memory persistence and result handling

# TODO(jblespiau): Remove when jaxlib 0.1.72 is the minimal version.
_USE_CPP_SDA = _xla_extension_version >= 38
# TODO(jblespiau): Consider removing this option.
_USE_CPP_SDA = True


def make_sharded_device_array(
Expand Down Expand Up @@ -539,9 +535,8 @@ def make_sharded_device_array(
if (_USE_CPP_SDA and
(not device_buffers or
isinstance(device_buffers[0], xb.xla_client.Buffer))):
# TODO(slebedev): Remove the ignore once jaxlib>=0.1.71.
return pmap_lib.ShardedDeviceArray.make(
aval, sharding_spec, device_buffers, # type: ignore[arg-type, call-arg]
aval, sharding_spec, device_buffers,
indices, aval.weak_type)

return _ShardedDeviceArray(aval, sharding_spec, device_buffers, indices)
Expand Down Expand Up @@ -1783,12 +1778,8 @@ def __init__(self,
use_spmd_partitioning=spmd_lowering,
)
compile_options.parameter_is_tupled_arguments = tuple_args
if jax._src.lib.version >= (0, 1, 72):
compile_options.executable_build_options.allow_spmd_sharding_propagation_to_output = \
_allow_propagation_to_outputs
elif _allow_propagation_to_outputs:
raise RuntimeError("Propagation of SPMD sharding specs to outputs is only supported "
"in jaxlib 0.1.72+. Please update your JAX version.")
compile_options.executable_build_options.allow_spmd_sharding_propagation_to_output = \
_allow_propagation_to_outputs

local_sharding_spec = mesh_sharding_specs(local_axis_sizes, mesh.axis_names)
local_input_specs = [local_sharding_spec(aval, aval_in_axes)
Expand Down
2 changes: 1 addition & 1 deletion jax/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.

__version__ = "0.2.26"
_minimum_jaxlib_version = "0.1.69"
_minimum_jaxlib_version = "0.1.74"
6 changes: 0 additions & 6 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,12 @@ def jit(self):
# Tensorflow.
return api._cpp_jit

@unittest.skipIf(jax._src.lib._xla_extension_version < 40,
"Test requires jaxlib 0.1.73")
def test_jit_repr(self):
def my_function():
return
jitted = jit(my_function)
self.assertEqual(repr(jitted), f"<CompiledFunction of {repr(my_function)}>")

@unittest.skipIf(jax._src.lib._xla_extension_version < 40,
"Test requires jaxlib 0.1.73")
def test_jit_repr_errors(self):
class Callable:
def __call__(self): pass
Expand Down Expand Up @@ -692,8 +688,6 @@ def f(*args):
np.testing.assert_allclose(f_pruned(*args), 3)
self.assertEqual(count[0], 1)

@unittest.skipIf(jax._src.lib._xla_extension_version <= 36,
"Test requires jaxlib 0.1.71")
def testBuffersAreFreedPromptly(self):
# Regression test for a bug where garbage collection was delayed too long
# for NumPy buffers that are aliased zero-copy by the runtime.
Expand Down
4 changes: 1 addition & 3 deletions tests/debug_nans_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ def f(x):
f(1)

def testPmap(self):
pmap_funcs = [api._python_pmap]
if jax._src.lib._xla_extension_version >= 36:
pmap_funcs.append(api._cpp_pmap)
pmap_funcs = [api._python_pmap, api._cpp_pmap]

for pmap in pmap_funcs:
f = pmap(lambda x: 0. / x)
Expand Down
3 changes: 0 additions & 3 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import itertools
import typing
from typing import Any, Optional, Tuple
import unittest
import warnings

from absl.testing import absltest
Expand Down Expand Up @@ -1241,8 +1240,6 @@ def testIndexSequenceDeprecation(self, idx, idx_type):
with self.assertNoWarnings():
x.at[normalize(idx)].set(0)

@unittest.skipIf(jax._src.lib.version < (0, 1, 72),
"Bug fixed in jaxlib 0.1.72")
def testIndexedUpdateAliasingBug(self):
# https://github.com/google/jax/issues/7461
fn = lambda x: x.at[1:].set(1 + x[:-1])
Expand Down
4 changes: 0 additions & 4 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1453,8 +1453,6 @@ def test_tridiagonal_solve(self, dtype):
for dtype in float_types + complex_types))
@jtu.skip_on_devices("gpu", "tpu")
def testSchur(self, shape, dtype):
if jax._src.lib.version < (0, 1, 72):
self.skipTest("Schur LAPACK wrapper only implemented for jaxlib versions >= 0.1.72")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]

Expand All @@ -1470,8 +1468,6 @@ def testSchur(self, shape, dtype):
for dtype in float_types + complex_types))
@jtu.skip_on_devices("gpu", "tpu")
def testSchurBatching(self, shape, dtype):
if jax._src.lib.version < (0, 1, 72):
self.skipTest("Schur LAPACK wrapper only implemented for jaxlib versions >= 0.1.72")
rng = jtu.rand_default(self.rng())
batch_size = 10
shape = (batch_size, ) + shape
Expand Down
4 changes: 0 additions & 4 deletions tests/pickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
class CloudpickleTest(jtu.JaxTestCase):

@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
@unittest.skipIf(jax._src.lib._xla_extension_version < 31,
"Requires jaxlib 0.1.71")
def testPickleOfJittedFunctions(self):

@jax.jit
Expand All @@ -56,8 +54,6 @@ def g(z):
self.assertEqual(expected, actual)

@unittest.skipIf(cloudpickle is None, "Requires cloudpickle")
@unittest.skipIf(jax._src.lib._xla_extension_version < 39,
"Requires jaxlib 0.1.72")
def testPickleOfPmappedFunctions(self):

@jax.pmap
Expand Down
1 change: 0 additions & 1 deletion tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ def testNested(self):
self.assertTrue(hasattr(y, "sharding_spec"))

@check_1d_2d_mesh(set_mesh=True)
@unittest.skipIf(jax._src.lib.version < (0, 1, 72), "Needs jaxlib 0.1.72+")
def testAutodiff(self, mesh, resources):
if len(mesh) != 2: return
assert resources == ('x', 'y')
Expand Down
Loading

0 comments on commit 3fd3c46

Please sign in to comment.