-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add experimental code for custom attention op #81
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright 2024 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from functools import partial | ||
|
||
from jax._src import core | ||
from jax._src import dispatch | ||
from jax._src.core import ShapedArray | ||
from jax._src.interpreters import mlir as jax_mlir | ||
from jax._src.typing import Array | ||
from jax.interpreters import mlir | ||
from jax.interpreters.mlir import ir | ||
from jaxlib.hlo_helpers import custom_call | ||
|
||
|
||
######################################### | ||
# Created Primitives for IREE attention # | ||
######################################### | ||
|
||
iree_attention_p = core.Primitive('iree_attention') | ||
iree_attention_p.def_impl(partial(dispatch.apply_primitive, iree_attention_p)) | ||
|
||
transpose_v = False | ||
|
||
|
||
def _check_rank(x, rank): | ||
if x.ndim != rank: | ||
raise ValueError(f'Expected {rank} dimensions, got {x.ndim}') | ||
|
||
|
||
def _iree_attention( | ||
query, | ||
key, | ||
value, | ||
scale, | ||
): | ||
for x in [query, key, value]: | ||
_check_rank(x, 3) | ||
out = iree_attention_p.bind(query, key, value, scale) | ||
return out | ||
|
||
#################### | ||
# Lowering to MLIR # | ||
#################### | ||
|
||
def iree_attention_lowering( | ||
ctx, | ||
query, | ||
key, | ||
value, | ||
scale, | ||
): | ||
|
||
"""Builds a custom IREE attentionOp.""" | ||
rw = custom_call( | ||
'iree_attention', | ||
result_types=[ir.RankedTensorType(query.type)], | ||
operands=[query, key, value, scale], | ||
extra_attributes={'transpose_v': ir.BoolAttr.get(transpose_v)}, | ||
) | ||
return rw.results | ||
|
||
|
||
mlir.register_lowering( | ||
iree_attention_p, iree_attention_lowering, platform='iree_cpu' | ||
) | ||
|
||
####################### | ||
# Abstract evaluation # | ||
####################### | ||
|
||
|
||
def _iree_attention_abstract_eval_rule(query, key, value, scale): | ||
return ShapedArray(query.shape, query.dtype) | ||
|
||
iree_attention_p.def_abstract_eval(_iree_attention_abstract_eval_rule) | ||
|
||
###################### | ||
# Top-level interface# | ||
###################### | ||
|
||
|
||
def iree_attention( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the one users would then call explicitly? So one needs to change the model to use it? (potentially wrapped inside some function which dispatches either to the IREE one or a generic one so that they could still test without IREE) |
||
query, | ||
key, | ||
value, | ||
scale, | ||
) -> Array: | ||
return _iree_attention(query, key, value, scale) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright 2022 The IREE Authors | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import attention | ||
from jax._src.interpreters import mlir as jax_mlir | ||
from jax._src.lib.mlir import ir | ||
from jax.experimental import export | ||
|
||
|
||
def export_iree_attention(query, key, value, scale): | ||
inputs = (query_in, key_in, value_in, scale_in) | ||
input_shapes = [ | ||
jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs | ||
] | ||
att = export.export( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be some generic IREE export somewhere? E.g., have list of "allowed for IREE export" custom calls listed somwehere. Then one always does iree.jax.export instead of jax.experimental.export. |
||
attention.iree_attention, | ||
lowering_platforms=['iree_cpu'], | ||
disabled_checks=[ | ||
export.DisabledSafetyCheck.custom_call('iree_attention') | ||
], | ||
)(*input_shapes).mlir_module() | ||
return att | ||
|
||
def get_asm(module_str): | ||
with jax_mlir.make_ir_context(): | ||
stablehlo_module = ir.Module.parse( | ||
module_str, context=jax_mlir.make_ir_context() | ||
) | ||
return stablehlo_module.operation.get_asm(large_elements_limit=20) | ||
|
||
query_in = jnp.array( | ||
[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]]] | ||
) | ||
key_in = jnp.array( | ||
[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]]] | ||
) | ||
value_in = jnp.array( | ||
[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]]] | ||
) | ||
scale_in = jnp.float32(0.5) | ||
|
||
print(get_asm(export_iree_attention(query_in, key_in, value_in, scale_in))) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a global here?