Skip to content

Commit

Permalink
add vocab padding for LLama(Support WizardLM) (vllm-project#411)
Browse files Browse the repository at this point in the history
  • Loading branch information
esmeetu authored Jul 14, 2023
1 parent c6dfc3c commit 7b6ae94
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,9 @@ def __init__(self, config: LlamaConfig):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
perform_initialization=False)
vocab_size, config.hidden_size, perform_initialization=False)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
Expand Down Expand Up @@ -228,8 +227,9 @@ def __init__(self, config):
super().__init__()
self.config = config
self.model = LlamaModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
Expand Down Expand Up @@ -259,6 +259,8 @@ def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()

Expand All @@ -267,6 +269,17 @@ def load_weights(self,
if "rotary_emb.inv_freq" in name:
continue

if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name]
# Consider padding in the vocab size.
padded_vocab_size = (param.shape[0] *
tensor_model_parallel_world_size)
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)

is_attention_weight = False
for stride_id, att_weight_name in enumerate(
["q_proj", "k_proj", "v_proj"]):
Expand Down

0 comments on commit 7b6ae94

Please sign in to comment.