Skip to content

Commit

Permalink
Adding VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE and VLLM_REMOVE_REPEAT_KV_C…
Browse files Browse the repository at this point in the history
…ACHE_SPLIT_GRAPHS (#43)
  • Loading branch information
iboiko-habana authored Nov 29, 2024
1 parent 50e10ea commit bc01901
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions vllm_hpu_extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import torch.nn.functional as F
import math
import habana_frameworks.torch.core as htcore

from vllm.logger import init_logger
from vllm_hpu_extension.capabilities import capabilities
Expand Down Expand Up @@ -146,6 +147,19 @@ def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]

#TODO: remove after fusedsdpa fix for query_head != kv_head
def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The kv go from (batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = kv.shape
if n_rep == 1:
return kv
kv = kv[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen,
head_dim)
return kv.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

def prompt_attention(
query: torch.Tensor,
Expand Down Expand Up @@ -179,6 +193,15 @@ def prompt_attention(
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
else:
VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE = os.environ.get('VLLM_REMOVE_REPEAT_KV_CACHE','1') == '1'
VLLM_REMOVE_REPEAT_KV_CACHE_SPLIT_GRAPHS = os.environ.get('VLLM_REMOVE_REPEAT_KV_CACHE_SPLIT_GRAPHS','0') == '1'
#TODO: remove after fusedsdpa fix for query_heads != kv_heads
if query_heads != kv_heads:
if VLLM_REMOVE_REPEAT_KV_CACHE_SPLIT_GRAPHS:
htcore.mark_step()
if VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE:
key = repeat_kv(key, int(query_heads // kv_heads))
value = repeat_kv(value, int(query_heads // kv_heads))
softmax_mode = 'fast'
recompute_mode = True
attn_weights = FusedSDPA.apply(query, key, value, None, 0.0, True,
Expand Down

0 comments on commit bc01901

Please sign in to comment.