Skip to content

Commit

Permalink
added support for expanding tags to noun phrases
Browse files Browse the repository at this point in the history
  • Loading branch information
mattboggess committed May 9, 2020
1 parent 041a12c commit fb262a1
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 3 deletions.
25 changes: 24 additions & 1 deletion snorkel/preprocessing/data_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,27 @@ def match_uncommon_plurals(text_span, term_span):

return s_plural_match or es_plural_match

def tag_terms(text, terms, nlp=None, invalid_pos=[], invalid_dep=[]):
def expand_noun_phrase(text, start_ix, match_length):

np_dep = ['amod', 'npadvmod', 'compound', 'poss']
end_ix = start_ix + match_length - 1

start_ix -= 1
while start_ix >= 0 and text[start_ix].dep_ in np_dep and text[start_ix].pos_ not in ['DET', 'ADV', 'ADJ', 'PRON'] and not text[start_ix].is_stop:
start_ix -= 1
start_ix += 1

while end_ix < len(text) and text[end_ix].dep_ in np_dep:
end_ix += 1
if end_ix == len(text):
end_ix -= 1

match_length = end_ix - start_ix + 1
term_lemma = ' '.join([tok.lemma_ for tok in text[start_ix:start_ix + match_length]]).replace('-', ' ').strip()

return (term_lemma, start_ix, match_length)

def tag_terms(text, terms, nlp=None, invalid_pos=[], invalid_dep=[], expand_np=False):
""" Identifies and tags any terms in a given input text.
TODO:
Expand Down Expand Up @@ -180,6 +200,9 @@ def tag_terms(text, terms, nlp=None, invalid_pos=[], invalid_dep=[]):

# only tag term if not part of larger term
if valid_match and tags[ix:ix + match_length] == ['O'] * match_length:

if expand_np:
term_lemma, ix, match_length = expand_noun_phrase(text, ix, match_length)

# collect term information
term_text = ''.join([token.text_with_ws for token in text[ix:ix + match_length]]).strip()
Expand Down
62 changes: 60 additions & 2 deletions snorkel/preprocessing/test_data_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,18 +172,70 @@
'text': "Metabolic pathways are regulated systems Figure 9.13 Relationships among the Major Metabolic Pathways of the Cell",
'found_terms' : {},
'bioes_tags': ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
},
{
'test_id': 'expand_np1',
'terms': ['protein', 'molecule'],
'text': "The best understood negative regulatory molecules are retinoblastoma protein (Rb), p53, and p21.",
'found_terms': {
'retinoblastoma protein': {
'text': ['retinoblastoma protein'],
'tokens': [['retinoblastoma', 'protein']],
'pos': [['NN', 'NN']],
'dep': [['compound', 'attr']],
'indices': [(7, 9)]
},
'negative regulatory molecule': {
'text': ['negative regulatory molecules'],
'tokens': [['negative', 'regulatory', 'molecules']],
'pos': [['JJ', 'JJ', 'NNS']],
'dep': [['amod', 'amod', 'nsubj']],
'indices': [(3, 6)]
}
},
'bioes_tags': ['O', 'O', 'O', 'B', 'I', 'E', 'O', 'B', 'E', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
},
{
'test_id': 'expand_np2',
'terms': ['spindle'],
'text': "The spindle apparatus (also called the mitotic spindle or simply the spindle) is awesome.",
'found_terms': {
'spindle apparatus': {
'text': ['spindle apparatus'],
'tokens': [['spindle', 'apparatus']],
'pos': [['NN', 'NN']],
'dep': [['amod', 'nsubj']],
'indices': [(1, 3)]
},
'mitotic spindle': {
'text': ['mitotic spindle'],
'tokens': [['mitotic', 'spindle']],
'pos': [['JJ', 'NN']],
'dep': [['amod', 'oprd']],
'indices': [(7, 9)]
},
'spindle': {
'text': ['spindle'],
'tokens': [['spindle']],
'pos': [['NN']],
'dep': [['conj']],
'indices': [(12, 13)]
}
},
'bioes_tags': ['O', 'B', 'E', 'O', 'O', 'O', 'O', 'B', 'E', 'O', 'O', 'O', 'S', 'O', 'O', 'O', 'O']
}


]

class TestDataProcessingUtils(unittest.TestCase):

def setUp(self):
self.test_data = pd.DataFrame(tag_terms_test_data)

def _test_tag_terms_generic(self, test_id, invalid_dep=[], invalid_pos=[]):
def _test_tag_terms_generic(self, test_id, invalid_dep=[], invalid_pos=[], expand_np=False):
row = self.test_data.loc[self.test_data.test_id == test_id, :].squeeze()
output = tag_terms(row.text, row.terms, invalid_dep=invalid_dep, invalid_pos=invalid_pos)
output = tag_terms(row.text, row.terms, invalid_dep=invalid_dep, invalid_pos=invalid_pos, expand_np=expand_np)
self.assertEqual(output['tags'], row.bioes_tags)
self.assertDictEqual(output['found_terms'], row.found_terms)

Expand Down Expand Up @@ -227,6 +279,12 @@ def test_pos_filter(self):

def test_pos_filter2(self):
self._test_tag_terms_generic('pos_filter', invalid_dep=['poss'], invalid_pos=['VBZ'])

def test_expandnp1(self):
self._test_tag_terms_generic('expand_np1', expand_np=True)

def test_expandnp2(self):
self._test_tag_terms_generic('expand_np2', expand_np=True)


if __name__ == '__main__':
Expand Down

0 comments on commit fb262a1

Please sign in to comment.