Skip to content

Commit

Permalink
Merge pull request jax-ml#293 from superbobry:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675154784
  • Loading branch information
The jax_triton Authors committed Sep 16, 2024
2 parents c5027d6 + 50f4af1 commit 973e106
Show file tree
Hide file tree
Showing 13 changed files with 116 additions and 31 deletions.
18 changes: 10 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ on:
push:
branches:
- main
permissions:
contents: write
pull_request:
branches:
- main

jobs:
deploy:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5
with:
python-version: 3.x
- run: pip install -r docs/requirements.txt
- run: mkdocs gh-deploy --force
python-version: '3.10'
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/[email protected]
20 changes: 20 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: docs
on:
push:
branches:
- main

permissions:
contents: write

jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # ratchet:actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5
with:
python-version: '3.10'
- run: pip install -r docs/requirements.txt
- run: mkdocs gh-deploy --force
31 changes: 31 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Install the pre-commit hooks below with
# 'pre-commit install'

# Auto-update the version of the hooks with
# 'pre-commit autoupdate'

# Run the hooks on all files with
# 'pre-commit run --all'

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: 2c9f875913ee60ca25ce70243dc24d5b6415598c # frozen: v4.6.0
hooks:
- id: check-ast
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
- id: end-of-file-fixer
# only include python files
files: \.py$
- id: debug-statements
# only include python files
files: \.py$
- id: trailing-whitespace
# only include python files
files: \.py$

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 8b5112a3b2ad121439a2092f8ff548c0d80f2514 # frozen: v0.6.1
hooks:
- id: ruff
5 changes: 2 additions & 3 deletions examples/block_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import jax
import jax.numpy as jnp
from jax import lax
from jax import random
import jax_triton as jt
from jax_triton import pallas as pl
Expand Down Expand Up @@ -178,8 +177,8 @@ def main(unused_argv):
k = random.normal(k_key, shape, dtype=dtype)
v = random.normal(v_key, shape, dtype=dtype)

o = mha(q, k, v).block_until_ready()
o_ref = mha_reference(q, k, v).block_until_ready()
mha(q, k, v).block_until_ready()
mha_reference(q, k, v).block_until_ready()

if __name__ == "__main__":
from absl import app
Expand Down
5 changes: 2 additions & 3 deletions examples/pallas/blocksparse_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from jax import random
import jax
from jax import lax
import jax.numpy as jnp
import numpy as np

import jax_triton as jt
Expand Down Expand Up @@ -87,7 +86,7 @@ def tree_unflatten(cls, data, xs):
return BlockELL(blocks, blocks_per_row, indices, shape=shape)

