Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/facebookresearch/MUSE
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexis Conneau committed Apr 23, 2018
2 parents 0dc4e45 + 6e0b460 commit 4884ad2
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 56 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ python evaluate.py --src_lang en --src_emb data/wiki.en.vec --max_vocab 200000
python evaluate.py --src_lang en --tgt_lang es --src_emb data/wiki.en-es.en.vec --tgt_emb data/wiki.en-es.es.vec --max_vocab 200000
```

## Word embedding format
By default, the aligned embeddings are exported to a text format at the end of experiments: `--export txt`. Exporting embeddings to a text file can take a while if you have a lot of embeddings. For a very fast export, you can set `--export pth` to export the embeddings in a PyTorch binary file, or simply disable the export (`--export ""`).

When loading embeddings, the model can load:
* PyTorch binary files previously generated by MUSE (.pth files)
* fastText binary files previously generated by fastText (.bin files)
* text files (text file with one word embedding per line)

The two first options are very fast and can load 1 million embeddings in a few seconds, while loading text files can take a while.

## Download
We provide multilingual embeddings and ground-truth bilingual dictionaries.

Expand Down
5 changes: 3 additions & 2 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
parser = argparse.ArgumentParser(description='Evaluation')
parser.add_argument("--verbose", type=int, default=2, help="Verbose level (2:debug, 1:info, 0:warning)")
parser.add_argument("--exp_path", type=str, default="", help="Where to store experiment logs and models")
parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU")
parser.add_argument("--exp_name", type=str, default="debug", help="Experiment name")
parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
parser.add_argument("--cuda", type=bool_flag, default=True, help="Run on GPU")
# data
parser.add_argument("--src_lang", type=str, default="", help="Source language")
parser.add_argument("--tgt_lang", type=str, default="", help="Target language")
# reload pre-trained embeddings
parser.add_argument("--src_emb", type=str, default="", help="Reload source embeddings")
parser.add_argument("--tgt_emb", type=str, default="", help="Reload target embeddings")
parser.add_argument("--max_vocab", type=int, default=200000, help="Maximum vocabulary size")
parser.add_argument("--max_vocab", type=int, default=200000, help="Maximum vocabulary size (-1 to disable)")
parser.add_argument("--emb_dim", type=int, default=300, help="Embedding dimension")
parser.add_argument("--normalize_embeddings", type=str, default="", help="Normalize embeddings before training")

Expand Down
9 changes: 9 additions & 0 deletions src/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,12 @@ def index(self, word):
Returns the index of the specified word.
"""
return self.word2id[word]

def prune(self, max_vocab):
"""
Limit the vocabulary size.
"""
assert max_vocab >= 1
self.id2word = {k: v for k, v in self.id2word.items() if k < max_vocab}
self.word2id = {v: k for k, v in self.id2word.items()}
self.check_valid()
6 changes: 3 additions & 3 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from torch import nn

from .utils import load_external_embeddings, normalize_embeddings
from .utils import load_embeddings, normalize_embeddings


class Discriminator(nn.Module):
Expand Down Expand Up @@ -43,14 +43,14 @@ def build_model(params, with_dis):
Build all components of the model.
"""
# source embeddings
src_dico, _src_emb = load_external_embeddings(params, source=True)
src_dico, _src_emb = load_embeddings(params, source=True)
params.src_dico = src_dico
src_emb = nn.Embedding(len(src_dico), params.emb_dim, sparse=True)
src_emb.weight.data.copy_(_src_emb)

# target embeddings
if params.tgt_lang:
tgt_dico, _tgt_emb = load_external_embeddings(params, source=False)
tgt_dico, _tgt_emb = load_embeddings(params, source=False)
params.tgt_dico = tgt_dico
tgt_emb = nn.Embedding(len(tgt_dico), params.emb_dim, sparse=True)
tgt_emb.weight.data.copy_(_tgt_emb)
Expand Down
15 changes: 9 additions & 6 deletions src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.autograd import Variable
from torch.nn import functional as F

from .utils import get_optimizer, load_external_embeddings, normalize_embeddings, export_embeddings
from .utils import get_optimizer, load_embeddings, normalize_embeddings, export_embeddings
from .utils import clip_parameters
from .dico_builder import build_dictionary
from .evaluation.word_translation import DIC_EVAL_PATH, load_identical_char_dico, load_dictionary
Expand Down Expand Up @@ -242,22 +242,25 @@ def reload_best(self):

def export(self):
"""
Export embeddings to a text file.
Export embeddings.
"""
params = self.params

# load all embeddings
params.src_dico.id2word, src_emb = load_external_embeddings(params, source=True, full_vocab=True)
params.tgt_dico.id2word, tgt_emb = load_external_embeddings(params, source=False, full_vocab=True)
logger.info("Reloading all embeddings for mapping ...")
params.src_dico, src_emb = load_embeddings(params, source=True, full_vocab=True)
params.tgt_dico, tgt_emb = load_embeddings(params, source=False, full_vocab=True)

# apply same normalization as during training
normalize_embeddings(src_emb, params.normalize_embeddings, mean=params.src_mean)
normalize_embeddings(tgt_emb, params.normalize_embeddings, mean=params.tgt_mean)

# map source embeddings to the target space
bs = 4096
for k in range(0, len(src_emb), bs):
logger.info("Map source embeddings to the target space ...")
for i, k in enumerate(range(0, len(src_emb), bs)):
x = Variable(src_emb[k:k + bs], volatile=True)
src_emb[k:k + bs] = self.mapping(x.cuda() if params.cuda else x).data.cpu()

export_embeddings(src_emb.numpy(), tgt_emb.numpy(), params)
# write embeddings to the disk
export_embeddings(src_emb, tgt_emb, params)
185 changes: 148 additions & 37 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def initialize_exp(params):
torch.cuda.manual_seed(params.seed)

# dump parameters
params.exp_path = get_exp_path(params) if not params.exp_path else params.exp_path
params.exp_path = get_exp_path(params)
with io.open(os.path.join(params.exp_path, 'params.pkl'), 'wb') as f:
pickle.dump(params, f)

Expand All @@ -68,6 +68,18 @@ def initialize_exp(params):
return logger


def load_fasttext_model(path):
"""
Load a binarized fastText model.
"""
try:
import fastText
except ImportError:
raise Exception("Unable to import fastText. Please install fastText for Python: "
"https://github.com/facebookresearch/fastText")
return fastText.load_model(path)


def bow(sentences, word_vec, normalize=False):
"""
Get sentence representations using average bag-of-words.
Expand Down Expand Up @@ -217,19 +229,23 @@ def get_exp_path(params):
Create a directory to store the experiment.
"""
# create the main dump path if it does not exist
exp_folder = MAIN_DUMP_PATH
exp_folder = MAIN_DUMP_PATH if params.exp_path == '' else params.exp_path
if not os.path.exists(exp_folder):
subprocess.Popen("mkdir %s" % exp_folder, shell=True).wait()
assert params.exp_name != ''
exp_folder = os.path.join(exp_folder, params.exp_name)
if not os.path.exists(exp_folder):
subprocess.Popen("mkdir %s" % exp_folder, shell=True).wait()
chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
while True:
exp_name = ''.join(random.choice(chars) for _ in range(10))
exp_path = os.path.join(exp_folder, exp_name)
if not os.path.isdir(exp_path):
break
if params.exp_id == '':
chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
while True:
exp_id = ''.join(random.choice(chars) for _ in range(10))
exp_path = os.path.join(exp_folder, exp_id)
if not os.path.isdir(exp_path):
break
else:
exp_path = os.path.join(exp_folder, params.exp_id)
assert not os.path.isdir(exp_path), exp_path
# create the dump folder
if not os.path.isdir(exp_path):
subprocess.Popen("mkdir %s" % exp_path, shell=True).wait()
Expand All @@ -245,15 +261,10 @@ def clip_parameters(model, clip):
x.data.clamp_(-clip, clip)


def load_external_embeddings(params, source, full_vocab=False):
def read_txt_embeddings(params, source, full_vocab):
"""
Reload pretrained embeddings from a text file.
- `full_vocab == False` means that we load the `params.max_vocab` most frequent words.
It is used at the beginning of the experiment.
- `full_vocab == True` means that we load the entire embedding text file,
before we export the embeddings at the end of the experiment.
"""
assert type(source) is bool
word2id = {}
vectors = []

Expand Down Expand Up @@ -290,19 +301,111 @@ def load_external_embeddings(params, source, full_vocab=False):
break

assert len(word2id) == len(vectors)
logger.info("Loaded %i pre-trained word embeddings" % len(vectors))
logger.info("Loaded %i pre-trained word embeddings." % len(vectors))

# compute new vocabulary / embeddings
id2word = {v: k for k, v in word2id.items()}
dico = Dictionary(id2word, word2id, lang)
embeddings = np.concatenate(vectors, 0)
embeddings = torch.from_numpy(embeddings).float()
embeddings = embeddings.cuda() if (params.cuda and not full_vocab) else embeddings
assert embeddings.size() == (len(word2id), params.emb_dim), ((len(word2id), params.emb_dim, embeddings.size()))

assert embeddings.size() == (len(dico), params.emb_dim)
return dico, embeddings


def select_subset(word_list, max_vocab):
"""
Select a subset of words to consider, to deal with words having embeddings
available in different casings. In particular, we select the embeddings of
the most frequent words, that are usually of better quality.
"""
word2id = {}
indexes = []
for i, word in enumerate(word_list):
word = word.lower()
if word not in word2id:
word2id[word] = len(word2id)
indexes.append(i)
if max_vocab > 0 and len(word2id) >= max_vocab:
break
assert len(word2id) == len(indexes)
return word2id, torch.LongTensor(indexes)


def load_pth_embeddings(params, source, full_vocab):
"""
Reload pretrained embeddings from a PyTorch binary file.
"""
# reload PyTorch binary file
lang = params.src_lang if source else params.tgt_lang
data = torch.load(params.src_emb if source else params.tgt_emb)
dico = data['dico']
embeddings = data['vectors']
assert dico.lang == lang
assert embeddings.size() == (len(dico), params.emb_dim)
logger.info("Loaded %i pre-trained word embeddings." % len(dico))

# select a subset of word embeddings (to deal with casing)
if not full_vocab:
word2id, indexes = select_subset([dico[i] for i in range(len(dico))], params.max_vocab)
id2word = {v: k for k, v in word2id.items()}
dico = Dictionary(id2word, word2id, lang)
embeddings = embeddings[indexes]

assert embeddings.size() == (len(dico), params.emb_dim)
return dico, embeddings


def load_bin_embeddings(params, source, full_vocab):
"""
Reload pretrained embeddings from a fastText binary file.
"""
# reload fastText binary file
lang = params.src_lang if source else params.tgt_lang
model = load_fasttext_model(params.src_emb if source else params.tgt_emb)
words = model.get_labels()
assert model.get_dimension() == params.emb_dim
logger.info("Loaded binary model. Generating embeddings ...")
embeddings = torch.from_numpy(np.concatenate([model.get_word_vector(w)[None] for w in words], 0))
logger.info("Generated embeddings for %i words." % len(words))
assert embeddings.size() == (len(words), params.emb_dim)

# select a subset of word embeddings (to deal with casing)
if not full_vocab:
word2id, indexes = select_subset(words, params.max_vocab)
embeddings = embeddings[indexes]
else:
word2id = {w: i for i, w in enumerate(words)}
id2word = {i: w for w, i in word2id.items()}
dico = Dictionary(id2word, word2id, lang)

assert embeddings.size() == (len(dico), params.emb_dim)
return dico, embeddings


def load_embeddings(params, source, full_vocab=False):
"""
Reload pretrained embeddings.
- `full_vocab == False` means that we load the `params.max_vocab` most frequent words.
It is used at the beginning of the experiment.
In that setting, if two words with a different casing occur, we lowercase both, and
only consider the most frequent one. For instance, if "London" and "london" are in
the embeddings file, we only consider the most frequent one, (in that case, probably
London). This is done to deal with the lowercased dictionaries.
- `full_vocab == True` means that we load the entire embedding text file,
before we export the embeddings at the end of the experiment.
"""
assert type(source) is bool and type(full_vocab) is bool
emb_path = params.src_emb if source else params.tgt_emb
if emb_path.endswith('.pth'):
return load_pth_embeddings(params, source, full_vocab)
if emb_path.endswith('.bin'):
return load_bin_embeddings(params, source, full_vocab)
else:
return read_txt_embeddings(params, source, full_vocab)


def normalize_embeddings(emb, types, mean=None):
"""
Normalize embeddings by their norms / recenter them.
Expand All @@ -323,24 +426,32 @@ def normalize_embeddings(emb, types, mean=None):

def export_embeddings(src_emb, tgt_emb, params):
"""
Export embeddings to a text file.
"""
src_id2word = params.src_dico.id2word
tgt_id2word = params.tgt_dico.id2word
n_src = len(src_id2word)
n_tgt = len(tgt_id2word)
dim = src_emb.shape[1]
src_path = os.path.join(params.exp_path, 'vectors-%s.txt' % params.src_lang)
tgt_path = os.path.join(params.exp_path, 'vectors-%s.txt' % params.tgt_lang)
# source embeddings
logger.info('Writing source embeddings to %s ...' % src_path)
with io.open(src_path, 'w', encoding='utf-8') as f:
f.write(u"%i %i\n" % (n_src, dim))
for i in range(len(src_id2word)):
f.write(u"%s %s\n" % (src_id2word[i], " ".join('%.5f' % x for x in src_emb[i])))
# target embeddings
logger.info('Writing target embeddings to %s ...' % tgt_path)
with io.open(tgt_path, 'w', encoding='utf-8') as f:
f.write(u"%i %i\n" % (n_tgt, dim))
for i in range(len(tgt_id2word)):
f.write(u"%s %s\n" % (tgt_id2word[i], " ".join('%.5f' % x for x in tgt_emb[i])))
Export embeddings to a text or a PyTorch file.
"""
assert params.export in ["txt", "pth"]

# text file
if params.export == "txt":
src_path = os.path.join(params.exp_path, 'vectors-%s.txt' % params.src_lang)
tgt_path = os.path.join(params.exp_path, 'vectors-%s.txt' % params.tgt_lang)
# source embeddings
logger.info('Writing source embeddings to %s ...' % src_path)
with io.open(src_path, 'w', encoding='utf-8') as f:
f.write(u"%i %i\n" % src_emb.size())
for i in range(len(params.src_dico)):
f.write(u"%s %s\n" % (params.src_dico[i], " ".join('%.5f' % x for x in src_emb[i])))
# target embeddings
logger.info('Writing target embeddings to %s ...' % tgt_path)
with io.open(tgt_path, 'w', encoding='utf-8') as f:
f.write(u"%i %i\n" % tgt_emb.size())
for i in range(len(params.tgt_dico)):
f.write(u"%s %s\n" % (params.tgt_dico[i], " ".join('%.5f' % x for x in tgt_emb[i])))

# PyTorch file
if params.export == "pth":
src_path = os.path.join(params.exp_path, 'vectors-%s.pth' % params.src_lang)
tgt_path = os.path.join(params.exp_path, 'vectors-%s.pth' % params.tgt_lang)
logger.info('Writing source embeddings to %s ...' % src_path)
torch.save({'dico': params.src_dico, 'vectors': src_emb}, src_path)
logger.info('Writing target embeddings to %s ...' % tgt_path)
torch.save({'dico': params.tgt_dico, 'vectors': tgt_emb}, tgt_path)
Loading

0 comments on commit 4884ad2

Please sign in to comment.