Skip to content

Commit

Permalink
[PJRT C API] Bump the minimum support libtpu version because there is…
Browse files Browse the repository at this point in the history
… a breaking change (openxla/xla@075d25e).

Also remove skip condition that are no longer needed because of this bump.

PiperOrigin-RevId: 611288492
  • Loading branch information
Jieying Luo authored and jax authors committed Feb 29, 2024
1 parent 550ce44 commit 4c57d09
Show file tree
Hide file tree
Showing 4 changed files with 1 addition and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
tpu-type: ["v3-8", "v4-8", "v5e-4"]
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu-type }})"
env:
LIBTPU_OLDEST_VERSION_DATE: 20231030
LIBTPU_OLDEST_VERSION_DATE: 20240228
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
timeout-minutes: 120
Expand Down
3 changes: 0 additions & 3 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import functools
import math
from absl.testing import absltest
Expand Down Expand Up @@ -1080,8 +1079,6 @@ class ActivationOffloadingTest(jtu.JaxTestCase):
def setUp(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest("Memories do not work on CPU and GPU backends yet.")
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 23)):
self.skipTest("Memories do not work on Cloud TPU older than 2024/02/23.")
super().setUp()

def test_remat_jaxpr_offloadable(self):
Expand Down
11 changes: 0 additions & 11 deletions tests/pallas/pallas_call_tpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

"""Test TPU-specific extensions to pallas_call."""

import datetime
import functools
from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -49,8 +48,6 @@ def setUp(self):
super().setUp()
if not self.interpret and jtu.device_under_test() != 'tpu':
self.skipTest('Only interpret mode supported on non-TPU')
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 10)):
self.skipTest('Does not work on Cloud TPU older than 2024/02/10.')

def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.interpret)
Expand Down Expand Up @@ -346,8 +343,6 @@ def dynamic_kernel(steps):

# TODO(apaszke): Add tests for scalar_prefetch too
def test_dynamic_grid_scalar_input(self):
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 14)):
self.skipTest('Does not work on Cloud TPU older than 2024/02/14.')
shape = (8, 128)
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)

Expand Down Expand Up @@ -441,9 +436,6 @@ def dynamic_kernel(x, steps):
)

def test_num_programs(self):
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 27)):
self.skipTest('Does not work on Cloud TPU older than 2024/02/27.')

def kernel(y_ref):
y_ref[0, 0] = pl.num_programs(0)

Expand All @@ -459,9 +451,6 @@ def dynamic_kernel(steps):
self.assertEqual(dynamic_kernel(4), 8)

def test_num_programs_block_spec(self):
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 27)):
self.skipTest('Does not work on Cloud TPU older than 2024/02/27.')

def kernel(x_ref, y_ref):
y_ref[...] = x_ref[...]

Expand Down
3 changes: 0 additions & 3 deletions tests/shard_alike_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import os

import jax
Expand Down Expand Up @@ -67,8 +66,6 @@ def setUp(self):
super().setUp()
if xla_extension_version < 227:
self.skipTest('Requires xla_extension_version >= 227')
if not jtu.if_cloud_tpu_at_least(datetime.date(2024, 2, 23)):
self.skipTest("Requires Cloud TPU older than 2024/02/23.")

def test_basic(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
Expand Down

0 comments on commit 4c57d09

Please sign in to comment.