Skip to content

Commit

Permalink
Fixed unavailable import in checkpoint script
Browse files Browse the repository at this point in the history
  • Loading branch information
taidopurason committed Nov 21, 2022
1 parent 05183bf commit d42f9f2
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions scripts/create_mix_match_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import argparse
import ast
import gc
from collections import OrderedDict

import torch

from fairseq.dataclass.utils import eval_dict
def eval_dict(x, key_type=str, value_type=str):
if x is None:
return None
if isinstance(x, str):
if len(x) == 0:
return {}
x = ast.literal_eval(x)

return {key_type(k): value_type(v) for k, v in x.items()}


def load_state_dict(path, keep_prefix=None, rename_prefix=None):
Expand Down Expand Up @@ -44,7 +53,7 @@ def rename_prefix(name, prefix, rename):
parser.add_argument("--encoder-rename-prefix", required=False, default=None)
parser.add_argument("--decoder-rename-prefix", required=False, default=None)
parser.add_argument(
"--extra-rename-prefixes", required=False, default=None, type=eval_dict
"--extra-rename-prefixes", required=False, default=None, type=lambda x: eval_dict(x, str, str)
)

args = parser.parse_args()
Expand Down

0 comments on commit d42f9f2

Please sign in to comment.