Skip to content

Commit

Permalink
Added a few missing compute capability checks to Pallas:GPU tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 608348004
  • Loading branch information
superbobry authored and jax authors committed Feb 19, 2024
1 parent dcc65e6 commit 46ec581
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
1 change: 0 additions & 1 deletion tests/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ jax_test(
disable_configs = [
"gpu",
"gpu_a100",
"gpu_p100",
],
enable_configs = [
"gpu_x32",
Expand Down
6 changes: 6 additions & 0 deletions tests/pallas/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,8 @@ def f(x):
(2, 1, 1),
])
def test_atomic_cas(self, init_value, cmp, new_value):
if not self.check_gpu_capability_at_least(70):
raise unittest.SkipTest("requires a GPU with compute capability >= sm70")

@functools.partial(
self.pallas_call, out_shape=(
Expand All @@ -789,6 +791,10 @@ def swap(_, lock_ref, out_ref):
def test_atomic_counter(self, num_threads):
if self.INTERPRET:
self.skipTest("While loop not supported in interpreter mode.")

if not self.check_gpu_capability_at_least(70):
raise unittest.SkipTest("requires a GPU compute capability >= sm70")

@functools.partial(
self.pallas_call, out_shape=(
jax.ShapeDtypeStruct((), jnp.int32),
Expand Down

0 comments on commit 46ec581

Please sign in to comment.