Skip to content

Commit

Permalink
option to export embeddings in PyTorch format
Browse files Browse the repository at this point in the history
  • Loading branch information
Guillaume Lample committed Mar 26, 2018
1 parent 93a2ae4 commit 6142387
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 34 deletions.
2 changes: 1 addition & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
parser = argparse.ArgumentParser(description='Evaluation')
parser.add_argument("--verbose", type=int, default=2, help="Verbose level (2:debug, 1:info, 0:warning)")
parser.add_argument("--exp_path", type=str, default="", help="Where to store experiment logs and models")
parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU")
parser.add_argument("--exp_name", type=str, default="debug", help="Experiment name")
parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU")
# data
parser.add_argument("--src_lang", type=str, default="", help="Source language")
parser.add_argument("--tgt_lang", type=str, default="", help="Target language")
Expand Down
13 changes: 8 additions & 5 deletions src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,22 +242,25 @@ def reload_best(self):

def export(self):
"""
Export embeddings to a text file.
Export embeddings.
"""
params = self.params

# load all embeddings
params.src_dico.id2word, src_emb = load_external_embeddings(params, source=True, full_vocab=True)
params.tgt_dico.id2word, tgt_emb = load_external_embeddings(params, source=False, full_vocab=True)
logger.info("Reloading all embeddings for mapping ...")
params.src_dico, src_emb = load_external_embeddings(params, source=True, full_vocab=True)
params.tgt_dico, tgt_emb = load_external_embeddings(params, source=False, full_vocab=True)

# apply same normalization as during training
normalize_embeddings(src_emb, params.normalize_embeddings, mean=params.src_mean)
normalize_embeddings(tgt_emb, params.normalize_embeddings, mean=params.tgt_mean)

# map source embeddings to the target space
bs = 4096
for k in range(0, len(src_emb), bs):
logger.info("Map source embeddings to the target space ...")
for i, k in enumerate(range(0, len(src_emb), bs)):
x = Variable(src_emb[k:k + bs], volatile=True)
src_emb[k:k + bs] = self.mapping(x.cuda() if params.cuda else x).data.cpu()

export_embeddings(src_emb.numpy(), tgt_emb.numpy(), params)
# write embeddings to the disk
export_embeddings(src_emb, tgt_emb, params)
52 changes: 30 additions & 22 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def initialize_exp(params):
torch.cuda.manual_seed(params.seed)

# dump parameters
params.exp_path = get_exp_path(params) if not params.exp_path else params.exp_path
params.exp_path = get_exp_path(params)
with io.open(os.path.join(params.exp_path, 'params.pkl'), 'wb') as f:
pickle.dump(params, f)

