forked from ModelTC/lightllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add support for starcoder2 (ModelTC#347)
- Loading branch information
Showing
8 changed files
with
441 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
174 changes: 174 additions & 0 deletions
174
lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
import torch | ||
import triton | ||
|
||
from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward | ||
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd | ||
|
||
from lightllm.models.starcoder2.layer_weights.transformer_layer_weight import Starcoder2TransformerLayerWeight | ||
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer | ||
|
||
from lightllm.models.mistral.infer_struct import MistralInferStateInfo | ||
from lightllm.models.mistral.triton_kernel.context_flashattention_nopad import context_attention_fwd | ||
from lightllm.models.mistral.triton_kernel.token_attention_nopad_att1 import token_att_fwd | ||
from lightllm.models.mistral.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 | ||
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd | ||
|
||
|
||
class Starcoder2TransformerLayerInfer(LlamaTransformerLayerInfer): | ||
def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): | ||
super().__init__(layer_num, tp_rank, world_size, network_config, mode) | ||
self._bind_func() | ||
|
||
def _bind_func(self): | ||
self._token_attention_kernel = self._token_decode_attention_normal | ||
self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal | ||
|
||
return | ||
|
||
def _att_norm( | ||
self, input, infer_state: MistralInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight | ||
) -> torch.Tensor: | ||
return layernorm_forward( | ||
input.view(-1, self.embed_dim_), | ||
weight=layer_weight.att_norm_weight_, | ||
bias=layer_weight.att_norm_bias_, | ||
eps=self.eps_, | ||
) | ||
|
||
def _ffn_norm( | ||
self, input, infer_state: MistralInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight | ||
) -> torch.Tensor: | ||
return layernorm_forward( | ||
input.view(-1, self.embed_dim_), | ||
weight=layer_weight.ffn_norm_weight_, | ||
bias=layer_weight.ffn_norm_bias_, | ||
eps=self.eps_, | ||
) | ||
|
||
def _get_qkv( | ||
self, input, cache_kv, infer_state: MistralInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight | ||
) -> torch.Tensor: | ||
q = torch.addmm(layer_weight.q_bias_, input.view(-1, self.embed_dim_), layer_weight.q_weight_) | ||
torch.addmm( | ||
layer_weight.kv_bias_, | ||
input.view(-1, self.embed_dim_), | ||
layer_weight.kv_weight_, | ||
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), | ||
) | ||
rotary_emb_fwd( | ||
q.view(-1, self.tp_q_head_num_, self.head_dim_), | ||
cache_kv[:, 0 : self.tp_k_head_num_, :], | ||
infer_state.position_cos, | ||
infer_state.position_sin, | ||
) | ||
return q, cache_kv | ||
|
||
def _get_o( | ||
self, input, infer_state: MistralInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight | ||
) -> torch.Tensor: | ||
o_tensor = torch.addmm( | ||
layer_weight.o_bias_, | ||
input.view(-1, self.tp_o_head_num_ * self.head_dim_), | ||
layer_weight.o_weight_, | ||
beta=1.0 / self.world_size_, | ||
) | ||
return o_tensor | ||
|
||
def _ffn( | ||
self, input, infer_state: MistralInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight | ||
) -> torch.Tensor: | ||
ffn1_out = torch.addmm(layer_weight.ffn_1_bias_, input.view(-1, self.embed_dim_), layer_weight.ffn_1_weight_) | ||
input = None | ||
gelu_out = torch.nn.functional.gelu(ffn1_out, approximate="tanh") | ||
ffn1_out = None | ||
ffn2_out = torch.addmm( | ||
layer_weight.ffn_2_bias_, gelu_out, layer_weight.ffn_2_weight_, beta=1.0 / self.world_size_ | ||
) | ||
gelu_out = None | ||
return ffn2_out | ||
|
||
# use sliding_window code from mistral | ||
def _context_attention_kernel( | ||
self, q, kv, infer_state: MistralInferStateInfo, layer_weight, out=None | ||
) -> torch.Tensor: | ||
o_tensor = torch.empty_like(q) if out is None else out | ||
context_attention_fwd( | ||
q.view(-1, self.tp_q_head_num_, self.head_dim_), | ||
kv[:, 0 : self.tp_k_head_num_, :], | ||
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], | ||
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), | ||
infer_state.b_start_loc, | ||
infer_state.b_seq_len, | ||
infer_state.max_len_in_batch, | ||
infer_state.sliding_window, | ||
) | ||
return o_tensor | ||
|
||
def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo, layer_weight, out=None): | ||
total_token_num = infer_state.total_cache_num | ||
batch_size = infer_state.batch_size | ||
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) | ||
|
||
att_m_tensor = torch.empty((self.tp_q_head_num_, total_token_num), dtype=q.dtype, device="cuda") | ||
|
||
token_att_fwd( | ||
q.view(calcu_shape1), | ||
infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :], | ||
att_m_tensor, | ||
infer_state.req_manager.req_to_token_indexs, | ||
infer_state.b_req_idx, | ||
infer_state.b_start_loc, | ||
infer_state.b_seq_len, | ||
infer_state.b_start_loc_window, | ||
infer_state.b_att_start_loc, | ||
infer_state.b_att_seq_len, | ||
infer_state.sliding_window, | ||
) | ||
|
||
o_tensor = torch.empty_like(q) if out is None else out | ||
|
||
if triton.__version__ == "2.0.0": | ||
prob = torch.empty_like(att_m_tensor) | ||
token_softmax_fwd( | ||
att_m_tensor, infer_state.b_att_start_loc, infer_state.b_att_seq_len, prob, infer_state.sliding_window | ||
) | ||
att_m_tensor = None | ||
token_att_fwd2( | ||
prob, | ||
infer_state.mem_manager.kv_buffer[self.layer_num_][ | ||
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : | ||
], | ||
o_tensor.view(calcu_shape1), | ||
infer_state.req_manager.req_to_token_indexs, | ||
infer_state.b_req_idx, | ||
infer_state.b_start_loc, | ||
infer_state.b_seq_len, | ||
infer_state.b_start_loc_window, | ||
infer_state.b_att_start_loc, | ||
infer_state.b_att_seq_len, | ||
) | ||
prob = None | ||
return o_tensor | ||
elif triton.__version__ >= "2.1.0": | ||
from lightllm.models.mistral.triton_kernel.token_attention_softmax_and_reducev import ( | ||
token_softmax_reducev_fwd, | ||
) | ||
|
||
token_softmax_reducev_fwd( | ||
att_m_tensor, | ||
infer_state.mem_manager.kv_buffer[self.layer_num_][ | ||
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : | ||
], | ||
o_tensor.view(calcu_shape1), | ||
infer_state.req_manager.req_to_token_indexs, | ||
infer_state.b_req_idx, | ||
infer_state.b_start_loc, | ||
infer_state.b_seq_len, | ||
infer_state.b_start_loc_window, | ||
infer_state.b_att_start_loc, | ||
infer_state.b_att_seq_len, | ||
infer_state.other_kv_index, | ||
) | ||
return o_tensor | ||
else: | ||
raise Exception("not support triton version") |
Empty file.
37 changes: 37 additions & 0 deletions
37
lightllm/models/starcoder2/layer_weights/pre_and_post_layer_weight.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import numpy as np | ||
from lightllm.common.basemodel import PreAndPostLayerWeight | ||
|
||
|
||
class Starcoder2PreAndPostLayerWeight(PreAndPostLayerWeight): | ||
def __init__(self, tp_rank, world_size, data_type, network_config, mode): | ||
super().__init__(tp_rank, world_size, data_type, network_config, mode) | ||
return | ||
|
||
def load_hf_weights(self, weights): | ||
vob_size = self.network_config_["vocab_size"] | ||
split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64) | ||
split_start = split_indexes[self.tp_rank_] | ||
split_end = split_indexes[self.tp_rank_ + 1] | ||
if "model.embed_tokens.weight" in weights: | ||
self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) | ||
|
||
# for starcoder2-3b and 7b which didn't use lm_head.weight (tie_word_embeddings) | ||
self.lm_head_weight_ = self.wte_weight_ | ||
|
||
if "lm_head.weight" in weights: | ||
self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) | ||
|
||
if "model.norm.weight" in weights: | ||
self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) | ||
|
||
if "model.norm.bias" in weights: | ||
self.final_norm_bias_ = self._cuda(weights["model.norm.bias"]) | ||
|
||
return | ||
|
||
def verify_load(self): | ||
errors = "weights load not ok" | ||
weights = [self.wte_weight_, self.lm_head_weight_, self.final_norm_weight_, self.final_norm_bias_] | ||
for i in range(len(weights)): | ||
assert weights[i] is not None, "index:" + str(i) + " " + errors | ||
return |
141 changes: 141 additions & 0 deletions
141
lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
from lightllm.common.basemodel import TransformerLayerWeight | ||
|
||
|
||
class Starcoder2TransformerLayerWeight(TransformerLayerWeight): | ||
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): | ||
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) | ||
assert network_config["num_attention_heads"] % self.world_size_ == 0 | ||
|
||
def init_static_params(self): | ||
pass | ||
|
||
def load_hf_weights(self, weights): | ||
self._load_qkvo_weights(weights) | ||
self._load_ffn_weights(weights) | ||
return | ||
|
||
def verify_load(self): | ||
errors = "weights load not ok" | ||
weights = [ | ||
self.att_norm_weight_, | ||
self.att_norm_bias_, | ||
self.q_weight_, | ||
self.kv_weight_, | ||
self.q_bias_, | ||
self.kv_bias_, | ||
self.o_weight_, | ||
self.o_bias_, | ||
self.ffn_norm_weight_, | ||
self.ffn_norm_bias_, | ||
self.ffn_1_weight_, | ||
self.ffn_1_bias_, | ||
self.ffn_2_weight_, | ||
self.ffn_2_bias_, | ||
] | ||
for i in range(len(weights)): | ||
assert weights[i] is not None, "index:" + str(i) + " " + errors | ||
return | ||
|
||
def _load_qkvo_weights(self, weights): | ||
# input layernorm params | ||
if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights: | ||
self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.weight"]) | ||
|
||
if f"model.layers.{self.layer_num_}.input_layernorm.bias" in weights: | ||
self.att_norm_bias_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.bias"]) | ||
|
||
n_embed = self.network_config_["hidden_size"] | ||
q_split_n_embed = n_embed // self.world_size_ | ||
kv_split_n_embed = ( | ||
n_embed | ||
// self.network_config_["num_attention_heads"] | ||
* self.network_config_["num_key_value_heads"] | ||
// self.world_size_ | ||
) | ||
# q k v weights for llama | ||
if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: | ||
self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] | ||
self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] | ||
self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) | ||
if f"model.layers.{self.layer_num_}.self_attn.q_proj.bias" in weights: | ||
self.q_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.bias"][ | ||
q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1) | ||
] | ||
self.q_bias_ = self._cuda(self.q_bias_) | ||
if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: | ||
k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] | ||
k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] | ||
self.k_weight_ = k_weight_.transpose(0, 1) | ||
if f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" in weights: | ||
self.k_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"][ | ||
kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) | ||
] | ||
if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: | ||
v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] | ||
v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] | ||
self.v_weight_ = v_weight_.transpose(0, 1) | ||
if f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" in weights: | ||
self.v_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"][ | ||
kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) | ||
] | ||
|
||
self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) | ||
|
||
self._try_cat_to(["k_bias_", "v_bias_"], "kv_bias_", cat_dim=0) | ||
|
||
# attention output dense params | ||
if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: | ||
self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] | ||
self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] | ||
self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) | ||
if f"model.layers.{self.layer_num_}.self_attn.o_proj.bias" in weights: | ||
self.o_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.bias"] | ||
self.o_bias_ = self._cuda(self.o_bias_) | ||
return | ||
|
||
def _load_ffn_weights(self, weights): | ||
if f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights: | ||
# post attention layernorm params | ||
self.ffn_norm_weight_ = self._cuda( | ||
weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"] | ||
) | ||
self.ffn_norm_bias_ = self._cuda(weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.bias"]) | ||
|
||
# ffn params | ||
n_embed = self.network_config_["hidden_size"] | ||
intermediate_size = n_embed * 4 | ||
split_inter_size = intermediate_size // self.world_size_ | ||
if f"model.layers.{self.layer_num_}.mlp.c_fc.weight" in weights: | ||
self.ffn_1_weight_ = weights[f"model.layers.{self.layer_num_}.mlp.c_fc.weight"].to(self.data_type_) | ||
self.ffn_1_weight_ = ( | ||
self.ffn_1_weight_[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :] | ||
.transpose(0, 1) | ||
.contiguous() | ||
.cuda() | ||
) | ||
|
||
if f"model.layers.{self.layer_num_}.mlp.c_fc.bias" in weights: | ||
self.ffn_1_bias_ = ( | ||
weights[f"model.layers.{self.layer_num_}.mlp.c_fc.bias"][ | ||
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) | ||
] | ||
.to(self.data_type_) | ||
.contiguous() | ||
.cuda() | ||
) | ||
|
||
if f"model.layers.{self.layer_num_}.mlp.c_proj.weight" in weights: | ||
self.ffn_2_weight_ = weights[f"model.layers.{self.layer_num_}.mlp.c_proj.weight"].to(self.data_type_) | ||
self.ffn_2_weight_ = ( | ||
self.ffn_2_weight_[:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1)] | ||
.transpose(0, 1) | ||
.contiguous() | ||
.cuda() | ||
) | ||
|
||
if f"model.layers.{self.layer_num_}.mlp.c_proj.bias" in weights: | ||
self.ffn_2_bias_ = ( | ||
weights[f"model.layers.{self.layer_num_}.mlp.c_proj.bias"].to(self.data_type_).contiguous().cuda() | ||
) | ||
|
||
return |
Oops, something went wrong.