def _validate(self):
nblocks, n, m = self.blocks.shape
_nblocks, n, m = self.blocks.shape
nrows = self.blocks_per_row.shape[0]
assert self.indices.shape[0] == nrows
assert len(self.shape) == 2
Expand Down Expand Up @@ -168,7 +167,7 @@ def sdd_matmul(x_ell, y, num_warps: int = 8, num_stages: int = 3, bn: int = 64,
grid = (jt.cdiv(m, bm), jt.cdiv(n, bn))

kernel = functools.partial(sdd_kernel, bm=bm, bn=bn)
out_shape = jax.ShapeDtypeStruct(shape=(m, n), dtype=x.dtype)
out_shape = jax.ShapeDtypeStruct(shape=(m, n), dtype=x_ell.dtype)
return pl.pallas_call(kernel, num_warps=num_warps, num_stages=num_stages,
grid=grid, out_shape=out_shape,
debug=debug)(x_ell.blocks, x_ell.indices,
Expand Down
19 changes: 9 additions & 10 deletions examples/pallas/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@
import functools
import timeit

from typing import Optional, Tuple

import jax.numpy as jnp
from jax import random
import jax
from jax import lax
from jax._src.lax.control_flow import for_loop
import jax.numpy as jnp
import numpy as np

import jax_triton as jt
Expand Down Expand Up @@ -188,13 +185,15 @@ def main(unused_argv):
x = random.normal(x_key, (batch_size, feature_size), dtype)
h = random.normal(h_key, (batch_size, hidden_size), dtype)
c = random.normal(c_key, (batch_size, hidden_size), dtype)
lstm_cell = jax.jit(functools.partial(lstm_cell,
block_batch=block_batch,
block_hidden=block_hidden,
block_features=block_features,
num_warps=num_warps,
num_stages=num_stages))
y, c_next = jax.block_until_ready(lstm_cell(weights, x, h, c))
lstm_cell_fn = jax.jit(functools.partial(
lstm_cell,
block_batch=block_batch,
block_hidden=block_hidden,
block_features=block_features,
num_warps=num_warps,
num_stages=num_stages,
))
y, c_next = jax.block_until_ready(lstm_cell_fn(weights, x, h, c))
y_ref, c_next_ref = lstm_cell_reference(weights, x, h, c)
np.testing.assert_allclose(y, y_ref, atol=0.05, rtol=0.05)
np.testing.assert_allclose(c_next, c_next_ref, atol=0.05, rtol=0.05)
Expand Down
11 changes: 11 additions & 0 deletions jax_triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@
# limitations under the License.

"""Library for JAX-Triton integrations."""

__all__ = [
"utils",
"triton_call",
"cdiv",
"next_power_of_2",
"strides_from_shape",
"__version__",
"__version_info__",
]

import jaxlib
from jax._src.lib import gpu_triton
from jax_triton import utils
Expand Down
2 changes: 1 addition & 1 deletion jax_triton/experimental/fusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
jax.nn.sigmoid = sigmoid
del sigmoid, oryx, jax

from jax_triton.experimental.fusion.lowering import jit
from jax_triton.experimental.fusion.lowering import jit as jit
3 changes: 1 addition & 2 deletions jax_triton/experimental/fusion/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from jax import lax
from jax.extend import linear_util as lu
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src import core
from jax._src import util
from jax._src.lax.control_flow import for_loop
Expand Down Expand Up @@ -251,6 +250,7 @@ def _matmul_elementwise_lowering_rule(x, y, *args, left_ops, right_ops, out_ops,
bias, = args
else:
bias = None
del bias # TODO(sharadmv): Please fix or remove `bias` above.
lhs_dim, rhs_dim = contract_dims
M, N, K = x.shape[1 - lhs_dim], y.shape[1 - rhs_dim], x.shape[lhs_dim]
assert x.shape[lhs_dim] == y.shape[rhs_dim]
Expand Down Expand Up @@ -340,4 +340,3 @@ def _dot_general_lowering_rule(x, y, dimension_numbers, **_):
out_ops=[], contract_dims=(lhs_dim,
rhs_dim))
lowering_rules[lax.dot_general_p] = _dot_general_lowering_rule

2 changes: 1 addition & 1 deletion jax_triton/experimental/fusion/jaxpr_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import dataclasses
import itertools as it

from typing import Any, Callable, Dict, List, Set, Tuple, Union
from typing import Any, Callable, List, Tuple, Union

from jax._src import core as jax_core
import jax.numpy as jnp
Expand Down
2 changes: 0 additions & 2 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,6 @@ def compile_ttir_to_hsaco_inplace(
amdgcn = hip_backend.make_amdgcn(llir, metadata, hip_options)
hsaco = hip_backend.make_hsaco(amdgcn, metadata, hip_options)

if hip_options.debug:
print(x)
name = metadata["name"]
ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None
llir = str(llir) if _JAX_TRITON_DUMP_DIR else None
Expand Down
7 changes: 6 additions & 1 deletion jax_triton/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
# limitations under the License.

"""Contains utilities for writing and calling Triton functions."""


__all__ = ["cdiv", "strides_from_shape", "next_power_of_2"]


from jax.experimental.pallas import cdiv
from jax.experimental.pallas import strides_from_shape
from jax.experimental.pallas import next_power_of_2
from jax.experimental.pallas import next_power_of_2
22 changes: 22 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,25 @@ packages = ["jax_triton"]

[tool.setuptools.dynamic]
version = {attr = "jax_triton.version.__version__"}

[tool.ruff]
preview = true
exclude = [
".git",
"build",
"__pycache__",
"*.ipynb",
]
line-length = 88
indent-width = 2
target-version = "py310"

[tool.ruff.lint]
ignore = [
# Do not assign a `lambda` expression, use a `def`
"E731",
# Module level import not at top of file
"E402",
# Ambiguous variable name
"E741",
]

0 comments on commit 973e106

Please sign in to comment.