Expand Down Expand Up @@ -217,7 +217,7 @@ def get_exp_path(params):
Create a directory to store the experiment.
"""
# create the main dump path if it does not exist
exp_folder = MAIN_DUMP_PATH
exp_folder = MAIN_DUMP_PATH if params.exp_path == '' else params.exp_path
if not os.path.exists(exp_folder):
subprocess.Popen("mkdir %s" % exp_folder, shell=True).wait()
assert params.exp_name != ''
Expand Down Expand Up @@ -323,24 +323,32 @@ def normalize_embeddings(emb, types, mean=None):

def export_embeddings(src_emb, tgt_emb, params):
"""
Export embeddings to a text file.
Export embeddings to a text or a PyTorch file.
"""
src_id2word = params.src_dico.id2word
tgt_id2word = params.tgt_dico.id2word
n_src = len(src_id2word)
n_tgt = len(tgt_id2word)
dim = src_emb.shape[1]
src_path = os.path.join(params.exp_path, 'vectors-%s.txt' % params.src_lang)
tgt_path = os.path.join(params.exp_path, 'vectors-%s.txt' % params.tgt_lang)
# source embeddings
logger.info('Writing source embeddings to %s ...' % src_path)
with io.open(src_path, 'w', encoding='utf-8') as f:
f.write(u"%i %i\n" % (n_src, dim))
for i in range(len(src_id2word)):
f.write(u"%s %s\n" % (src_id2word[i], " ".join('%.5f' % x for x in src_emb[i])))
# target embeddings
logger.info('Writing target embeddings to %s ...' % tgt_path)
with io.open(tgt_path, 'w', encoding='utf-8') as f:
f.write(u"%i %i\n" % (n_tgt, dim))
for i in range(len(tgt_id2word)):
f.write(u"%s %s\n" % (tgt_id2word[i], " ".join('%.5f' % x for x in tgt_emb[i])))
assert params.export in ["text", "pth"]

# text file
if params.export == "text":
src_path = os.path.join(params.exp_path, 'vectors-%s.txt' % params.src_lang)
tgt_path = os.path.join(params.exp_path, 'vectors-%s.txt' % params.tgt_lang)
# source embeddings
logger.info('Writing source embeddings to %s ...' % src_path)
with io.open(src_path, 'w', encoding='utf-8') as f:
f.write(u"%i %i\n" % src_emb.size())
for i in range(len(params.src_dico)):
f.write(u"%s %s\n" % (params.src_dico[i], " ".join('%.5f' % x for x in src_emb[i])))
# target embeddings
logger.info('Writing target embeddings to %s ...' % tgt_path)
with io.open(tgt_path, 'w', encoding='utf-8') as f:
f.write(u"%i %i\n" % tgt_emb.size())
for i in range(len(params.tgt_dico)):
f.write(u"%s %s\n" % (params.tgt_dico[i], " ".join('%.5f' % x for x in tgt_emb[i])))

# PyTorch file
if params.export == "pth":
src_path = os.path.join(params.exp_path, 'vectors-%s.pth' % params.src_lang)
tgt_path = os.path.join(params.exp_path, 'vectors-%s.pth' % params.tgt_lang)
logger.info('Writing source embeddings to %s ...' % src_path)
torch.save({'dico': params.src_dico, 'vectors': src_emb}, src_path)
logger.info('Writing target embeddings to %s ...' % tgt_path)
torch.save({'dico': params.tgt_dico, 'vectors': tgt_emb}, tgt_path)
7 changes: 4 additions & 3 deletions supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
parser.add_argument("--seed", type=int, default=-1, help="Initialization seed")
parser.add_argument("--verbose", type=int, default=2, help="Verbose level (2:debug, 1:info, 0:warning)")
parser.add_argument("--exp_path", type=str, default="", help="Where to store experiment logs and models")
parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU")
parser.add_argument("--export", type=bool_flag, default=True, help="Export embeddings after training")
parser.add_argument("--exp_name", type=str, default="debug", help="Experiment name")
parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU")
parser.add_argument("--export", type=str, default="text", help="Export embeddings after training (text / pth)")

# data
parser.add_argument("--src_lang", type=str, default='en', help="Source language")
Expand Down Expand Up @@ -62,6 +62,7 @@
assert params.dico_max_size == 0 or params.dico_max_size > params.dico_min_size
assert os.path.isfile(params.src_emb)
assert os.path.isfile(params.tgt_emb)
assert params.export in ["", "text", "pth"]

# build logger / model / trainer / evaluator
logger = initialize_exp(params)
Expand Down Expand Up @@ -98,7 +99,7 @@
logger.info('End of iteration %i.\n\n' % n_iter)


# export embeddings to a text format
# export embeddings
if params.export:
trainer.reload_best()
trainer.export()
7 changes: 4 additions & 3 deletions unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
parser.add_argument("--seed", type=int, default=-1, help="Initialization seed")
parser.add_argument("--verbose", type=int, default=2, help="Verbose level (2:debug, 1:info, 0:warning)")
parser.add_argument("--exp_path", type=str, default="", help="Where to store experiment logs and models")
parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU")
parser.add_argument("--export", type=bool_flag, default=True, help="Export embeddings after training")
parser.add_argument("--exp_name", type=str, default="debug", help="Experiment name")
parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU")
parser.add_argument("--export", type=str, default="text", help="Export embeddings after training (text / pth)")
# data
parser.add_argument("--src_lang", type=str, default='en', help="Source language")
parser.add_argument("--tgt_lang", type=str, default='es', help="Target language")
Expand Down Expand Up @@ -85,6 +85,7 @@
assert 0 < params.lr_shrink <= 1
assert os.path.isfile(params.src_emb)
assert os.path.isfile(params.tgt_emb)
assert params.export in ["", "text", "pth"]

# build model / trainer / evaluator
logger = initialize_exp(params)
Expand Down Expand Up @@ -176,7 +177,7 @@
logger.info('End of refinement iteration %i.\n\n' % n_iter)


# export embeddings to a text format
# export embeddings
if params.export:
trainer.reload_best()
trainer.export()

0 comments on commit 6142387

Please sign in to comment.