Skip to content

Commit

Permalink
fav2 for gptneox
Browse files Browse the repository at this point in the history
  • Loading branch information
LorrinWWW committed Jul 17, 2023
1 parent d739ab5 commit d64bdef
Showing 1 changed file with 51 additions and 4 deletions.
55 changes: 51 additions & 4 deletions training/modules/hf_gptneox_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,52 @@
print('>>>>> using flash attention')
except ImportError:
flash_attn_installed = False

try:
from fav2.fav2_interface import flash_attn_qkvpacked_func as fav2_qkvpacked_func
flash_attn_v2_installed = True
print('>>>>> using flash attention v2')

class FlashAttentionV2(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, softmax_scale=None, attention_dropout=0.0):
super().__init__()
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout

def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
max_s=None, need_weights=False):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
if unpadded: (nnz, 3, h, d)
key_padding_mask: a bool tensor of shape (B, S)
"""
assert not need_weights
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
assert key_padding_mask is None
assert cu_seqlens is None
assert max_s is None

output = fav2_qkvpacked_func(
qkv, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)

return output, None
except ImportError:
flash_attn_v2_installed = False



def rotate_half(x):
Expand Down Expand Up @@ -76,8 +121,10 @@ def __init__(self, config):
)
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)

if flash_attn_installed:

if flash_attn_v2_installed:
self.flash_attn = FlashAttentionV2(softmax_scale=1.0/self.norm_factor, attention_dropout = 0)
elif flash_attn_installed:
self.flash_attn = FlashAttention(softmax_scale=1.0/self.norm_factor, attention_dropout = 0)

def forward(
Expand Down Expand Up @@ -145,8 +192,8 @@ def forward(
value = torch.cat((past_value, value), dim=-2)
present = None if use_cache else (key, value)

# Compute attention
if flash_attn_installed:
# Compute attention
if flash_attn_installed or flash_attn_v2_installed:

query = query.permute(0, 2, 1, 3).half()
key = key.permute(0, 2, 1, 3).half()
Expand Down

0 comments on commit d64bdef

Please sign in to comment.