Skip to content

Commit

Permalink
handle lowercased words
Browse files Browse the repository at this point in the history
  • Loading branch information
Guillaume Lample committed Mar 22, 2018
1 parent 98f0c34 commit 30422a2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,15 @@ def save_best(self, to_log, metric):
logger.info('* Best value for "%s": %.5f' % (metric, to_log[metric]))
# save the mapping
W = self.mapping.weight.data.cpu().numpy()
path = os.path.join(self.params.exp_path, 'best_mapping.t7')
path = os.path.join(self.params.exp_path, 'best_mapping.pth')
logger.info('* Saving the mapping to %s ...' % path)
torch.save(W, path)

def reload_best(self):
"""
Reload the best mapping.
"""
path = os.path.join(self.params.exp_path, 'best_mapping.t7')
path = os.path.join(self.params.exp_path, 'best_mapping.pth')
logger.info('* Reloading the best model from %s ...' % path)
# reload the model
assert os.path.isfile(path)
Expand Down
14 changes: 13 additions & 1 deletion src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ def clip_parameters(model, clip):
def load_external_embeddings(params, source, full_vocab=False):
"""
Reload pretrained embeddings from a text file.
- `full_vocab == False` means that we load the `params.max_vocab` most frequent words.
It is used at the beginning of the experiment.
- `full_vocab == True` means that we load the entire embedding text file,
before we export the embeddings at the end of the experiment.
"""
assert type(source) is bool
word2id = {}
Expand All @@ -265,12 +269,20 @@ def load_external_embeddings(params, source, full_vocab=False):
assert _emb_dim_file == int(split[1])
else:
word, vect = line.rstrip().split(' ', 1)
if not full_vocab:
word = word.lower()
vect = np.fromstring(vect, sep=' ')
if np.linalg.norm(vect) == 0: # avoid to have null embeddings
vect[0] = 0.01
if word in word2id:
logger.warning('Word %s found twice in %s embedding file' % (word, 'source' if source else 'target'))
if full_vocab:
logger.warning("Word '%s' found twice in %s embedding file"
% (word, 'source' if source else 'target'))
else:
if not vect.shape == (_emb_dim_file,):
logger.warning("Invalid dimension (%i) for %s word '%s' in line %i."
% (vect.shape[0], 'source' if source else 'target', word, i))
continue
assert vect.shape == (_emb_dim_file,), i
word2id[word] = len(word2id)
vectors.append(vect[None])
Expand Down

0 comments on commit 30422a2

Please sign in to comment.