Skip to content

Commit

Permalink
Fix parallel ckpt conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
Stanislas0 committed Feb 20, 2023
1 parent b0544da commit 140a5ea
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions codegeex/megatron/convert_ckpt_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main():
word_emb_dict = get_element_from_dict_by_path(
output_state_dict[i], "module.language_model.embedding.word_embeddings"
)
word_emb_dict["weight"] = out_word_embeddings[i]
word_emb_dict["weight"] = out_word_embeddings[i].clone()

print("Converting QueryEmbedding layers...")
query_embeddings = state_dict['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight']
Expand All @@ -82,7 +82,7 @@ def main():
query_emb_dict = get_element_from_dict_by_path(
output_state_dict[i], "module.language_model.topQueryEmbedding.top_query_embeddings"
)
query_emb_dict["weight"] = out_query_embeddings[i]
query_emb_dict["weight"] = out_query_embeddings[i].clone()

print("Converting Transformer layers...")
for layer_name in state_dict['module']['language_model']['transformer'].keys():
Expand All @@ -109,10 +109,10 @@ def main():
for i in range(args.target_tensor_model_parallel_size):
params_dict = get_element_from_dict_by_path(output_state_dict[i], "module.language_model.transformer")
if type(params) is tuple:
params_dict[layer_name] = params[i]
params_dict[layer_name] = params[i].clone()
else:
params_dict[layer_name] = params

os.makedirs(args.save_ckpt_path, exist_ok=True)
for rank in range(args.target_tensor_model_parallel_size):
save_ckpt_path = os.path.join(args.save_ckpt_path, f"mp_rank_{rank:02d}_model_states.pt")
Expand Down

0 comments on commit 140a5ea

Please sign in to comment.