Skip to content

Commit

Permalink
Merge branch 'main' of github.com:THUDM/CodeGeeX
Browse files Browse the repository at this point in the history
  • Loading branch information
Stanislas0 committed Feb 20, 2023
2 parents 4758dce + 54052c3 commit 14493cb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
8 changes: 4 additions & 4 deletions codegeex/megatron/convert_ckpt_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def main():
word_emb_dict = get_element_from_dict_by_path(
output_state_dict[i], "module.language_model.embedding.word_embeddings"
)
word_emb_dict["weight"] = out_word_embeddings[i]
word_emb_dict["weight"] = out_word_embeddings[i].clone()

print("Converting QueryEmbedding layers...")
query_embeddings = state_dict['module']['language_model']['topQueryEmbedding']['top_query_embeddings']['weight']
Expand All @@ -82,7 +82,7 @@ def main():
query_emb_dict = get_element_from_dict_by_path(
output_state_dict[i], "module.language_model.topQueryEmbedding.top_query_embeddings"
)
query_emb_dict["weight"] = out_query_embeddings[i]
query_emb_dict["weight"] = out_query_embeddings[i].clone()

print("Converting Transformer layers...")
for layer_name in state_dict['module']['language_model']['transformer'].keys():
Expand All @@ -109,10 +109,10 @@ def main():
for i in range(args.target_tensor_model_parallel_size):
params_dict = get_element_from_dict_by_path(output_state_dict[i], "module.language_model.transformer")
if type(params) is tuple:
params_dict[layer_name] = params[i]
params_dict[layer_name] = params[i].clone()
else:
params_dict[layer_name] = params

os.makedirs(args.save_ckpt_path, exist_ok=True)
for rank in range(args.target_tensor_model_parallel_size):
save_ckpt_path = os.path.join(args.save_ckpt_path, f"mp_rank_{rank:02d}_model_states.pt")
Expand Down
10 changes: 9 additions & 1 deletion codegeex/megatron/merge_ckpt_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def get_change_ckpt_args(parser):
required=True,
help='path to save ".pt" checkpoint.',
)
group.add_argument(
'--save-name',
type=str,
help='name of checkpoint.',
)
group.add_argument(
'--source-tensor-model-parallel-size',
type=int,
Expand Down Expand Up @@ -123,7 +128,10 @@ def main():
save_ckpt_path = args.save_ckpt_path
else:
os.makedirs(args.save_ckpt_path, exist_ok=True)
save_ckpt_path = os.path.join(args.save_ckpt_path, "mp_rank_00_model_states.pt")
if args.save_name:
save_ckpt_path = os.path.join(args.save_ckpt_path, args.save_name)
else:
save_ckpt_path = os.path.join(args.save_ckpt_path, "mp_rank_00_model_states.pt")

torch.save(sd, save_ckpt_path)
print(f"Converted checkpoint saved in {save_ckpt_path}.")
Expand Down

0 comments on commit 14493cb

Please sign in to comment.