From 03ce8ca0ca2618804664804f7211318fadf0d60c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 17 Jan 2024 12:53:24 -0800 Subject: [PATCH] jax.random: deprecate passing of batched keys to APIs --- CHANGELOG.md | 3 + jax/_src/random.py | 111 ++++++++++-------- .../jax2tf/tests/shape_poly_test.py | 2 +- tests/key_reuse_test.py | 2 +- tests/random_lax_test.py | 28 +++++ tests/shape_poly_test.py | 2 +- 6 files changed, 96 insertions(+), 52 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2deb275b6e0f..97700970dec0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -81,6 +81,9 @@ Remember to align the itemized text with the first line of an item within a list * Support for the mhlo MLIR dialect has been deprecated. JAX no longer uses the mhlo dialect, in favor of stablehlo. APIs that refer to "mhlo" will be removed in the future. Use the "stablehlo" dialect instead. + * {mod}`jax.random`: passing batched keys directly to random number generation functions, + such as {func}`~jax.random.bits`, {func}`~jax.random.gamma`, and others, is deprecated + and will emit a `FutureWarning`. Use `jax.vmap` for explicit batching. ## jaxlib 0.4.24 diff --git a/jax/_src/random.py b/jax/_src/random.py index bf16cd80b335..01bdcc6a0958 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -69,12 +69,18 @@ def _isnan(x: ArrayLike) -> Array: return lax.ne(x, x) -def _check_prng_key(key: KeyArrayLike) -> tuple[KeyArray, bool]: +# TODO(jakevdp) Finalize batched input deprecation by setting error_on_batched=True. +# FutureWarning Added 2024-01-17 +def _check_prng_key(name: str, key: KeyArrayLike, *, + allow_batched: bool = False, + error_on_batched: bool = False) -> tuple[KeyArray, bool]: if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key): - return key, False + wrapped_key = key + wrapped = False elif _arraylike(key): # Call random_wrap here to surface errors for invalid keys. wrapped_key = prng.random_wrap(key, impl=default_prng_impl()) + wrapped = True if config.legacy_prng_key.value == 'error': raise ValueError( 'Legacy uint32 key array passed as key to jax.random function. ' @@ -91,10 +97,20 @@ def _check_prng_key(key: KeyArrayLike) -> tuple[KeyArray, bool]: 'Raw arrays as random keys to jax.random functions are deprecated. ' 'Assuming valid threefry2x32 key for now.', FutureWarning) - return wrapped_key, True else: raise TypeError(f'unexpected PRNG key type {type(key)}') + if (not allow_batched) and wrapped_key.ndim: + msg = (f"{name} accepts a single key, but was given a key array of " + f"shape {np.shape(key)} != (). Use jax.vmap for batching.") + if error_on_batched: + raise ValueError(msg) + else: + warnings.warn(msg + " In a future JAX version, this will be an error.", + FutureWarning, stacklevel=3) + + return wrapped_key, wrapped + def _return_prng_keys(was_wrapped, key): # TODO(frostig): remove once we always enable_custom_prng @@ -245,10 +261,7 @@ def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray: A new PRNG key that is a deterministic function of the inputs and is statistically safe for producing a stream of new pseudo-random values. """ - key, wrapped = _check_prng_key(key) - if np.ndim(key): - raise TypeError("fold_in accepts a single key, but was given a key array of" - f"shape {np.shape(key)} != (). Use jax.vmap for batching.") + key, wrapped = _check_prng_key("fold_in", key, error_on_batched=True) if np.ndim(data): raise TypeError("fold_in accepts a scalar, but was given an array of" f"shape {np.shape(data)} != (). Use jax.vmap for batching.") @@ -262,7 +275,7 @@ def _split(key: KeyArray, num: int | tuple[int, ...] = 2) -> KeyArray: # to always enable_custom_prng assert jnp.issubdtype(key.dtype, dtypes.prng_key) if key.ndim: - raise TypeError("split accepts a single key, but was given a key array of" + raise TypeError("split accepts a single key, but was given a key array of " f"shape {key.shape} != (). Use jax.vmap for batching.") shape = tuple(num) if isinstance(num, Sequence) else (num,) return prng.random_split(key, shape=shape) @@ -278,7 +291,7 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray: Returns: An array-like object of `num` new PRNG keys. """ - typed_key, wrapped = _check_prng_key(key) + typed_key, wrapped = _check_prng_key("split", key, error_on_batched=True) return _return_prng_keys(wrapped, _split(typed_key, num)) @@ -288,7 +301,7 @@ def _key_impl(keys: KeyArray) -> PRNGImpl: return keys_dtype._impl def key_impl(keys: KeyArrayLike) -> Hashable: - typed_keys, _ = _check_prng_key(keys) + typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True) return PRNGSpec(_key_impl(typed_keys)) @@ -298,7 +311,7 @@ def _key_data(keys: KeyArray) -> Array: def key_data(keys: KeyArrayLike) -> Array: """Recover the bits of key data underlying a PRNG key array.""" - keys, _ = _check_prng_key(keys) + keys, _ = _check_prng_key("key_data", keys, allow_batched=True) return _key_data(keys) @@ -350,7 +363,7 @@ def bits(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("bits", key) if dtype is None: dtype = dtypes.canonicalize_dtype(jnp.uint) else: @@ -383,7 +396,7 @@ def uniform(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("uniform", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): @@ -452,7 +465,7 @@ def randint(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("randint", key) dtypes.check_user_dtype_supported(dtype) dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) @@ -535,7 +548,7 @@ def shuffle(key: KeyArrayLike, x: ArrayLike, axis: int = 0) -> Array: msg = ("jax.random.shuffle is deprecated and will be removed in a future release. " "Use jax.random.permutation with independent=True.") warnings.warn(msg, FutureWarning) - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("shuffle", key) return _shuffle(key, x, axis) # type: ignore @@ -556,7 +569,7 @@ def permutation(key: KeyArrayLike, Returns: A shuffled version of x or array range """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("permutation", key) check_arraylike("permutation", x) axis = canonicalize_axis(axis, np.ndim(x) or 1) if not np.ndim(x): @@ -630,7 +643,7 @@ def choice(key: KeyArrayLike, Returns: An array of shape `shape` containing samples from `a`. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("choice", key) if not isinstance(shape, Sequence): raise TypeError("shape argument of jax.random.choice must be a sequence, " f"got {shape}") @@ -697,7 +710,7 @@ def normal(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("normal", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to `normal` must be a float or complex dtype, " @@ -764,7 +777,7 @@ def multivariate_normal(key: KeyArrayLike, ``shape + mean.shape[-1:]`` if ``shape`` is not None, or else ``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("multivariate_normal", key) dtypes.check_user_dtype_supported(dtype) mean, cov = promote_dtypes_inexact(mean, cov) if method not in {'svd', 'eigh', 'cholesky'}: @@ -843,7 +856,7 @@ def truncated_normal(key: KeyArrayLike, ``shape`` is not None, or else by broadcasting ``lower`` and ``upper``. Returns values in the open interval ``(lower, upper)``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("truncated_normal", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `truncated_normal` must be a float " @@ -901,7 +914,7 @@ def bernoulli(key: KeyArrayLike, A random array with boolean dtype and shape given by ``shape`` if ``shape`` is not None, or else ``p.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("bernoulli", key) dtype = dtypes.canonicalize_dtype(lax.dtype(p)) if shape is not None: shape = core.as_named_shape(shape) @@ -952,7 +965,7 @@ def beta(key: KeyArrayLike, A random array with the specified dtype and shape given by ``shape`` if ``shape`` is not None, or else by broadcasting ``a`` and ``b``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("beta", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `beta` must be a float " @@ -1005,7 +1018,7 @@ def cauchy(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("cauchy", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `cauchy` must be a float " @@ -1057,7 +1070,7 @@ def dirichlet(key: KeyArrayLike, ``shape + (alpha.shape[-1],)`` if ``shape`` is not None, or else ``alpha.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("dirichlet", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `dirichlet` must be a float " @@ -1116,7 +1129,7 @@ def exponential(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("exponential", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `exponential` must be a float " @@ -1297,7 +1310,7 @@ def gamma(key: KeyArrayLike, loggamma : sample gamma values in log-space, which can provide improved accuracy for small values of ``a``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("gamma", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `gamma` must be a float " @@ -1339,7 +1352,7 @@ def loggamma(key: KeyArrayLike, See Also: gamma : standard gamma sampler. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("loggamma", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `gamma` must be a float " @@ -1475,7 +1488,7 @@ def poisson(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape is not None, or else by ``lam.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("poisson", key) dtypes.check_user_dtype_supported(dtype) # TODO(frostig): generalize underlying poisson implementation and # remove this check @@ -1515,7 +1528,7 @@ def gumbel(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("gumbel", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `gumbel` must be a float " @@ -1550,7 +1563,7 @@ def categorical(key: KeyArrayLike, A random array with int dtype and shape given by ``shape`` if ``shape`` is not None, or else ``np.delete(logits.shape, axis)``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("categorical", key) check_arraylike("categorical", logits) logits_arr = jnp.asarray(logits) @@ -1593,7 +1606,7 @@ def laplace(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("laplace", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `laplace` must be a float " @@ -1630,7 +1643,7 @@ def logistic(key: KeyArrayLike, Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("logistic", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `logistic` must be a float " @@ -1673,7 +1686,7 @@ def pareto(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``b.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("pareto", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `pareto` must be a float " @@ -1722,7 +1735,7 @@ def t(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``df.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("t", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `t` must be a float " @@ -1775,7 +1788,7 @@ def chisquare(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``df.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("chisquare", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `chisquare` must be a float " @@ -1833,7 +1846,7 @@ def f(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``df.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("f", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `f` must be a float " @@ -1885,7 +1898,7 @@ def rademacher(key: KeyArrayLike, a 50% change of being 1 or -1. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("rademacher", key) dtypes.check_user_dtype_supported(dtype) dtype = dtypes.canonicalize_dtype(dtype) shape = core.canonicalize_shape(shape) @@ -1921,7 +1934,7 @@ def maxwell(key: KeyArrayLike, """ # Generate samples using: # sqrt(X^2 + Y^2 + Z^2), X,Y,Z ~N(0,1) - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("maxwell", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `maxwell` must be a float " @@ -1964,7 +1977,7 @@ def double_sided_maxwell(key: KeyArrayLike, A jnp.array of samples. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("double_sided_maxwell", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `double_sided_maxwell` must be a float" @@ -2016,7 +2029,7 @@ def weibull_min(key: KeyArrayLike, A jnp.array of samples. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("weibull_min", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError(f"dtype argument to `weibull_min` must be a float " @@ -2055,7 +2068,7 @@ def orthogonal( Returns: A random array of shape `(*shape, n, n)` and specified dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("orthogonal", key) dtypes.check_user_dtype_supported(dtype) _check_shape("orthogonal", shape) n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()") @@ -2090,7 +2103,7 @@ def generalized_normal( Returns: A random array with the specified shape and dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("generalized_normal", key) dtypes.check_user_dtype_supported(dtype) _check_shape("generalized_normal", shape) keys = split(key) @@ -2120,7 +2133,7 @@ def ball( Returns: A random array of shape `(*shape, d)` and specified dtype. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("ball", key) dtypes.check_user_dtype_supported(dtype) _check_shape("ball", shape) d = core.concrete_or_error(index, d, "The error occurred in jax.random.ball()") @@ -2158,7 +2171,7 @@ def rayleigh(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``scale.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("rayleigh", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `rayleigh` must be a float " @@ -2212,7 +2225,7 @@ def wald(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``mean.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("wald", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `wald` must be a float " @@ -2268,7 +2281,7 @@ def geometric(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``p.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("geometric", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.integer): raise ValueError("dtype argument to `geometric` must be an int " @@ -2330,7 +2343,7 @@ def triangular(key: KeyArrayLike, A random array with the specified dtype and with shape given by ``shape`` if ``shape`` is not None, or else by ``left.shape``, ``mode.shape`` and ``right.shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("triangular", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): raise ValueError("dtype argument to `triangular` must be a float " @@ -2384,7 +2397,7 @@ def lognormal(key: KeyArrayLike, Returns: A random array with the specified dtype and with shape given by ``shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("lognormal", key) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.inexact): raise ValueError(f"dtype argument to `lognormal` must be a float or complex dtype, " @@ -2597,7 +2610,7 @@ def binomial( A random array with the specified dtype and with shape given by ``np.broadcast(n, p).shape``. """ - key, _ = _check_prng_key(key) + key, _ = _check_prng_key("binomial", key) check_arraylike("binomial", n, p) dtypes.check_user_dtype_supported(dtype) if not dtypes.issubdtype(dtype, np.floating): diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 44d54fae23b1..e5fb13454b6a 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -2108,7 +2108,7 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10] # non-partitionable), and unsafe_rbg. [ PolyHarness("random_gamma", f"{flags_name}", - lambda key, a: jax.random.gamma(key, a), + lambda key, a: jax.vmap(jax.random.gamma)(key, a), arg_descriptors=[RandArg((3, key_size), np.uint32), RandArg((3, 4, 5), _f32)], polymorphic_shapes=["b, ...", "b, w, ..."], tol=1E-5, override_jax_config_flags=override_jax_config_flags), # type: ignore diff --git a/tests/key_reuse_test.py b/tests/key_reuse_test.py index 705dc7baa7e6..d42217a4c78f 100644 --- a/tests/key_reuse_test.py +++ b/tests/key_reuse_test.py @@ -616,7 +616,7 @@ def test_reuse_after_broadcast(self): def f(): key = jax.random.key(0) key2 = key[None] - return jax.random.bits(key) + jax.random.bits(key2) + return jax.random.bits(key) + jax.vmap(jax.random.bits)(key2) with self.assertRaisesRegex(KeyReuseError, self.random_bits_error): self.check_key_reuse(f) diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index ef18ebe96919..f7361f74330c 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -1247,6 +1247,34 @@ def testBinomialCornerCases(self): self.assertArraysAllClose(samples2, jnp.array([jnp.nan, 0., jnp.nan, jnp.nan]), check_dtypes=False) self.assertArraysAllClose(samples3, jnp.array([jnp.nan, jnp.nan, jnp.nan]), check_dtypes=False) + def test_batched_key_warnings(self): + keys = jax.random.split(self.make_key(0)) + msg = "{} accepts a single key, but was given a key array of shape.*" + + # Check a handful of functions that are expected to warn. + with self.assertWarnsRegex(FutureWarning, msg.format('bits')): + jax.random.bits(keys, shape=(2,)) + with self.assertWarnsRegex(FutureWarning, msg.format('chisquare')): + jax.random.chisquare(keys, 1.0, shape=(2,)) + with self.assertWarnsRegex(FutureWarning, msg.format('dirichlet')): + jax.random.dirichlet(keys, jnp.arange(2.0), shape=(2,)) + with self.assertWarnsRegex(FutureWarning, msg.format('gamma')): + jax.random.gamma(keys, 1.0, shape=(2,)) + with self.assertWarnsRegex(FutureWarning, msg.format('loggamma')): + jax.random.loggamma(keys, 1.0, shape=(2,)) + + # Other functions should error; test a few cases. + with self.assertRaisesRegex(ValueError, msg.format('fold_in')): + jax.random.fold_in(keys, 0) + with self.assertRaisesRegex(ValueError, msg.format('split')): + jax.random.split(keys) + + # Some shouldn't error or warn + with self.assertNoWarnings(): + jax.random.key_data(keys) + jax.random.key_impl(keys) + + threefry_seed = prng_internal.threefry_seed threefry_split = prng_internal.threefry_split threefry_random_bits = prng_internal.threefry_random_bits diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 9495d4672e24..59ee6d3e030b 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -1917,7 +1917,7 @@ def test_vmap_error(self): # non-partitionable), and unsafe_rbg. [ PolyHarness("random_gamma", f"{flags_name}", - lambda key, a: jax.random.gamma( + lambda key, a: jax.vmap(jax.random.gamma)( jax.random.wrap_key_data(key), a), arg_descriptors=[RandArg((3, key_size), np.uint32), RandArg((3, 4, 5), _f32)],