Skip to content

Commit

Permalink
Merge pull request nltk#817 from dimazest/nltk.parse-test-fix
Browse files Browse the repository at this point in the history
nltk.parse tests are fixed for Python 2.6.
  • Loading branch information
stevenbird committed Dec 19, 2014
2 parents 5a61c30 + 09958ad commit 39d9f29
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 108 deletions.
156 changes: 79 additions & 77 deletions nltk/parse/depeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,116 +7,118 @@
# For license information, see LICENSE.TXT
import unicodedata


class DependencyEvaluator(object):
"""
Class for measuring labelled and unlabelled attachment score for dependency parsing. Note that the evaluation ignore the punctuation
>>> from nltk.parse.dependencygraph import DependencyGraph
>>> from nltk.parse.depeval import DependencyEvaluator
>>> gold_sent = DependencyGraph(\"""
... Pierre NNP 2 NMOD
... Vinken NNP 8 SUB
... , , 2 P
... 61 CD 5 NMOD
... years NNS 6 AMOD
... old JJ 2 NMOD
... , , 2 P
... will MD 0 ROOT
... join VB 8 VC
... the DT 11 NMOD
... board NN 9 OBJ
... as IN 9 VMOD
... a DT 15 NMOD
... nonexecutive JJ 15 NMOD
... director NN 12 PMOD
... Nov. NNP 9 VMOD
... 29 CD 16 NMOD
... . . 9 VMOD
... \""")
>>> parsed_sent = DependencyGraph(\"""
... Pierre NNP 8 NMOD
... Vinken NNP 1 SUB
... , , 3 P
... 61 CD 6 NMOD
... years NNS 6 AMOD
... old JJ 2 NMOD
... , , 3 AMOD
... will MD 0 ROOT
... join VB 8 VC
... the DT 11 AMOD
... board NN 9 OBJECT
... as IN 9 NMOD
... a DT 15 NMOD
... nonexecutive JJ 15 NMOD
... director NN 12 PMOD
... Nov. NNP 9 VMOD
... 29 CD 16 NMOD
... . . 9 VMOD
... \""")
>>> de = DependencyEvaluator([parsed_sent],[gold_sent])
>>> de.eval()
(0.8, 0.6)
>>> from nltk.parse.dependencygraph import DependencyGraph
>>> from nltk.parse.depeval import DependencyEvaluator
>>> gold_sent = DependencyGraph(\"""
... Pierre NNP 2 NMOD
... Vinken NNP 8 SUB
... , , 2 P
... 61 CD 5 NMOD
... years NNS 6 AMOD
... old JJ 2 NMOD
... , , 2 P
... will MD 0 ROOT
... join VB 8 VC
... the DT 11 NMOD
... board NN 9 OBJ
... as IN 9 VMOD
... a DT 15 NMOD
... nonexecutive JJ 15 NMOD
... director NN 12 PMOD
... Nov. NNP 9 VMOD
... 29 CD 16 NMOD
... . . 9 VMOD
... \""")
>>> parsed_sent = DependencyGraph(\"""
... Pierre NNP 8 NMOD
... Vinken NNP 1 SUB
... , , 3 P
... 61 CD 6 NMOD
... years NNS 6 AMOD
... old JJ 2 NMOD
... , , 3 AMOD
... will MD 0 ROOT
... join VB 8 VC
... the DT 11 AMOD
... board NN 9 OBJECT
... as IN 9 NMOD
... a DT 15 NMOD
... nonexecutive JJ 15 NMOD
... director NN 12 PMOD
... Nov. NNP 9 VMOD
... 29 CD 16 NMOD
... . . 9 VMOD
... \""")
>>> de = DependencyEvaluator([parsed_sent],[gold_sent])
>>> las, uas = de.eval()
>>> las
0.8...
>>> abs(uas - 0.6) < 0.00001
True
"""
def __init__(self, parsed_sents, gold_sents):
"""
:param parsed_sents: the list of parsed_sents as the output of parser
:param parsed_sents: the list of parsed_sents as the output of parser
:type parsed_sents: list(DependencyGraph)
"""
self._parsed_sents = parsed_sents
self._gold_sents = gold_sents

def _remove_punct(self,inStr):
"""
Function to remove punctuation from Unicode string.
:param input: the input string
:return: Unicode string after remove all punctuation
Function to remove punctuation from Unicode string.
:param input: the input string
:return: Unicode string after remove all punctuation
"""
punc_cat = set(["Pc", "Pd", "Ps", "Pe", "Pi", "Pf", "Po"])
return "".join(x for x in inStr if unicodedata.category(x) not in punc_cat)
return "".join(x for x in inStr if unicodedata.category(x) not in punc_cat)

def eval(self):
"""
Return the Labeled Attachment Score (LAS) and Unlabeled Attachment Score (UAS)
Return the Labeled Attachment Score (LAS) and Unlabeled Attachment Score (UAS)
:return : tuple(float,float)
"""
if (len(self._parsed_sents) != len(self._gold_sents)):
raise ValueError(" Number of parsed sentence is different with number of gold sentence.")
corr = 0
corrL = 0
total = 0

