forked from mosaicml/llm-foundry
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Monkeypatch flash attention in for llama (mosaicml#520)
- Loading branch information
Showing
5 changed files
with
404 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
283 changes: 283 additions & 0 deletions
283
llmfoundry/models/layers/llama_attention_monkeypatch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,283 @@ | ||
# Copyright 2022 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# This file is copied and modified from | ||
# https://github.com/huggingface/transformers/blob/fe3c8ab1af558b95f67f5fafc0c55f09fd2b09db/src/transformers/models/llama/modeling_llama.py | ||
# See the clearly denoted code blocks for the main modifications (there are a few others like type ignores, and error messages) | ||
|
||
import logging | ||
from typing import Callable, Optional, Tuple | ||
|
||
import torch | ||
import torch.functional as F | ||
|
||
from llmfoundry.models.layers.attention import ( | ||
scaled_multihead_dot_product_attention, triton_flash_attn_fn) | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | ||
"""Equivalent of torch.repeat_interleave(x, dim=1, | ||
repeats=n_rep). | ||
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to | ||
(batch, num_attention_heads, seqlen, head_dim) | ||
""" | ||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape | ||
if n_rep == 1: | ||
return hidden_states | ||
hidden_states = hidden_states[:, :, | ||
None, :, :].expand(batch, num_key_value_heads, | ||
n_rep, slen, head_dim) | ||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, | ||
head_dim) | ||
|
||
|
||
def rotate_half(x: torch.Tensor): | ||
"""Rotates half the hidden dims of the input.""" | ||
x1 = x[..., :x.shape[-1] // 2] | ||
x2 = x[..., x.shape[-1] // 2:] | ||
return torch.cat((-x2, x1), dim=-1) | ||
|
||
|
||
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, | ||
sin: torch.Tensor, position_ids: torch.Tensor): | ||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. | ||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] | ||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] | ||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] | ||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] | ||
q_embed = (q * cos) + (rotate_half(q) * sin) | ||
k_embed = (k * cos) + (rotate_half(k) * sin) | ||
return q_embed, k_embed | ||
|
||
|
||
def get_llama_attention_patch_fn(patch_fn_name: str = 'torch') -> Callable: | ||
if patch_fn_name == 'torch': | ||
return llama_attention_patch_torch | ||
elif patch_fn_name == 'triton': | ||
return llama_attention_patch_triton | ||
else: | ||
raise ValueError( | ||
f'Unrecognized llama attention patch function: {patch_fn_name}') | ||
|
||
|
||
def llama_attention_patch_torch( | ||
self, # type: ignore | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
output_attentions: bool = False, | ||
use_cache: bool = False, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
if use_cache: | ||
raise NotImplementedError( | ||
'use_cache is not yet supported when patching Llama attention.') | ||
|
||
bsz, q_len, _ = hidden_states.size() | ||
|
||
if self.config.pretraining_tp > 1: | ||
key_value_slicing = (self.num_key_value_heads * | ||
self.head_dim) // self.config.pretraining_tp | ||
query_slices = self.q_proj.weight.split( | ||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, | ||
dim=0) | ||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) | ||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) | ||
|
||
query_states = [ | ||
F.linear( # type: ignore (thirdParty) | ||
hidden_states, query_slices[i]) | ||
for i in range(self.config.pretraining_tp) | ||
] | ||
query_states = torch.cat(query_states, dim=-1) | ||
|
||
key_states = [ | ||
F.linear(hidden_states, key_slices[i]) # type: ignore (thirdParty) | ||
for i in range(self.config.pretraining_tp) | ||
] | ||
key_states = torch.cat(key_states, dim=-1) | ||
|
||
value_states = [ | ||
F.linear( # type: ignore (thirdParty) | ||
hidden_states, value_slices[i]) | ||
for i in range(self.config.pretraining_tp) | ||
] | ||
value_states = torch.cat(value_states, dim=-1) | ||
else: | ||
query_states = self.q_proj(hidden_states) | ||
key_states = self.k_proj(hidden_states) | ||
value_states = self.v_proj(hidden_states) | ||
|
||
query_states = query_states.view(bsz, q_len, self.num_heads, | ||
self.head_dim).transpose(1, 2) | ||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, | ||
self.head_dim).transpose(1, 2) | ||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, | ||
self.head_dim).transpose(1, 2) | ||
|
||
kv_seq_len = key_states.shape[-2] | ||
if past_key_value is not None: | ||
kv_seq_len += past_key_value[0].shape[-2] | ||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||
query_states, key_states = apply_rotary_pos_emb( | ||
query_states, key_states, cos, sin, | ||
position_ids) # type: ignore (thirdParty) | ||
|
||
### MAIN MODIFICATIONS START HERE ### | ||
query_states = query_states.transpose(1, 2).view( | ||
bsz, q_len, self.num_heads * self.head_dim) | ||
key_states = key_states.transpose(1, 2).view( | ||
bsz, q_len, self.num_key_value_heads * self.head_dim) | ||
value_states = value_states.transpose(1, 2).view( | ||
bsz, q_len, self.num_key_value_heads * self.head_dim) | ||
|
||
attn_output, attn_weights, _ = scaled_multihead_dot_product_attention( | ||
query=query_states, | ||
key=key_states, | ||
value=value_states, | ||
n_heads=self.num_heads, | ||
kv_n_heads=self.num_key_value_heads, | ||
past_key_value=None, | ||
softmax_scale=None, | ||
attn_bias=attention_mask, | ||
key_padding_mask=None, | ||
is_causal=False, # The causal mask is propagated from LLamaForCausalLM | ||
dropout_p=0, | ||
training=self.training, | ||
needs_weights=False, | ||
) | ||
### MAIN MODIFICATIONS END HERE ### | ||
|
||
if self.config.pretraining_tp > 1: | ||
attn_output = attn_output.split(self.hidden_size // | ||
self.config.pretraining_tp, | ||
dim=2) | ||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // | ||
self.config.pretraining_tp, | ||
dim=1) | ||
attn_output = sum([ | ||
F.linear( # type: ignore (thirdParty) | ||
attn_output[i], o_proj_slices[i]) | ||
for i in range(self.config.pretraining_tp) | ||
]) | ||
else: | ||
attn_output = self.o_proj(attn_output) | ||
|
||
if not output_attentions: | ||
attn_weights = None | ||
|
||
return attn_output, attn_weights, None # type: ignore (thirdParty) | ||
|
||
|
||
def llama_attention_patch_triton( | ||
self, # type: ignore | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
output_attentions: bool = False, | ||
use_cache: bool = False, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
if use_cache: | ||
raise NotImplementedError( | ||
'use_cache is not yet supported when patching Llama attention.') | ||
# output_attentions is not support for triton attention | ||
if output_attentions: | ||
raise NotImplementedError( | ||
'output_attentions is not supported when patching Llama attention with triton attention.' | ||
) | ||
bsz, q_len, _ = hidden_states.size() | ||
|
||
if self.config.pretraining_tp > 1: | ||
key_value_slicing = (self.num_key_value_heads * | ||
self.head_dim) // self.config.pretraining_tp | ||
query_slices = self.q_proj.weight.split( | ||
(self.num_heads * self.head_dim) // self.config.pretraining_tp, | ||
dim=0) | ||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) | ||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) | ||
|
||
query_states = [ | ||
F.linear( # type: ignore (thirdParty) | ||
hidden_states, query_slices[i]) | ||
for i in range(self.config.pretraining_tp) | ||
] | ||
query_states = torch.cat(query_states, dim=-1) | ||
|
||
key_states = [ | ||
F.linear(hidden_states, key_slices[i]) # type: ignore (thirdParty) | ||
for i in range(self.config.pretraining_tp) | ||
] | ||
key_states = torch.cat(key_states, dim=-1) | ||
|
||
value_states = [ | ||
F.linear( # type: ignore (thirdParty) | ||
hidden_states, value_slices[i]) | ||
for i in range(self.config.pretraining_tp) | ||
] | ||
value_states = torch.cat(value_states, dim=-1) | ||
else: | ||
query_states = self.q_proj(hidden_states) | ||
key_states = self.k_proj(hidden_states) | ||
value_states = self.v_proj(hidden_states) | ||
|
||
query_states = query_states.view(bsz, q_len, self.num_heads, | ||
self.head_dim).transpose(1, 2) | ||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, | ||
self.head_dim).transpose(1, 2) | ||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, | ||
self.head_dim).transpose(1, 2) | ||
|
||
kv_seq_len = key_states.shape[-2] | ||
if past_key_value is not None: | ||
kv_seq_len += past_key_value[0].shape[-2] | ||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||
query_states, key_states = apply_rotary_pos_emb( | ||
query_states, key_states, cos, sin, | ||
position_ids) # type: ignore (thirdParty) | ||
|
||
### MAIN MODIFICATIONS START HERE ### | ||
query_states = query_states.transpose(1, 2).view( | ||
bsz, q_len, self.num_heads * self.head_dim) | ||
key_states = key_states.transpose(1, 2).view( | ||
bsz, q_len, self.num_key_value_heads * self.head_dim) | ||
value_states = value_states.transpose(1, 2).view( | ||
bsz, q_len, self.num_key_value_heads * self.head_dim) | ||
|
||
attn_output, _, _ = triton_flash_attn_fn( | ||
query=query_states, | ||
key=key_states, | ||
value=value_states, | ||
n_heads=self.num_heads, | ||
kv_n_heads=self.num_key_value_heads, | ||
past_key_value=None, | ||
softmax_scale=None, | ||
attn_bias=attention_mask, | ||
key_padding_mask=None, | ||
is_causal=False, # The causal mask is propagated from LLamaForCausalLM | ||
dropout_p=0, | ||
training=self.training, | ||
needs_weights=False, | ||
) | ||
### MAIN MODIFICATIONS END HERE ### | ||
|
||
if self.config.pretraining_tp > 1: | ||
attn_output = attn_output.split(self.hidden_size // | ||
self.config.pretraining_tp, | ||
dim=2) | ||
o_proj_slices = self.o_proj.weight.split(self.hidden_size // | ||
self.config.pretraining_tp, | ||
dim=1) | ||
attn_output = sum([ | ||
F.linear( # type: ignore (thirdParty) | ||
attn_output[i], o_proj_slices[i]) | ||
for i in range(self.config.pretraining_tp) | ||
]) | ||
else: | ||
attn_output = self.o_proj(attn_output) | ||
|
||
return attn_output, None, None # type: ignore (thirdParty) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.