Skip to content

Commit

Permalink
Fix the KeyError when loading bloom-based models (vllm-project#441)
Browse files Browse the repository at this point in the history
  • Loading branch information
HermitSun authored Jul 14, 2023
1 parent 7b6ae94 commit dbed690
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions vllm/model_executor/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,17 @@ def load_weights(self,
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if not name.startswith("transformer."):
name = "transformer." + name
if name == "lm_head.weight":
# Since hidden_states are parallelized, we need to
# load lm_head.weight in parallel.
self._column_parallel_weights.append(name)
# If lm_head is provided, use it instead.
param = self.lm_head_weight
else:
if not name.startswith("transformer."):
name = "transformer." + name
param = state_dict[name]

param = state_dict[name]
if "query_key_value" in name:
# NOTE(woosuk): BLOOM's fused QKV has the shape of
# [num_heads * 3 * head_size, hidden_size], while the
Expand Down

0 comments on commit dbed690

Please sign in to comment.