Skip to content

Commit

Permalink
Optimize MQA Kernel (vllm-project#452)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuohan123 authored Jul 15, 2023
1 parent dbed690 commit 96853af
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 72 deletions.
1 change: 1 addition & 0 deletions csrc/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ void single_query_cached_kv_attention(
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
Expand Down
31 changes: 22 additions & 9 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,17 @@ template<
__global__ void single_query_cached_kv_attention_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int* __restrict__ head_mapping, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride) {
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
Expand All @@ -91,6 +94,7 @@ __global__ void single_query_cached_kv_attention_kernel(

const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int kv_head_idx = head_mapping[head_idx];
const int seq_idx = blockIdx.y;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];

Expand Down Expand Up @@ -158,8 +162,8 @@ __global__ void single_query_cached_kv_attention_kernel(

#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
+ head_idx * HEAD_SIZE * BLOCK_SIZE
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride
+ physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
Expand Down Expand Up @@ -246,8 +250,8 @@ __global__ void single_query_cached_kv_attention_kernel(
L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));

const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
+ head_idx * HEAD_SIZE * BLOCK_SIZE;
const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
Expand Down Expand Up @@ -328,12 +332,15 @@ __global__ void single_query_cached_kv_attention_kernel(
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
head_mapping_ptr, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
query_stride);
q_stride, \
kv_block_stride, \
kv_head_stride);

// TODO(woosuk): Tune NUM_THREADS.
template<
Expand All @@ -345,6 +352,7 @@ void single_query_cached_kv_attention_launcher(
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
Expand All @@ -354,7 +362,9 @@ void single_query_cached_kv_attention_launcher(
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int query_stride = query.stride(0);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);

int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
Expand All @@ -368,6 +378,7 @@ void single_query_cached_kv_attention_launcher(
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();

Expand Down Expand Up @@ -422,6 +433,7 @@ void single_query_cached_kv_attention_launcher(
query, \
key_cache, \
value_cache, \
head_mapping, \
scale, \
block_tables, \
context_lens, \
Expand Down Expand Up @@ -469,6 +481,7 @@ void single_query_cached_kv_attention(
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& head_mapping, // [num_heads]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
Expand Down
7 changes: 7 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ def get_head_size(self) -> int:
return self.hf_config.hidden_size // self.hf_config.num_attention_heads

def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
# For GPTBigCode:
if getattr(self.hf_config, "multi_query", False):
# Multi-query attention, only one KV head.
return 1
# For Falcon:
if getattr(self.hf_config, "n_head_kv", None) is not None:
return self.hf_config.n_head_kv
total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size

Expand Down
43 changes: 32 additions & 11 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,23 @@ class PagedAttention(nn.Module):
5. Output a flattened 1D tensor.
"""

def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp()
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.head_mapping = torch.repeat_interleave(
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
self.num_queries_per_kv)

if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f"head_size ({self.head_size}) is not supported. "
Expand All @@ -76,10 +87,18 @@ def multi_query_kv_attention(
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
"""

if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads.
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=1)

# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
Expand Down Expand Up @@ -107,9 +126,9 @@ def single_query_cached_kv_attention(
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size]
input_metadata: metadata for paged attention.
"""
block_size = value_cache.shape[3]
Expand All @@ -118,6 +137,7 @@ def single_query_cached_kv_attention(
query,
key_cache,
value_cache,
self.head_mapping,
self.scale,
input_metadata.block_tables,
input_metadata.context_lens,
Expand All @@ -143,11 +163,12 @@ def forward(
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_heads * head_size]
value: shape = [num_tokens, num_heads * head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
cache_event: event to wait for the cache operations to finish.
Expand All @@ -157,8 +178,8 @@ def forward(

# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

# Pre-allocate the output tensor.
output = torch.empty_like(query)
Expand Down
74 changes: 22 additions & 52 deletions vllm/model_executor/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import torch
from torch import nn
import numpy as np
from transformers import GPTBigCodeConfig

from vllm.model_executor.input_metadata import InputMetadata
Expand Down Expand Up @@ -55,10 +54,12 @@ def __init__(self, config: GPTBigCodeConfig):
assert total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads
self.num_kv_heads = 1 if config.multi_query else self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.scale = self.head_dim**-0.5

self.c_attn = ColumnParallelLinear(self.hidden_size,
3 * self.hidden_size,
self.hidden_size + 2 * self.kv_dim,
bias=True,
gather_output=False,
perform_initialization=False)
Expand All @@ -69,7 +70,8 @@ def __init__(self, config: GPTBigCodeConfig):
perform_initialization=False)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale)
scale=self.scale,
num_kv_heads=self.num_kv_heads)

def forward(
self,
Expand All @@ -79,7 +81,8 @@ def forward(
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim],
dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
Expand Down Expand Up @@ -263,67 +266,34 @@ def load_weights(self,
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)

def _expand_mqa_mha(qkv_array, n_head, head_dim):
"""manipulates along axis=0 from MQA to MHA
inputs: qkv_array.shape=((n_heads + 2) * head_dim, hidden_dim)
with n_heads for q, then 1 for k, 1 for 1 v, times head dim
return: qkv_array.shape=(3 * n_heads * head_dim, hidden_dim)
TODO: this function is no longer needed once vllm supports MQA.
"""
qkv_array = qkv_array.numpy()

dims_q = n_head * head_dim
# pylint: disable=unbalanced-tuple-unpacking
q, k, v = np.split(qkv_array, (dims_q, dims_q + head_dim),
axis=0)
# q is fine, but k & v have not replicated shape along the first
# axis as long as MQA is not nativly supported, increase memory
# and replicated (head_dim, hidden_dim) to
# (n_heads * head_dim, hidden_dim)
if k.ndim == 2 and v.ndim == 2:
replication = (n_head, 1) # weights
else:
replication = n_head # biases
# replicate n_head times for q, v
k, v = np.tile(k, replication), np.tile(v, replication)
# concat q, k, v along the first axis
# (n_heads * head_dim, hidden_dim)
# to (3 * n_heads * head_dim, hidden_dim)
qkv_array = np.concatenate((q, k, v), axis=0)
return torch.from_numpy(qkv_array)

# For the fused QKV linear layer, manually shard the weights.
if "c_attn" in name:
# GPT-2's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor parallelism is used, we shard the weights along
# the head dimension.
total_num_heads = self.config.num_attention_heads
total_num_kv_heads = (1 if self.config.multi_query else
total_num_heads)
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
total_kv_size = head_size * total_num_kv_heads
num_heads = total_num_heads // tensor_model_parallel_world_size
head_start = tensor_model_parallel_rank * num_heads
head_end = (tensor_model_parallel_rank + 1) * num_heads

if name.endswith(".weight"):
loaded_weight = _expand_mqa_mha(loaded_weight,
n_head=total_num_heads,
head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"):
loaded_weight = _expand_mqa_mha(loaded_weight,
n_head=total_num_heads,
head_dim=head_size)
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected parameter name {name}")
wq, wk, wv = torch.split(
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
dim=0)

wq = wq[head_size * head_start:head_size * head_end]
if not self.config.multi_query:
# Split the heads when using normal multi-head attention
wk = wk[head_size * head_start:head_size * head_end]
wv = wv[head_size * head_start:head_size * head_end]
# Else, keep the weights as is for multi-query attention

loaded_weight = torch.cat([wq, wk, wv], dim=0)

load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
Expand Down

0 comments on commit 96853af

Please sign in to comment.