From 292bbf33574b6960f23c7b7c40ff443885cc7c70 Mon Sep 17 00:00:00 2001 From: Natasha Kononenko Date: Thu, 6 Jun 2024 19:56:27 +0000 Subject: [PATCH 1/2] Add experimental folder with custom attention op skeleton --- experimental/attention.py | 93 ++++++++++++++++++++++++++++++++++ experimental/attention_call.py | 48 ++++++++++++++++++ 2 files changed, 141 insertions(+) create mode 100644 experimental/attention.py create mode 100644 experimental/attention_call.py diff --git a/experimental/attention.py b/experimental/attention.py new file mode 100644 index 0000000..c4a5ead --- /dev/null +++ b/experimental/attention.py @@ -0,0 +1,93 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from functools import partial + +from jax._src import core +from jax._src import dispatch +from jax._src.core import ShapedArray +from jax._src.interpreters import mlir as jax_mlir +from jax._src.typing import Array +from jax.interpreters import mlir +from jax.interpreters.mlir import ir +from jaxlib.hlo_helpers import custom_call + + + +######################################### +# Created Primitives for IREE attention # +######################################### + +iree_attention_p = core.Primitive('iree_attention') +iree_attention_p.def_impl(partial(dispatch.apply_primitive, iree_attention_p)) + +transpose_v = False + + +def _check_rank(x, rank): + if x.ndim != rank: + raise ValueError(f'Expected {rank} dimensions, got {x.ndim}') + + +def _iree_attention( + query, + key, + value, + scale, +): + for x in [query, key, value]: + _check_rank(x, 3) + out = iree_attention_p.bind(query, key, value, scale) + return out + +#################### +# Lowering to MLIR # +#################### + +def iree_attention_lowering( + ctx, + query, + key, + value, + scale, +): + + """Builds a custom IREE attentionOp.""" + rw = custom_call( + 'iree_attention', + result_types=[ir.RankedTensorType(query.type)], + operands=[query, key, value, scale], + extra_attributes={'transpose_v': ir.BoolAttr.get(transpose_v)}, + ) + return rw.results + + +mlir.register_lowering( + iree_attention_p, iree_attention_lowering, platform='iree_cpu' +) # Should this be iree? + +####################### +# Abstract evaluation # +####################### + + +def _iree_attention_abstract_eval_rule(query, key, value, scale): + return ShapedArray(query.shape, query.dtype) + +iree_attention_p.def_abstract_eval(_iree_attention_abstract_eval_rule) + +###################### +# Top-level interface# +###################### + + +def iree_attention( + query, + key, + value, + scale, +) -> Array: + return _iree_attention(query, key, value, scale) \ No newline at end of file diff --git a/experimental/attention_call.py b/experimental/attention_call.py new file mode 100644 index 0000000..79477f8 --- /dev/null +++ b/experimental/attention_call.py @@ -0,0 +1,48 @@ +# Copyright 2022 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import jax +import jax.numpy as jnp +import attention +from jax._src.interpreters import mlir as jax_mlir +from jax._src.lib.mlir import ir +from jax.experimental import export + + +def export_iree_attention(query, key, value, scale): + inputs = (query_in, key_in, value_in, scale_in) + input_shapes = [ + jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs + ] + att = export.export( + attention.iree_attention, + lowering_platforms=['iree_cpu'], + disabled_checks=[ + export.DisabledSafetyCheck.custom_call('iree_attention') + ], + )(*input_shapes).mlir_module() + return att + +def get_asm(module_str): + with jax_mlir.make_ir_context(): + stablehlo_module = ir.Module.parse( + module_str, context=jax_mlir.make_ir_context() + ) + return stablehlo_module.operation.get_asm(large_elements_limit=20) + +query_in = jnp.array( + [[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]]] +) +key_in = jnp.array( + [[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]]] +) +value_in = jnp.array( + [[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]]] +) +scale_in = jnp.float32(0.5) + +print(get_asm(export_iree_attention(query_in, key_in, value_in, scale_in))) + From 15492b2d3e87598f1439dba29ce0393e2e4702ee Mon Sep 17 00:00:00 2001 From: Natasha Kononenko Date: Thu, 6 Jun 2024 19:59:13 +0000 Subject: [PATCH 2/2] Add newline --- experimental/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/experimental/attention.py b/experimental/attention.py index c4a5ead..1f3a746 100644 --- a/experimental/attention.py +++ b/experimental/attention.py @@ -16,7 +16,6 @@ from jaxlib.hlo_helpers import custom_call - ######################################### # Created Primitives for IREE attention # ######################################### @@ -67,7 +66,7 @@ def iree_attention_lowering( mlir.register_lowering( iree_attention_p, iree_attention_lowering, platform='iree_cpu' -) # Should this be iree? +) ####################### # Abstract evaluation # @@ -90,4 +89,5 @@ def iree_attention( value, scale, ) -> Array: - return _iree_attention(query, key, value, scale) \ No newline at end of file + return _iree_attention(query, key, value, scale) +