Skip to content

Commit

Permalink
Merge pull request piskvorky#607 from alantian/develop
Browse files Browse the repository at this point in the history
Control whether to use lowercase for computing word2vec accuracy.
  • Loading branch information
tmylk committed Apr 11, 2016
2 parents bc695ad + 31eadad commit 8c2a11c
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions gensim/models/word2vec.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -1414,7 +1414,7 @@ def log_accuracy(section):
(section['section'], 100.0 * correct / (correct + incorrect),
correct, correct + incorrect))

def accuracy(self, questions, restrict_vocab=30000, most_similar=most_similar):
def accuracy(self, questions, restrict_vocab=30000, most_similar=most_similar, use_lowercase=True):
"""
Compute accuracy of the model. `questions` is a filename where lines are
4-tuples of words, split into sections by ": SECTION NAME" lines.
Expand All @@ -1426,6 +1426,9 @@ def accuracy(self, questions, restrict_vocab=30000, most_similar=most_similar):
Use `restrict_vocab` to ignore all questions containing a word whose frequency
is not in the top-N most frequent words (default top 30,000).
Use `use_lowercase` to convert all words in questions to thier lowercase form before evaluating
the accuracy. It's useful when assuming the text preprocessing also uses lowercase. (default True).
This method corresponds to the `compute-accuracy` script of the original C word2vec.
"""
Expand All @@ -1447,7 +1450,10 @@ def accuracy(self, questions, restrict_vocab=30000, most_similar=most_similar):
if not section:
raise ValueError("missing section header before line #%i in %s" % (line_no, questions))
try:
a, b, c, expected = [word.lower() for word in line.split()] # TODO assumes vocabulary preprocessing uses lowercase, too...
if use_lowercase:
a, b, c, expected = [word.lower() for word in line.split()] # assumes vocabulary preprocessing uses lowercase, too...
else:
a, b, c, expected = [word for word in line.split()]
except:
logger.info("skipping invalid line #%i in %s" % (line_no, questions))
if a not in ok_vocab or b not in ok_vocab or c not in ok_vocab or expected not in ok_vocab:
Expand Down

0 comments on commit 8c2a11c

Please sign in to comment.