Skip to content

Commit

Permalink
Enable patching Fused SDPA (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
nirda7 authored Dec 5, 2024
1 parent 070591a commit 41ff369
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
12 changes: 3 additions & 9 deletions vllm_hpu_extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,6 @@
except ImportError:
logger.warning("Could not import HPU FusedRMSNorm kernel. "
"vLLM will use forward_native implementation of RMSNorm.")
HPUFusedSDPA = None
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
HPUFusedSDPA = FusedSDPA
except ImportError:
logger.warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")


def grouped_max(block_max, batch_size, block_groups):
Expand Down Expand Up @@ -176,13 +169,14 @@ def prompt_attention(
softmax_op=torch.softmax,
matmul_av_op=torch.matmul,
valid_seq_lengths: Optional[torch.Tensor] = None,
fsdpa_op = None,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
query_heads = query.size(1)
kv_heads = key.size(1)
if attn_bias is not None or HPUFusedSDPA is None:
if attn_bias is not None or fsdpa_op is None:
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
Expand All @@ -209,7 +203,7 @@ def prompt_attention(
value = repeat_kv(value, int(query_heads // kv_heads))
softmax_mode = 'fast'
recompute_mode = True
attn_weights = FusedSDPA.apply(query, key, value, None, 0.0, True,
attn_weights = fsdpa_op(query, key, value, None, 0.0, True,
scale, softmax_mode, recompute_mode,
valid_seq_lengths, 'right')
attn_weights = attn_weights.transpose(1, 2)
Expand Down
35 changes: 35 additions & 0 deletions vllm_hpu_extension/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,38 @@ def fetch_from_cache(self, cache, blocks):
else:
return cache.index_select(0, blocks)


class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA):
super().__init__()
assert fusedSDPA is not None, f'fusedSDPA kernel is None'
self._hpu_kernel_fsdpa = fusedSDPA

def forward(
self,
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side="left",
):
return self._hpu_kernel_fsdpa.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side,
)

0 comments on commit 41ff369

Please sign in to comment.