Skip to content

Commit 328b49b

Browse files
authored
Merge pull request RasaHQ#5511 from RasaHQ/fix-entity-recognition-prediction
Entity applies to complete word not just parts of it
2 parents f4eaee4 + 1dfe96c commit 328b49b

File tree

8 files changed

+357
-67
lines changed

8 files changed

+357
-67
lines changed

changelog/5509.bugfix.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
An entity label should always cover a complete word.
2+
3+
If you are using, for example, ``ConveRTTokenizer`` words can be split into multiple tokens.
4+
Our entity extractors assign entity labels per token. So, it might happen, that just a part of a word has
5+
an entity label. This is now fixed. An entity label always covers a complete word.

rasa/nlu/classifiers/diet_classifier.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -739,13 +739,15 @@ def _predict_entities(
739739
message.text, message.get(TOKENS_NAMES[TEXT], []), tags
740740
)
741741

742-
extracted = self.add_extractor_name(entities)
743-
entities = message.get(ENTITIES, []) + extracted
742+
entities = self.add_extractor_name(entities)
743+
entities = self.clean_up_entities(message, entities)
744+
entities = message.get(ENTITIES, []) + entities
744745

745746
return entities
746747

748+
@staticmethod
747749
def _convert_tags_to_entities(
748-
self, text: Text, tokens: List[Token], tags: List[Text]
750+
text: Text, tokens: List[Token], tags: List[Text]
749751
) -> List[Dict[Text, Any]]:
750752
entities = []
751753
last_tag = NO_ENTITY_TAG
@@ -773,7 +775,7 @@ def _convert_tags_to_entities(
773775
for entity in entities:
774776
entity["value"] = text[entity["start"] : entity["end"]]
775777

776-
return self.clean_up_entities(entities)
778+
return entities
777779

778780
def process(self, message: Message, **kwargs: Any) -> None:
779781
"""Return the most likely label and its similarity to the input."""

rasa/nlu/extractors/crf_entity_extractor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,9 @@ def _create_dataset(self, examples: List[Message]) -> List[List[CRFToken]]:
156156
return dataset
157157

158158
def process(self, message: Message, **kwargs: Any) -> None:
159-
extracted = self.add_extractor_name(self.extract_entities(message))
160-
message.set(ENTITIES, message.get(ENTITIES, []) + extracted, add_to_output=True)
159+
entities = self.add_extractor_name(self.extract_entities(message))
160+
entities = self.clean_up_entities(message, entities)
161+
message.set(ENTITIES, message.get(ENTITIES, []) + entities, add_to_output=True)
161162

162163
def extract_entities(self, message: Message) -> List[Dict[Text, Any]]:
163164
"""Take a sentence and return entities in json format"""
@@ -166,8 +167,7 @@ def extract_entities(self, message: Message) -> List[Dict[Text, Any]]:
166167
text_data = self._from_text_to_crf(message)
167168
features = self._sentence_to_features(text_data)
168169
entities = self.ent_tagger.predict_marginals_single(features)
169-
entities = self._from_crf_to_json(message, entities)
170-
return self.clean_up_entities(entities)
170+
return self._from_crf_to_json(message, entities)
171171
else:
172172
return []
173173

rasa/nlu/extractors/duckling_http_extractor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,8 @@ def process(self, message: Message, **kwargs: Any) -> None:
186186
)
187187

188188
extracted = self.add_extractor_name(extracted)
189-
message.set(
190-
ENTITIES, message.get(ENTITIES, []) + extracted, add_to_output=True,
191-
)
189+
extracted = self.clean_up_entities(message, extracted)
190+
message.set(ENTITIES, message.get(ENTITIES, []) + extracted, add_to_output=True)
192191

