Skip to content

Commit

Permalink
[Docs] Fix docstring about Q nheads being divisible by KV nheads
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Aug 1, 2023
1 parent a4f148b commit 8f4cd4c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=Fa
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of K, V.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Expand Down Expand Up @@ -346,7 +346,7 @@ def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
return_attn_probs=False):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Expand Down Expand Up @@ -416,7 +416,7 @@ def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqle
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of K, V.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Expand Down Expand Up @@ -456,7 +456,7 @@ def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, ma
return_attn_probs=False):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
than Q. Note that the number of heads in K, V must be divisible by the number of heads in Q.
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Expand Down

0 comments on commit 8f4cd4c

Please sign in to comment.