Skip to content

Commit

Permalink
Fix iakv regression (#2900)
Browse files Browse the repository at this point in the history
* Fix iakv regression

* Remove unuse loop
  • Loading branch information
liangan1 authored May 22, 2024
1 parent 1f68851 commit 21b5030
Showing 1 changed file with 125 additions and 125 deletions.
250 changes: 125 additions & 125 deletions csrc/cpu/aten/kernels/MaskedMultiHeadAttentionKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,10 +557,14 @@ scale_dot_product_for_indirect_access_kv_cache(
auto thread_numbers = omp_get_max_threads();
auto max_parallel_parts = thread_numbers * 4;

auto target_block_size = 32L;
if (bs <= 32 and seq_len < 65536) {
target_block_size = 1L;
}
auto kv_block_size = bs * head_num >= max_parallel_parts
? seq_len
: std::max(seq_len / max_parallel_parts, 1L);
kv_block_size = std::min(kv_block_size, 32L);
kv_block_size = std::min(kv_block_size, target_block_size);
auto kv_block_count = (seq_len + kv_block_size - 1) / kv_block_size;
auto need_update_beam_idx = offset > 0 and bs > 1;
auto b_ptr = beam_idx.data_ptr<long>();
Expand All @@ -585,37 +589,48 @@ scale_dot_product_for_indirect_access_kv_cache(
for (auto hi = 0; hi < head_num; hi++) {
auto k_start = block_id * kv_block_size;
auto block_size = std::min(kv_block_size, seq_len - k_start);
auto query_ti = 0;
for (auto ti = k_start; ti < k_start + block_size; ti++) {
for (auto query_ti = 0; query_ti < cur_len; query_ti++) {
auto kv_hi = hi / group_size; // maping the query head to
// key/value head to support MGA/MQA
auto q_ptr_start = q_ptr +
(bi * cur_len + query_ti) * head_num * head_size +
hi * head_size;
auto attn_w_stride = (bi * head_num + hi) * cur_len * seq_len;
auto attn_w_pos =
attn_w_ptr + attn_w_stride + query_ti * seq_len + ti;
attn_w_pos[0] = 0.0f;
auto kc_token_start = ti * kc_token_stride;
auto kc_t_beam_start = kc_token_start;
auto beam = need_update_beam_idx ? new_beam_idx[bi][ti] : 0;
if (ti >
query_ti + offset) { // only caculate the innerproduct for
// the past token and current token
attn_w_pos[0] = -10000.0f;
} else if (ti == query_ti + offset) { // caculate the innerproduct
// for the current token and
// store the key
if (cur_len > 1) { // this may occur for processing the promt
auto beam_size = beam_batch / bs;
// need to store key accross beam
kc_t_beam_start =
kc_t_beam_start + bi * beam_size * kv_head * head_size;
} else {
kc_t_beam_start = kc_t_beam_start + bi * kv_head * head_size;
}
auto kc_head_start =
k_cache_ptr + kc_t_beam_start + kv_hi * head_size;
auto kv_hi = hi / group_size; // maping the query head to
// key/value head to support MGA/MQA
auto q_ptr_start = q_ptr +
(bi * cur_len + query_ti) * head_num * head_size +
hi * head_size;
auto attn_w_stride = (bi * head_num + hi) * cur_len * seq_len;
auto attn_w_pos =
attn_w_ptr + attn_w_stride + query_ti * seq_len + ti;
attn_w_pos[0] = 0.0f;
auto kc_token_start = ti * kc_token_stride;
auto kc_t_beam_start = kc_token_start;
auto beam = need_update_beam_idx ? new_beam_idx[bi][ti] : 0;
if (ti > query_ti + offset) { // only caculate the innerproduct for
// the past token and current token
attn_w_pos[0] = -10000.0f;
} else if (ti == query_ti + offset) { // caculate the innerproduct
// for the current token and
// store the key
if (cur_len > 1) { // this may occur for processing the promt
auto beam_size = beam_batch / bs;
// need to store key accross beam
kc_t_beam_start =
kc_t_beam_start + bi * beam_size * kv_head * head_size;
} else {
kc_t_beam_start = kc_t_beam_start + bi * kv_head * head_size;
}
auto kc_head_start =
k_cache_ptr + kc_t_beam_start + kv_hi * head_size;
auto k_ptr_start = k_ptr +
(bi * cur_len + ti - offset) * kv_head * head_size +
kv_hi * head_size;
reduce_head<QT>(
q_ptr_start,
k_ptr_start,
attn_w_pos,
head_size,
true,
kc_head_start);
} else { // caculate the innerproduct for the past token
if (ti >= offset) {
auto k_ptr_start = k_ptr +
(bi * cur_len + ti - offset) * kv_head * head_size +
kv_hi * head_size;
Expand All @@ -624,38 +639,24 @@ scale_dot_product_for_indirect_access_kv_cache(
k_ptr_start,
attn_w_pos,
head_size,
true,
kc_head_start);
} else { // caculate the innerproduct for the past token
if (ti >= offset) {
auto k_ptr_start = k_ptr +
(bi * cur_len + ti - offset) * kv_head * head_size +
kv_hi * head_size;
reduce_head<QT>(
q_ptr_start,
k_ptr_start,
attn_w_pos,
head_size,
false,
nullptr);
} else {
false,
nullptr);
} else {
kc_t_beam_start = kc_t_beam_start + beam * kv_head * head_size;
if (cur_len > 1) {
auto beam_size = beam_batch / bs;
kc_t_beam_start =
kc_t_beam_start + beam * kv_head * head_size;
if (cur_len > 1) {
auto beam_size = beam_batch / bs;
kc_t_beam_start =
kc_t_beam_start + bi * beam_size * kv_head * head_size;
}
auto kc_head_start =
k_cache_ptr + kc_t_beam_start + kv_hi * head_size;
reduce_head<QT>(
q_ptr_start,
kc_head_start,
attn_w_pos,
head_size,
false,
nullptr);
kc_t_beam_start + bi * beam_size * kv_head * head_size;
}
auto kc_head_start =
k_cache_ptr + kc_t_beam_start + kv_hi * head_size;
reduce_head<QT>(
q_ptr_start,
kc_head_start,
attn_w_pos,
head_size,
false,
nullptr);
}
}
}
Expand Down Expand Up @@ -742,85 +743,84 @@ scale_dot_product_for_indirect_access_kv_cache(
thread_id = omp_get_thread_num();
auto v_start = block_id * kv_block_size;
auto block_size = std::min(kv_block_size, seq_len - v_start);
auto query_ti = 0;
for (auto vi = v_start; vi < v_start + block_size; vi++) {
for (auto query_ti = 0; query_ti < cur_len; query_ti++) {
auto kv_hi = hi / group_size; // maping the query head to
// key/value head to support MGA/MQA
auto attn_w_stride = (bi * head_num + hi) * cur_len * seq_len;
auto attn_w_query_start =
attn_w_ptr + attn_w_stride + query_ti * seq_len;
// calculate weighted value and store the result to attn_outs[bs,
// head_num, cur_len, head_size]
auto attn_out_head_stride = thread_id * attn_outs_stride_priv +
(bi * head_num + hi) * cur_len * head_size;
auto attn_out_start = private_attn_out_ptr +
attn_out_head_stride + query_ti * head_size;
auto kv_hi = hi / group_size; // maping the query head to
// key/value head to support MGA/MQA
auto attn_w_stride = (bi * head_num + hi) * cur_len * seq_len;
auto attn_w_query_start =
attn_w_ptr + attn_w_stride + query_ti * seq_len;
// calculate weighted value and store the result to attn_outs[bs,
// head_num, cur_len, head_size]
auto attn_out_head_stride = thread_id * attn_outs_stride_priv +
(bi * head_num + hi) * cur_len * head_size;
auto attn_out_start = private_attn_out_ptr + attn_out_head_stride +
query_ti * head_size;

auto vc_token_start = vi * kc_token_stride;
auto beam = need_update_beam_idx ? new_beam_idx[bi][vi] : 0;
if (vi == query_ti + offset) { // caculate the attention values
// for the current token
auto vc_t_beam_start = vc_token_start;
if (cur_len > 1) { // this may occur for processing the promt
auto vc_token_start = vi * kc_token_stride;
auto beam = need_update_beam_idx ? new_beam_idx[bi][vi] : 0;
if (vi == query_ti + offset) { // caculate the attention values
// for the current token
auto vc_t_beam_start = vc_token_start;
if (cur_len > 1) { // this may occur for processing the promt
auto beam_size = beam_batch / bs;
// removed the redundant computation, need to store key
// accross beam
vc_t_beam_start =
vc_t_beam_start + bi * beam_size * kv_head * head_size;
} else {
vc_t_beam_start = vc_t_beam_start + bi * kv_head * head_size;
}
auto v_cache_head_start =
v_cache_ptr + vc_t_beam_start + kv_hi * head_size;
auto v_ptr_start = v_ptr +
(bi * cur_len + vi - offset) * kv_head * head_size +
kv_hi * head_size;
mul_attenion_weights_and_value_of_head<VT, float>(
attn_w_query_start[vi],
v_ptr_start,
attn_out_start,
head_size,
true,
v_cache_head_start,
flag_access[thread_id][bi][hi]);
} else if (vi < query_ti + offset) { // caculate attention
// values for the past
// token
if (vi >= offset) {
auto v_ptr_start = v_ptr +
(bi * cur_len + vi - offset) * kv_head * head_size +
kv_hi * head_size;
mul_attenion_weights_and_value_of_head<VT, float>(
attn_w_query_start[vi],
v_ptr_start,
attn_out_start,
head_size,
false,
nullptr,
flag_access[thread_id][bi][hi]);
} else {
auto vc_t_beam_start =
vc_token_start + beam * kv_head * head_size;
if (cur_len > 1) {
auto beam_size = beam_batch / bs;
// removed the redundant computation, need to store key
// accross beam
vc_t_beam_start =
vc_t_beam_start + bi * beam_size * kv_head * head_size;
} else {
vc_t_beam_start = vc_t_beam_start + bi * kv_head * head_size;
}
auto v_cache_head_start =
v_cache_ptr + vc_t_beam_start + kv_hi * head_size;
auto v_ptr_start = v_ptr +
(bi * cur_len + vi - offset) * kv_head * head_size +
kv_hi * head_size;
mul_attenion_weights_and_value_of_head<VT, float>(
attn_w_query_start[vi],
v_ptr_start,
v_cache_head_start,
attn_out_start,
head_size,
true,
v_cache_head_start,
false,
nullptr,
flag_access[thread_id][bi][hi]);
} else if (vi < query_ti + offset) { // caculate attention
// values for the past
// token
if (vi >= offset) {
auto v_ptr_start = v_ptr +
(bi * cur_len + vi - offset) * kv_head * head_size +
kv_hi * head_size;
mul_attenion_weights_and_value_of_head<VT, float>(
attn_w_query_start[vi],
v_ptr_start,
attn_out_start,
head_size,
false,
nullptr,
flag_access[thread_id][bi][hi]);
} else {
auto vc_t_beam_start =
vc_token_start + beam * kv_head * head_size;
if (cur_len > 1) {
auto beam_size = beam_batch / bs;
vc_t_beam_start =
vc_t_beam_start + bi * beam_size * kv_head * head_size;
}
auto v_cache_head_start =
v_cache_ptr + vc_t_beam_start + kv_hi * head_size;
mul_attenion_weights_and_value_of_head<VT, float>(
attn_w_query_start[vi],
v_cache_head_start,
attn_out_start,
head_size,
false,
nullptr,
flag_access[thread_id][bi][hi]);
}
}
if (flag_access[thread_id][bi][hi] == 0)
flag_access[thread_id][bi][hi] = 1;
}
if (flag_access[thread_id][bi][hi] == 0)
flag_access[thread_id][bi][hi] = 1;
}
}
}
Expand Down

0 comments on commit 21b5030

Please sign in to comment.