193192
@classmethod
194193
def load(

rasa/nlu/extractors/extractor.py

Lines changed: 199 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import Any, Dict, List, Text, Tuple, Optional
1+
from typing import Any, Dict, List, Text, Tuple, Optional, Union
22

3+
from rasa.nlu.tokenizers.tokenizer import Token
34
from rasa.nlu.components import Component
4-
from rasa.nlu.constants import EXTRACTOR, ENTITIES
5+
from rasa.nlu.constants import EXTRACTOR, ENTITIES, TOKENS_NAMES, TEXT
56
from rasa.nlu.training_data import Message
67

78

@@ -22,62 +23,60 @@ def add_processor_name(self, entity: Dict[Text, Any]) -> Dict[Text, Any]:
2223
return entity
2324

2425
def clean_up_entities(
25-
self, entities: List[Dict[Text, Any]], keep: bool = True
26+
self, message: Message, entities: List[Dict[Text, Any]], keep: bool = True
2627
) -> List[Dict[Text, Any]]:
2728
"""
28-
Checks if multiple entity labels are assigned to one word.
29+
Check if multiple entity labels are assigned to one word or if an entity label
30+
is assigned to just a part of a word or if an entity label covers multiple
31+
words, but one word just partly.
2932
3033
This might happen if you are using a tokenizer that splits up words into
3134
sub-words and different entity labels are assigned to the individual sub-words.
32-
In such a case keep the entity label with the highest confidence as entity
33-
label for that word. If you set 'keep' to 'False', all entity labels for
34-
that word will be removed.
35+
If multiple entity labels are assigned to one word, we keep the entity label
36+
with the highest confidence as entity label for that word. If just a part
37+
of the word is annotated, that entity label is taken for the complete word.
38+
If you set 'keep' to 'False', all entity labels for the word will be removed.
3539
3640
Args:
41+
message: message object
3742
entities: list of entities
3843
keep:
3944
If set to 'True', the entity label with the highest confidence is kept
4045
if multiple entity labels are assigned to one word. If set to 'False'
4146
all entity labels for that word will be removed.
4247
43-
Returns: updated list of entities
48+
Returns:
49+
Updated entities.
4450
"""
45-
if len(entities) <= 1:
46-
return entities
47-
48-
entity_indices: List[List[int]] = []
49-
50-
# get indices of entity labels that belong to one word
51-
for idx in range(1, len(entities)):
52-
if entities[idx]["start"] == entities[idx - 1]["end"]:
53-
if entity_indices and entity_indices[-1][-1] == idx - 1:
54-
entity_indices[-1].append(idx)
55-
else:
56-
entity_indices.append([idx - 1, idx])
51+
misaligned_entities = self._get_misaligned_entities(
52+
message.get(TOKENS_NAMES[TEXT]), entities
53+
)
5754

5855
entity_indices_to_remove = set()
5956

60-
for indices in entity_indices:
57+
for misaligned_entity in misaligned_entities:
58+
# entity indices involved in the misalignment
59+
entity_indices = misaligned_entity["entity_indices"]
60+
6161
if not keep:
62-
entity_indices_to_remove.update(indices)
62+
entity_indices_to_remove.update(entity_indices)
6363
continue
6464

65-
# get start, end, and value of entity matching the complete word
66-
start = entities[indices[0]]["start"]
67-
end = entities[indices[-1]]["end"]
68-
value = "".join(entities[idx]["value"] for idx in indices)
69-
idx = self._get_highest_confidence_idx(entities, indices)
65+
idx = self._entity_index_to_keep(entities, entity_indices)
7066

7167
if idx is None:
72-
entity_indices_to_remove.update(indices)
68+
entity_indices_to_remove.update(entity_indices)
7369
else:
74-
# We just want to keep the entity with the highest confidence value
75-
indices.remove(idx)
76-
entity_indices_to_remove.update(indices)
77-
# update that entity to cover the complete word
78-
entities[idx]["start"] = start
79-
entities[idx]["end"] = end
80-
entities[idx]["value"] = value
70+
# keep just one entity
71+
entity_indices.remove(idx)
72+
entity_indices_to_remove.update(entity_indices)
73+
74+
# update that entity to cover the complete word(s)
75+
entities[idx]["start"] = misaligned_entity["start"]
76+
entities[idx]["end"] = misaligned_entity["end"]
77+
entities[idx]["value"] = message.text[
78+
misaligned_entity["start"] : misaligned_entity["end"]
79+
]
8180

8281
# sort indices to remove entries at the end of the list first
8382
# to avoid index out of range errors
@@ -86,24 +85,183 @@ def clean_up_entities(
8685

8786
return entities
8887

88+
def _get_misaligned_entities(
89+
self, tokens: List[Token], entities: List[Dict[Text, Any]]
90+
) -> List[Dict[Text, Any]]:
91+
"""Identify entities and tokens that are misaligned.
92+
93+
Misaligned entities are those that apply only to a part of a word, i.e.
94+
sub-word.
95+
96+
Args:
97+
tokens: list of tokens
98+
entities: list of detected entities by the entity extractor
99+
100+
Returns:
101+
Misaligned entities including the start and end position
102+
of the final entity in the text and entity indices that are part of this
103+
misalignment.
104+
"""
105+
if not tokens:
106+
return []
107+
108+
# group tokens: one token cluster corresponds to one word
109+
token_clusters = self._token_clusters(tokens)
110+
111+
# added for tests, should only happen if tokens are not set or len(tokens) == 1
112+
if not token_clusters:
113+
return []
114+
115+
misaligned_entities = []
116+
for entity_idx, entity in enumerate(entities):
117+
# get all tokens that are covered/touched by the entity
118+
entity_tokens = self._tokens_of_entity(entity, token_clusters)
119+
120+
if len(entity_tokens) == 1:
121+
# entity covers exactly one word
122+
continue
123+
124+
# get start and end position of complete word
125+
# needed to update the final entity later
126+
start_position = entity_tokens[0].start
127+
end_position = entity_tokens[-1].end
128+
129+
# check if an entity was already found that covers the exact same word(s)
130+
_idx = self._misaligned_entity_index(
131+
misaligned_entities, start_position, end_position
132+
)
133+
134+
if _idx is None:
135+
misaligned_entities.append(
136+
{
137+
"start": start_position,
138+
"end": end_position,
139+
"entity_indices": [entity_idx],
140+
}
141+
)
142+
else:
143+
misaligned_entities[_idx]["entity_indices"].append(entity_idx)
144+
145+
return misaligned_entities
146+
89147
@staticmethod
90-
def _get_highest_confidence_idx(
91-
entities: List[Dict[Text, Any]], indices: List[int]
148+
def _misaligned_entity_index(
149+
word_entity_cluster: List[Dict[Text, Union[int, List[int]]]],
150+
start_position: int,
151+
end_position: int,
92152
) -> Optional[int]:
153+
"""Get index of matching misaligned entity.
154+
155+
Args:
156+
word_entity_cluster: word entity cluster
157+
start_position: start position
158+
end_position: end position
159+
160+
Returns:
161+
Index of the misaligned entity that matches the provided start and end
162+
position.
93163
"""
164+
for idx, cluster in enumerate(word_entity_cluster):
165+
if cluster["start"] == start_position and cluster["end"] == end_position:
166+
return idx
167+
return None
168+
169+
@staticmethod
170+
def _tokens_of_entity(
171+
entity: Dict[Text, Any], token_clusters: List[List[Token]]
172+
) -> List[Token]:
173+
"""Get all tokens of token clusters that are covered by the entity.
174+
175+
The entity can cover them completely or just partly.
176+
177+
Args:
178+
entity: the entity
179+
token_clusters: list of token clusters
180+
181+
Returns:
182+
Token clusters that belong to the provided entity.
183+
184+
"""
185+
entity_tokens = []
186+
for token_cluster in token_clusters:
187+
entity_starts_inside_cluster = (
188+
token_cluster[0].start <= entity["start"] <= token_cluster[-1].end
189+
)
190+
entity_ends_inside_cluster = (
191+
token_cluster[0].start <= entity["end"] <= token_cluster[-1].end
192+
)
193+
194+
if entity_starts_inside_cluster or entity_ends_inside_cluster:
195+
entity_tokens += token_cluster
196+
return entity_tokens
197+
198+
@staticmethod
199+
def _token_clusters(tokens: List[Token]) -> List[List[Token]]:
200+
"""Build clusters of tokens that belong to one word.
201+
202+
Args:
203+
tokens: list of tokens
204+
205+
Returns:
206+
Token clusters.
207+
208+
"""
209+
# token cluster = list of token indices that belong to one word
210+
token_index_clusters = []
211+
212+
# start at 1 in order to check if current token and previous token belong
213+
# to the same word
214+
for token_idx in range(1, len(tokens)):
215+
previous_token_idx = token_idx - 1
216+
# two tokens belong to the same word if there is no other character
217+
# between them
218+
if tokens[token_idx].start == tokens[previous_token_idx].end:
219+
# a word was split into multiple tokens
220+
token_cluster_already_exists = (
221+
token_index_clusters
222+
and token_index_clusters[-1][-1] == previous_token_idx
223+
)
224+
if token_cluster_already_exists:
225+
token_index_clusters[-1].append(token_idx)
226+
else:
227+
token_index_clusters.append([previous_token_idx, token_idx])
228+
else:
229+
# the token corresponds to a single word
230+
if token_idx == 1:
231+
token_index_clusters.append([previous_token_idx])
232+
token_index_clusters.append([token_idx])
233+
234+
return [[tokens[idx] for idx in cluster] for cluster in token_index_clusters]
235+
236+
@staticmethod
237+
def _entity_index_to_keep(
238+
entities: List[Dict[Text, Any]], entity_indices: List[int]
239+
) -> Optional[int]:
240+
"""
241+
Determine the entity index to keep.
242+
243+
If we just have one entity index, i.e. candidate, we return the index of that
244+
candidate. If we have multiple candidates, we return the index of the entity
245+
value with the highest confidence score. If no confidence score is present,
246+
no entity label will be kept.
247+
94248
Args:
95249
entities: the full list of entities
96-
indices: the indices to consider
250+
entity_indices: the entity indices to consider
97251
98-
Returns: the idx of the entity label with the highest confidence.
252+
Returns: the idx of the entity to keep
99253
"""
254+
if len(entity_indices) == 1:
255+
return entity_indices[0]
256+
100257
confidences = [
101258
entities[idx]["confidence"]
102-
for idx in indices
259+
for idx in entity_indices
103260
if "confidence" in entities[idx]
104261
]
105262

106-
if len(confidences) != len(indices):
263+
# we don't have confidence values for all entity labels
264+
if len(confidences) != len(entity_indices):
107265
return None
108266

109267
return confidences.index(max(confidences))

rasa/nlu/extractors/mitie_entity_extractor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,8 @@ def process(self, message: Message, **kwargs: Any) -> None:
142142
message.text, self._tokens_without_cls(message), mitie_feature_extractor
143143
)
144144
extracted = self.add_extractor_name(ents)
145-
message.set(
146-
ENTITIES, message.get(ENTITIES, []) + extracted, add_to_output=True,
147-
)
145+
extracted = self.clean_up_entities(message, extracted)
146+
message.set(ENTITIES, message.get(ENTITIES, []) + extracted, add_to_output=True)
148147

149148
@classmethod
150149
def load(

rasa/nlu/extractors/spacy_entity_extractor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,12 @@ def process(self, message: Message, **kwargs: Any) -> None:
3232
spacy_nlp = kwargs.get("spacy_nlp", None)
3333
doc = spacy_nlp(message.text)
3434
all_extracted = self.add_extractor_name(self.extract_entities(doc))
35+
all_extracted = self.clean_up_entities(message, all_extracted)
3536
dimensions = self.component_config["dimensions"]
3637
extracted = SpacyEntityExtractor.filter_irrelevant_entities(
3738
all_extracted, dimensions
3839
)
39-
message.set(
40-
ENTITIES, message.get(ENTITIES, []) + extracted, add_to_output=True,
41-
)
40+
message.set(ENTITIES, message.get(ENTITIES, []) + extracted, add_to_output=True)
4241

4342
@staticmethod
4443
def extract_entities(doc: "Doc") -> List[Dict[Text, Any]]:

0 commit comments

Comments
 (0)