Skip to content

Commit

Permalink
metal : use F32 prec for K*Q in vec FA (llama/9595)
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Sep 24, 2024
1 parent ba85632 commit 64f30f3
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2631,11 +2631,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
const short iv3 = iq3 / rv3;

// load the queries from shared memory into local memory
half4 mq[D4];
float4 mq[D4];

for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
mq[i] = sq4[i];
mq[i] = (float4) sq4[i];
}

// pointer to the mask
Expand All @@ -2661,11 +2661,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg;

half4x4 mk;
mk[0] = pk4[i + 0*(nb11/8)];
mk[1] = pk4[i + 1*(nb11/8)];
mk[2] = pk4[i + 2*(nb11/8)];
mk[3] = pk4[i + 3*(nb11/8)];
float4x4 mk;
mk[0] = (float4) pk4[i + 0*(nb11/8)];
mk[1] = (float4) pk4[i + 1*(nb11/8)];
mk[2] = (float4) pk4[i + 2*(nb11/8)];
mk[3] = (float4) pk4[i + 3*(nb11/8)];

mqk += (float4) (mq[i] * mk);
}
Expand Down

0 comments on commit 64f30f3

Please sign in to comment.