corr = 0
corrL = 0
total = 0

for i in range(len(self._parsed_sents)):
parsed_sent = self._parsed_sents[i].nodelist
gold_sent = self._gold_sents[i].nodelist
if len(parsed_sent) != len(gold_sent):
raise ValueError(" Sentence length is not matched. ")

for j in range(len(parsed_sent)):
if parsed_sent[j]["word"] is None:
if parsed_sent[j]["word"] is None:
continue
if parsed_sent[j]["word"] != gold_sent[j]["word"]:
raise ValueError(" Sentence sequence is not matched. ")
# by default, ignore if word is punctuation
if self._remove_punct(parsed_sent[j]["word"]) == "":

# by default, ignore if word is punctuation
if self._remove_punct(parsed_sent[j]["word"]) == "":
#if (parsed_sent[j]["word"] in string.punctuation):
continue
total += 1
continue

total += 1
if (parsed_sent[j]["head"] == gold_sent[j]["head"]):
corr += 1
if (parsed_sent[j]["rel"] == gold_sent[j]["rel"]):
corr += 1
if (parsed_sent[j]["rel"] == gold_sent[j]["rel"]):
corrL += 1
return (corr / (1.0 * total), corrL/ (1.0 * total))

if __name__ == '__main__':
import doctest
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)

66 changes: 35 additions & 31 deletions nltk/parse/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

import unicodedata


class DependencyEvaluator(object):
"""
Class for measuring labelled and unlabelled attachment score for
dependency parsing. Note that the evaluation ignores punctuation.
>>> from nltk.parse import DependencyGraph, DependencyEvaluator
>>> gold_sent = DependencyGraph(\"""
Expand All @@ -37,7 +38,7 @@ class DependencyEvaluator(object):
... 29 CD 16 NMOD
... . . 9 VMOD
... \""")
>>> parsed_sent = DependencyGraph(\"""
... Pierre NNP 8 NMOD
... Vinken NNP 1 SUB
Expand All @@ -60,65 +61,68 @@ class DependencyEvaluator(object):
... \""")
>>> de = DependencyEvaluator([parsed_sent],[gold_sent])
>>> de.eval()
(0.8, 0.6)
"""
>>> las, uas = de.eval()
>>> las
0.8...
>>> abs(uas - 0.6) < 0.00001
True
"""

def __init__(self, parsed_sents, gold_sents):
"""
:param parsed_sents: the list of parsed_sents as the output of parser
:param parsed_sents: the list of parsed_sents as the output of parser
:type parsed_sents: list(DependencyGraph)
"""
self._parsed_sents = parsed_sents
self._gold_sents = gold_sents
def _remove_punct(self,inStr):

def _remove_punct(self, inStr):
"""
Function to remove punctuation from Unicode string.
:param input: the input string
:return: Unicode string after remove all punctuation
Function to remove punctuation from Unicode string.
:param input: the input string
:return: Unicode string after remove all punctuation
"""
punc_cat = set(["Pc", "Pd", "Ps", "Pe", "Pi", "Pf", "Po"])
return "".join(x for x in inStr if unicodedata.category(x) not in punc_cat)
return "".join(x for x in inStr if unicodedata.category(x) not in punc_cat)

def eval(self):
"""
Return the Labeled Attachment Score (LAS) and Unlabeled Attachment Score (UAS)
Return the Labeled Attachment Score (LAS) and Unlabeled Attachment Score (UAS)
:return : tuple(float,float)
"""
if (len(self._parsed_sents) != len(self._gold_sents)):
raise ValueError(" Number of parsed sentence is different with number of gold sentence.")
corr = 0
corrL = 0
total = 0

corr = 0
corrL = 0
total = 0

for i in range(len(self._parsed_sents)):
parsed_sent = self._parsed_sents[i].nodelist
gold_sent = self._gold_sents[i].nodelist
if (len(parsed_sent) != len(gold_sent)):
raise ValueError("Sentences must have equal length.")

for j in range(len(parsed_sent)):
if (parsed_sent[j]["word"] is None):
if (parsed_sent[j]["word"] is None):
continue
if (parsed_sent[j]["word"] != gold_sent[j]["word"]):
raise ValueError("Sentence sequence is not matched.")

# Ignore if word is punctuation by default
#if (parsed_sent[j]["word"] in string.punctuation):
if self._remove_punct(parsed_sent[j]["word"]) == "":
continue
total += 1
if self._remove_punct(parsed_sent[j]["word"]) == "":
continue

total += 1
if (parsed_sent[j]["head"] == gold_sent[j]["head"]):
corr += 1
if (parsed_sent[j]["rel"] == gold_sent[j]["rel"]):
corr += 1
if (parsed_sent[j]["rel"] == gold_sent[j]["rel"]):
corrL += 1
return (corr/total, corrL/total)


if __name__ == '__main__':
import doctest
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE|doctest.ELLIPSIS)
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS)

0 comments on commit 39d9f29

Please sign in to comment.