Skip to content

Commit

Permalink
add support for starcoder2 (ModelTC#347)
Browse files Browse the repository at this point in the history
  • Loading branch information
WuSiYu authored Mar 7, 2024
1 parent adba526 commit 486f647
Show file tree
Hide file tree
Showing 8 changed files with 441 additions and 0 deletions.
Empty file.
Empty file.
174 changes: 174 additions & 0 deletions lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import torch
import triton

from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd

from lightllm.models.starcoder2.layer_weights.transformer_layer_weight import Starcoder2TransformerLayerWeight
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer

from lightllm.models.mistral.infer_struct import MistralInferStateInfo
from lightllm.models.mistral.triton_kernel.context_flashattention_nopad import context_attention_fwd
from lightllm.models.mistral.triton_kernel.token_attention_nopad_att1 import token_att_fwd
from lightllm.models.mistral.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd


class Starcoder2TransformerLayerInfer(LlamaTransformerLayerInfer):
def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
self._bind_func()

def _bind_func(self):
self._token_attention_kernel = self._token_decode_attention_normal
self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal

return

def _att_norm(
self, input, infer_state: MistralInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight
) -> torch.Tensor:
return layernorm_forward(
input.view(-1, self.embed_dim_),
weight=layer_weight.att_norm_weight_,
bias=layer_weight.att_norm_bias_,
eps=self.eps_,
)

def _ffn_norm(
self, input, infer_state: MistralInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight
) -> torch.Tensor:
return layernorm_forward(
input.view(-1, self.embed_dim_),
weight=layer_weight.ffn_norm_weight_,
bias=layer_weight.ffn_norm_bias_,
eps=self.eps_,
)

def _get_qkv(
self, input, cache_kv, infer_state: MistralInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight
) -> torch.Tensor:
q = torch.addmm(layer_weight.q_bias_, input.view(-1, self.embed_dim_), layer_weight.q_weight_)
torch.addmm(
layer_weight.kv_bias_,
input.view(-1, self.embed_dim_),
layer_weight.kv_weight_,
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
)
rotary_emb_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
cache_kv[:, 0 : self.tp_k_head_num_, :],
infer_state.position_cos,
infer_state.position_sin,
)
return q, cache_kv

def _get_o(
self, input, infer_state: MistralInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight
) -> torch.Tensor:
o_tensor = torch.addmm(
layer_weight.o_bias_,
input.view(-1, self.tp_o_head_num_ * self.head_dim_),
layer_weight.o_weight_,
beta=1.0 / self.world_size_,
)
return o_tensor

def _ffn(
self, input, infer_state: MistralInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight
) -> torch.Tensor:
ffn1_out = torch.addmm(layer_weight.ffn_1_bias_, input.view(-1, self.embed_dim_), layer_weight.ffn_1_weight_)
input = None
gelu_out = torch.nn.functional.gelu(ffn1_out, approximate="tanh")
ffn1_out = None
ffn2_out = torch.addmm(
layer_weight.ffn_2_bias_, gelu_out, layer_weight.ffn_2_weight_, beta=1.0 / self.world_size_
)
gelu_out = None
return ffn2_out

# use sliding_window code from mistral
def _context_attention_kernel(
self, q, kv, infer_state: MistralInferStateInfo, layer_weight, out=None
) -> torch.Tensor:
o_tensor = torch.empty_like(q) if out is None else out
context_attention_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
kv[:, 0 : self.tp_k_head_num_, :],
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
infer_state.sliding_window,
)
return o_tensor

def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo, layer_weight, out=None):
total_token_num = infer_state.total_cache_num
batch_size = infer_state.batch_size
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_)

att_m_tensor = torch.empty((self.tp_q_head_num_, total_token_num), dtype=q.dtype, device="cuda")

token_att_fwd(
q.view(calcu_shape1),
infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :],
att_m_tensor,
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_start_loc_window,
infer_state.b_att_start_loc,
infer_state.b_att_seq_len,
infer_state.sliding_window,
)

o_tensor = torch.empty_like(q) if out is None else out

if triton.__version__ == "2.0.0":
prob = torch.empty_like(att_m_tensor)
token_softmax_fwd(
att_m_tensor, infer_state.b_att_start_loc, infer_state.b_att_seq_len, prob, infer_state.sliding_window
)
att_m_tensor = None
token_att_fwd2(
prob,
infer_state.mem_manager.kv_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
],
o_tensor.view(calcu_shape1),
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_start_loc_window,
infer_state.b_att_start_loc,
infer_state.b_att_seq_len,
)
prob = None
return o_tensor
elif triton.__version__ >= "2.1.0":
from lightllm.models.mistral.triton_kernel.token_attention_softmax_and_reducev import (
token_softmax_reducev_fwd,
)

token_softmax_reducev_fwd(
att_m_tensor,
infer_state.mem_manager.kv_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
],
o_tensor.view(calcu_shape1),
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_start_loc_window,
infer_state.b_att_start_loc,
infer_state.b_att_seq_len,
infer_state.other_kv_index,
)
return o_tensor
else:
raise Exception("not support triton version")
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np
from lightllm.common.basemodel import PreAndPostLayerWeight


class Starcoder2PreAndPostLayerWeight(PreAndPostLayerWeight):
def __init__(self, tp_rank, world_size, data_type, network_config, mode):
super().__init__(tp_rank, world_size, data_type, network_config, mode)
return

