Skip to content

Commit d6599b5

Browse files
authored
Enable multiple query for the next tokens (#3018)
* Enable multiple query for the next tokens
1 parent a6fd719 commit d6599b5

File tree

1 file changed

+38
-11
lines changed

1 file changed

+38
-11
lines changed

csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,17 +1473,44 @@ masked_multihead_self_attention_kernel_impl(
14731473
value_cache = new_value_cache;
14741474
beam_idx = new_beam_idx;
14751475
}
1476-
if (offset > 0) {
1477-
return zero_copy_kv_cache_masked_multihead_self_attention_kernel_impl(
1478-
query,
1479-
key,
1480-
value,
1481-
key_cache,
1482-
value_cache,
1483-
beam_idx,
1484-
offset,
1485-
scale_attn,
1486-
attention_mask_v);
1476+
if (offset != 0) {
1477+
auto cur_len = query.size(1);
1478+
if (cur_len == 1)
1479+
return zero_copy_kv_cache_masked_multihead_self_attention_kernel_impl(
1480+
query,
1481+
key,
1482+
value,
1483+
key_cache,
1484+
value_cache,
1485+
beam_idx,
1486+
offset,
1487+
scale_attn,
1488+
attention_mask_v);
1489+
// just a funcationality path,need to optimize
1490+
auto tokens_outs = std::vector<at::Tensor>(cur_len);
1491+
for (auto i = 0; i < cur_len; i++) {
1492+
auto query_i = query.select(1, i).unsqueeze(1);
1493+
;
1494+
auto key_i = key.select(1, i).unsqueeze(1);
1495+
;
1496+
auto value_i = value.select(1, i).unsqueeze(1);
1497+
;
1498+
auto next_outs =
1499+
zero_copy_kv_cache_masked_multihead_self_attention_kernel_impl(
1500+
query_i,
1501+
key_i,
1502+
value_i,
1503+
key_cache,
1504+
value_cache,
1505+
beam_idx,
1506+
offset,
1507+
scale_attn,
1508+
attention_mask_v);
1509+
tokens_outs[i] = std::get<0>(next_outs);
1510+
}
1511+
auto attn_outs = at::cat(tokens_outs, 2);
1512+
return std::make_tuple(
1513+
attn_outs, at::Tensor(), key_cache, value_cache, beam_idx);
14871514
} else {
14881515
return first_token_masked_mha(
14891516
query,

0 commit comments

Comments
 (0)