Skip to content

Commit

Permalink
Add pt to pdparams convert script
Browse files Browse the repository at this point in the history
  • Loading branch information
saas1600 committed Dec 9, 2022
1 parent 82d2c02 commit c2db09b
Showing 1 changed file with 55 additions and 0 deletions.
55 changes: 55 additions & 0 deletions codegeex/paddle/pt_to_pdparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import argparse
import paddle
import torch

linear_layer = [
"mlp.dense_h_to_4h",
"mlp.dense_4h_to_h",
"attention.query",
"attention.key",
"attention.value",
"attention.dense",
]


def WalkDict(x):
for i in x:
if isinstance(x[i], dict):
WalkDict(x[i])
elif isinstance(x[i], torch.Tensor):
print(f"Converting '{i}' from 'torch.Tensor' to 'numpy.ndarray'.")
npy = x[i].cpu().numpy()
if any([f".{layer}.weight" in i for layer in linear_layer]):
print(f"Transposing linear layer weight '{i}'.")
x[i] = npy.T
else:
x[i] = npy


def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument(
"--pt",
type=str,
required=True,
help="Path to pt checkpoint."
)
parser.add_argument(
"--pdparams",
type=str,
required=True,
help="Path to pdparams checkpoint."
)
opt = parser.parse_args()
return opt


def main(opt):
state_dict = torch.load(opt.pt)
WalkDict(state_dict)
paddle.save(state_dict, opt.pdparams)


if __name__ == "__main__":
opt = parse_opt()
main(opt)

0 comments on commit c2db09b

Please sign in to comment.