Skip to content

Commit

Permalink
Enable multiple query for the next tokens (#3018)
Browse files Browse the repository at this point in the history
* Enable multiple query for the next tokens
  • Loading branch information
liangan1 authored Jun 28, 2024
1 parent a6fd719 commit d6599b5
Showing 1 changed file with 38 additions and 11 deletions.
49 changes: 38 additions & 11 deletions csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1473,17 +1473,44 @@ masked_multihead_self_attention_kernel_impl(
value_cache = new_value_cache;
beam_idx = new_beam_idx;
}
if (offset > 0) {
return zero_copy_kv_cache_masked_multihead_self_attention_kernel_impl(
query,
key,
value,
key_cache,
value_cache,
beam_idx,
offset,
scale_attn,
attention_mask_v);
if (offset != 0) {
auto cur_len = query.size(1);
if (cur_len == 1)
return zero_copy_kv_cache_masked_multihead_self_attention_kernel_impl(
query,
key,
value,
key_cache,
value_cache,
beam_idx,
offset,
scale_attn,
attention_mask_v);
// just a funcationality path,need to optimize
auto tokens_outs = std::vector<at::Tensor>(cur_len);
for (auto i = 0; i < cur_len; i++) {
auto query_i = query.select(1, i).unsqueeze(1);
;
auto key_i = key.select(1, i).unsqueeze(1);
;
auto value_i = value.select(1, i).unsqueeze(1);
;
auto next_outs =
zero_copy_kv_cache_masked_multihead_self_attention_kernel_impl(
query_i,
key_i,
value_i,
key_cache,
value_cache,
beam_idx,
offset,
scale_attn,
attention_mask_v);
tokens_outs[i] = std::get<0>(next_outs);
}
auto attn_outs = at::cat(tokens_outs, 2);
return std::make_tuple(
attn_outs, at::Tensor(), key_cache, value_cache, beam_idx);
} else {
return first_token_masked_mha(
query,
Expand Down

0 comments on commit d6599b5

Please sign in to comment.