Skip to content

Commit

Permalink
Merge pull request karpathy#19 from nat/patch-1
Browse files Browse the repository at this point in the history
Strip unwanted prefix from state keys when loading model in sample.py
  • Loading branch information
karpathy authored Jan 5, 2023
2 parents d562b3e + 2b9e168 commit 529c967
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
model.load_state_dict(checkpoint['model'])
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model.to(device)
if compile:
Expand Down

0 comments on commit 529c967

Please sign in to comment.