Skip to content

Commit

Permalink
setup ci for jax-metal plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
shuhand0 committed Mar 6, 2024
1 parent fc8dc83 commit 9bdad4d
Show file tree
Hide file tree
Showing 3 changed files with 1,727 additions and 0 deletions.
42 changes: 42 additions & 0 deletions .github/workflows/metal_plugin_ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# JAX-Metal plugin CI

name: Jax-Metal CI
on:
workflow_dispatch: # allows triggering the workflow run manually

jobs:
jax-metal-plugin-test:

strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
jaxlib-version: ["plugin_latest"]
name: "Jax-Metal plugin test (jaxlib=${{ matrix.jaxlib-version }})"
runs-on: [self-hosted, macOS, ARM64]

steps:
- name: Get repo
uses: actions/checkout@v4
with:
path: jax
- name: Setup build and test enviroment
run: |
rm -rf ${GITHUB_WORKSPACE}/jax-metal-venv
python3 -m venv ${GITHUB_WORKSPACE}/jax-metal-venv
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
pip install -U pip numpy wheel
pip install jax-metal absl-py pytest
if [[ "${{ matrix.jaxlib-version }}" == "nightly" ]]; then
pip install --pre jaxlib \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
fi;
cd jax
pip install .
- name: Run test
run: |
source ${GITHUB_WORKSPACE}/jax-metal-venv/bin/activate
export ENABLE_PJRT_COMPATIBILITY=1
cd jax
pytest tests/lax_metal_test.py
4 changes: 4 additions & 0 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ def supported_dtypes():
elif device_under_test() == "iree":
types = {np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16,
np.uint32, np.float32}
elif device_under_test() == "METAL":
types = {np.int32, np.uint32, np.float32}
else:
types = {np.bool_, np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
Expand Down Expand Up @@ -423,6 +425,8 @@ def _get_device_tags():
device_tags = {device_under_test(), "rocm"}
elif is_device_cuda():
device_tags = {device_under_test(), "cuda"}
elif device_under_test() == "METAL":
device_tags = {device_under_test(), "gpu"}
else:
device_tags = {device_under_test()}
return device_tags
Expand Down
Loading

0 comments on commit 9bdad4d

Please sign in to comment.