def load_hf_weights(self, weights):
vob_size = self.network_config_["vocab_size"]
split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64)
split_start = split_indexes[self.tp_rank_]
split_end = split_indexes[self.tp_rank_ + 1]
if "model.embed_tokens.weight" in weights:
self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :])

# for starcoder2-3b and 7b which didn't use lm_head.weight (tie_word_embeddings)
self.lm_head_weight_ = self.wte_weight_

if "lm_head.weight" in weights:
self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :])

if "model.norm.weight" in weights:
self.final_norm_weight_ = self._cuda(weights["model.norm.weight"])

if "model.norm.bias" in weights:
self.final_norm_bias_ = self._cuda(weights["model.norm.bias"])

return

def verify_load(self):
errors = "weights load not ok"
weights = [self.wte_weight_, self.lm_head_weight_, self.final_norm_weight_, self.final_norm_bias_]
for i in range(len(weights)):
assert weights[i] is not None, "index:" + str(i) + " " + errors
return
141 changes: 141 additions & 0 deletions lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from lightllm.common.basemodel import TransformerLayerWeight


class Starcoder2TransformerLayerWeight(TransformerLayerWeight):
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]):
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode)
assert network_config["num_attention_heads"] % self.world_size_ == 0

def init_static_params(self):
pass

def load_hf_weights(self, weights):
self._load_qkvo_weights(weights)
self._load_ffn_weights(weights)
return

def verify_load(self):
errors = "weights load not ok"
weights = [
self.att_norm_weight_,
self.att_norm_bias_,
self.q_weight_,
self.kv_weight_,
self.q_bias_,
self.kv_bias_,
self.o_weight_,
self.o_bias_,
self.ffn_norm_weight_,
self.ffn_norm_bias_,
self.ffn_1_weight_,
self.ffn_1_bias_,
self.ffn_2_weight_,
self.ffn_2_bias_,
]
for i in range(len(weights)):
assert weights[i] is not None, "index:" + str(i) + " " + errors
return

def _load_qkvo_weights(self, weights):
# input layernorm params
if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights:
self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.weight"])

if f"model.layers.{self.layer_num_}.input_layernorm.bias" in weights:
self.att_norm_bias_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.bias"])

n_embed = self.network_config_["hidden_size"]
q_split_n_embed = n_embed // self.world_size_
kv_split_n_embed = (
n_embed
// self.network_config_["num_attention_heads"]
* self.network_config_["num_key_value_heads"]
// self.world_size_
)
# q k v weights for llama
if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights:
self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"]
self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :]
self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1))
if f"model.layers.{self.layer_num_}.self_attn.q_proj.bias" in weights:
self.q_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.bias"][
q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)
]
self.q_bias_ = self._cuda(self.q_bias_)
if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights:
k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"]
k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :]
self.k_weight_ = k_weight_.transpose(0, 1)
if f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" in weights:
self.k_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"][
kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1)
]
if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights:
v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"]
v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :]
self.v_weight_ = v_weight_.transpose(0, 1)
if f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" in weights:
self.v_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"][
kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1)
]

self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1)

self._try_cat_to(["k_bias_", "v_bias_"], "kv_bias_", cat_dim=0)

# attention output dense params
if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights:
self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"]
self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)]
self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1))
if f"model.layers.{self.layer_num_}.self_attn.o_proj.bias" in weights:
self.o_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.bias"]
self.o_bias_ = self._cuda(self.o_bias_)
return

def _load_ffn_weights(self, weights):
if f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights:
# post attention layernorm params
self.ffn_norm_weight_ = self._cuda(
weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"]
)
self.ffn_norm_bias_ = self._cuda(weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.bias"])

# ffn params
n_embed = self.network_config_["hidden_size"]
intermediate_size = n_embed * 4
split_inter_size = intermediate_size // self.world_size_
if f"model.layers.{self.layer_num_}.mlp.c_fc.weight" in weights:
self.ffn_1_weight_ = weights[f"model.layers.{self.layer_num_}.mlp.c_fc.weight"].to(self.data_type_)
self.ffn_1_weight_ = (
self.ffn_1_weight_[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :]
.transpose(0, 1)
.contiguous()
.cuda()
)

if f"model.layers.{self.layer_num_}.mlp.c_fc.bias" in weights:
self.ffn_1_bias_ = (
weights[f"model.layers.{self.layer_num_}.mlp.c_fc.bias"][
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1)
]
.to(self.data_type_)
.contiguous()
.cuda()
)

if f"model.layers.{self.layer_num_}.mlp.c_proj.weight" in weights:
self.ffn_2_weight_ = weights[f"model.layers.{self.layer_num_}.mlp.c_proj.weight"].to(self.data_type_)
self.ffn_2_weight_ = (
self.ffn_2_weight_[:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1)]
.transpose(0, 1)
.contiguous()
.cuda()
)

if f"model.layers.{self.layer_num_}.mlp.c_proj.bias" in weights:
self.ffn_2_bias_ = (
weights[f"model.layers.{self.layer_num_}.mlp.c_proj.bias"].to(self.data_type_).contiguous().cuda()
)

return
Loading

0 comments on commit 486f647

Please sign in to comment.