1
- from typing import Any , Dict , List , Text , Tuple , Optional
1
+ from typing import Any , Dict , List , Text , Tuple , Optional , Union
2
2
3
+ from rasa .nlu .tokenizers .tokenizer import Token
3
4
from rasa .nlu .components import Component
4
- from rasa .nlu .constants import EXTRACTOR , ENTITIES
5
+ from rasa .nlu .constants import EXTRACTOR , ENTITIES , TOKENS_NAMES , TEXT
5
6
from rasa .nlu .training_data import Message
6
7
7
8
@@ -22,62 +23,60 @@ def add_processor_name(self, entity: Dict[Text, Any]) -> Dict[Text, Any]:
22
23
return entity
23
24
24
25
def clean_up_entities (
25
- self , entities : List [Dict [Text , Any ]], keep : bool = True
26
+ self , message : Message , entities : List [Dict [Text , Any ]], keep : bool = True
26
27
) -> List [Dict [Text , Any ]]:
27
28
"""
28
- Checks if multiple entity labels are assigned to one word.
29
+ Check if multiple entity labels are assigned to one word or if an entity label
30
+ is assigned to just a part of a word or if an entity label covers multiple
31
+ words, but one word just partly.
29
32
30
33
This might happen if you are using a tokenizer that splits up words into
31
34
sub-words and different entity labels are assigned to the individual sub-words.
32
- In such a case keep the entity label with the highest confidence as entity
33
- label for that word. If you set 'keep' to 'False', all entity labels for
34
- that word will be removed.
35
+ If multiple entity labels are assigned to one word, we keep the entity label
36
+ with the highest confidence as entity label for that word. If just a part
37
+ of the word is annotated, that entity label is taken for the complete word.
38
+ If you set 'keep' to 'False', all entity labels for the word will be removed.
35
39
36
40
Args:
41
+ message: message object
37
42
entities: list of entities
38
43
keep:
39
44
If set to 'True', the entity label with the highest confidence is kept
40
45
if multiple entity labels are assigned to one word. If set to 'False'
41
46
all entity labels for that word will be removed.
42
47
43
- Returns: updated list of entities
48
+ Returns:
49
+ Updated entities.
44
50
"""
45
- if len (entities ) <= 1 :
46
- return entities
47
-
48
- entity_indices : List [List [int ]] = []
49
-
50
- # get indices of entity labels that belong to one word
51
- for idx in range (1 , len (entities )):
52
- if entities [idx ]["start" ] == entities [idx - 1 ]["end" ]:
53
- if entity_indices and entity_indices [- 1 ][- 1 ] == idx - 1 :
54
- entity_indices [- 1 ].append (idx )
55
- else :
56
- entity_indices .append ([idx - 1 , idx ])
51
+ misaligned_entities = self ._get_misaligned_entities (
52
+ message .get (TOKENS_NAMES [TEXT ]), entities
53
+ )
57
54
58
55
entity_indices_to_remove = set ()
59
56
60
- for indices in entity_indices :
57
+ for misaligned_entity in misaligned_entities :
58
+ # entity indices involved in the misalignment
59
+ entity_indices = misaligned_entity ["entity_indices" ]
60
+
61
61
if not keep :
62
- entity_indices_to_remove .update (indices )
62
+ entity_indices_to_remove .update (entity_indices )
63
63
continue
64
64
65
- # get start, end, and value of entity matching the complete word
66
- start = entities [indices [0 ]]["start" ]
67
- end = entities [indices [- 1 ]]["end" ]
68
- value = "" .join (entities [idx ]["value" ] for idx in indices )
69
- idx = self ._get_highest_confidence_idx (entities , indices )
65
+ idx = self ._entity_index_to_keep (entities , entity_indices )
70
66
71
67
if idx is None :
72
- entity_indices_to_remove .update (indices )
68
+ entity_indices_to_remove .update (entity_indices )
73
69
else :
74
- # We just want to keep the entity with the highest confidence value
75
- indices .remove (idx )
76
- entity_indices_to_remove .update (indices )
77
- # update that entity to cover the complete word
78
- entities [idx ]["start" ] = start
79
- entities [idx ]["end" ] = end
80
- entities [idx ]["value" ] = value
70
+ # keep just one entity
71
+ entity_indices .remove (idx )
72
+ entity_indices_to_remove .update (entity_indices )
73
+
74
+ # update that entity to cover the complete word(s)
75
+ entities [idx ]["start" ] = misaligned_entity ["start" ]
76
+ entities [idx ]["end" ] = misaligned_entity ["end" ]
77
+ entities [idx ]["value" ] = message .text [
78
+ misaligned_entity ["start" ] : misaligned_entity ["end" ]
79
+ ]
81
80
82
81
# sort indices to remove entries at the end of the list first
83
82
# to avoid index out of range errors
@@ -86,24 +85,183 @@ def clean_up_entities(
86
85
87
86
return entities
88
87
88
+ def _get_misaligned_entities (
89
+ self , tokens : List [Token ], entities : List [Dict [Text , Any ]]
90
+ ) -> List [Dict [Text , Any ]]:
91
+ """Identify entities and tokens that are misaligned.
92
+
93
+ Misaligned entities are those that apply only to a part of a word, i.e.
94
+ sub-word.
95
+
96
+ Args:
97
+ tokens: list of tokens
98
+ entities: list of detected entities by the entity extractor
99
+
100
+ Returns:
101
+ Misaligned entities including the start and end position
102
+ of the final entity in the text and entity indices that are part of this
103
+ misalignment.
104
+ """
105
+ if not tokens :
106
+ return []
107
+
108
+ # group tokens: one token cluster corresponds to one word
109
+ token_clusters = self ._token_clusters (tokens )
110
+
111
+ # added for tests, should only happen if tokens are not set or len(tokens) == 1
112
+ if not token_clusters :
113
+ return []
114
+
115
+ misaligned_entities = []
116
+ for entity_idx , entity in enumerate (entities ):
117
+ # get all tokens that are covered/touched by the entity
118
+ entity_tokens = self ._tokens_of_entity (entity , token_clusters )
119
+
120
+ if len (entity_tokens ) == 1 :
121
+ # entity covers exactly one word
122
+ continue
123
+
124
+ # get start and end position of complete word
125
+ # needed to update the final entity later
126
+ start_position = entity_tokens [0 ].start
127
+ end_position = entity_tokens [- 1 ].end
128
+
129
+ # check if an entity was already found that covers the exact same word(s)
130
+ _idx = self ._misaligned_entity_index (
131
+ misaligned_entities , start_position , end_position
132
+ )
133
+
134
+ if _idx is None :
135
+ misaligned_entities .append (
136
+ {
137
+ "start" : start_position ,
138
+ "end" : end_position ,
139
+ "entity_indices" : [entity_idx ],
140
+ }
141
+ )
142
+ else :
143
+ misaligned_entities [_idx ]["entity_indices" ].append (entity_idx )
144
+
145
+ return misaligned_entities
146
+
89
147
@staticmethod
90
- def _get_highest_confidence_idx (
91
- entities : List [Dict [Text , Any ]], indices : List [int ]
148
+ def _misaligned_entity_index (
149
+ word_entity_cluster : List [Dict [Text , Union [int , List [int ]]]],
150
+ start_position : int ,
151
+ end_position : int ,
92
152
) -> Optional [int ]:
153
+ """Get index of matching misaligned entity.
154
+
155
+ Args:
156
+ word_entity_cluster: word entity cluster
157
+ start_position: start position
158
+ end_position: end position
159
+
160
+ Returns:
161
+ Index of the misaligned entity that matches the provided start and end
162
+ position.
93
163
"""
164
+ for idx , cluster in enumerate (word_entity_cluster ):
165
+ if cluster ["start" ] == start_position and cluster ["end" ] == end_position :
166
+ return idx
167
+ return None
168
+
169
+ @staticmethod
170
+ def _tokens_of_entity (
171
+ entity : Dict [Text , Any ], token_clusters : List [List [Token ]]
172
+ ) -> List [Token ]:
173
+ """Get all tokens of token clusters that are covered by the entity.
174
+
175
+ The entity can cover them completely or just partly.
176
+
177
+ Args:
178
+ entity: the entity
179
+ token_clusters: list of token clusters
180
+
181
+ Returns:
182
+ Token clusters that belong to the provided entity.
183
+
184
+ """
185
+ entity_tokens = []
186
+ for token_cluster in token_clusters :
187
+ entity_starts_inside_cluster = (
188
+ token_cluster [0 ].start <= entity ["start" ] <= token_cluster [- 1 ].end
189
+ )
190
+ entity_ends_inside_cluster = (
191
+ token_cluster [0 ].start <= entity ["end" ] <= token_cluster [- 1 ].end
192
+ )
193
+
194
+ if entity_starts_inside_cluster or entity_ends_inside_cluster :
195
+ entity_tokens += token_cluster
196
+ return entity_tokens
197
+
198
+ @staticmethod
199
+ def _token_clusters (tokens : List [Token ]) -> List [List [Token ]]:
200
+ """Build clusters of tokens that belong to one word.
201
+
202
+ Args:
203
+ tokens: list of tokens
204
+
205
+ Returns:
206
+ Token clusters.
207
+
208
+ """
209
+ # token cluster = list of token indices that belong to one word
210
+ token_index_clusters = []
211
+
212
+ # start at 1 in order to check if current token and previous token belong
213
+ # to the same word
214
+ for token_idx in range (1 , len (tokens )):
215
+ previous_token_idx = token_idx - 1
216
+ # two tokens belong to the same word if there is no other character
217
+ # between them
218
+ if tokens [token_idx ].start == tokens [previous_token_idx ].end :
219
+ # a word was split into multiple tokens
220
+ token_cluster_already_exists = (
221
+ token_index_clusters
222
+ and token_index_clusters [- 1 ][- 1 ] == previous_token_idx
223
+ )
224
+ if token_cluster_already_exists :
225
+ token_index_clusters [- 1 ].append (token_idx )
226
+ else :
227
+ token_index_clusters .append ([previous_token_idx , token_idx ])
228
+ else :
229
+ # the token corresponds to a single word
230
+ if token_idx == 1 :
231
+ token_index_clusters .append ([previous_token_idx ])
232
+ token_index_clusters .append ([token_idx ])
233
+
234
+ return [[tokens [idx ] for idx in cluster ] for cluster in token_index_clusters ]
235
+
236
+ @staticmethod
237
+ def _entity_index_to_keep (
238
+ entities : List [Dict [Text , Any ]], entity_indices : List [int ]
239
+ ) -> Optional [int ]:
240
+ """
241
+ Determine the entity index to keep.
242
+
243
+ If we just have one entity index, i.e. candidate, we return the index of that
244
+ candidate. If we have multiple candidates, we return the index of the entity
245
+ value with the highest confidence score. If no confidence score is present,
246
+ no entity label will be kept.
247
+
94
248
Args:
95
249
entities: the full list of entities
96
- indices : the indices to consider
250
+ entity_indices : the entity indices to consider
97
251
98
- Returns: the idx of the entity label with the highest confidence.
252
+ Returns: the idx of the entity to keep
99
253
"""
254
+ if len (entity_indices ) == 1 :
255
+ return entity_indices [0 ]
256
+
100
257
confidences = [
101
258
entities [idx ]["confidence" ]
102
- for idx in indices
259
+ for idx in entity_indices
103
260
if "confidence" in entities [idx ]
104
261
]
105
262
106
- if len (confidences ) != len (indices ):
263
+ # we don't have confidence values for all entity labels
264
+ if len (confidences ) != len (entity_indices ):
107
265
return None
108
266
109
267
return confidences .index (max (confidences ))
0 commit comments