1
- from typing import Any , Dict , List , Text , Tuple
1
+ from typing import Any , Dict , List , Text , Tuple , Optional
2
2
3
3
from rasa .nlu .components import Component
4
4
from rasa .nlu .constants import EXTRACTOR , ENTITIES
@@ -21,6 +21,93 @@ def add_processor_name(self, entity: Dict[Text, Any]) -> Dict[Text, Any]:
21
21
22
22
return entity
23
23
24
+ def clean_up_entities (
25
+ self , entities : List [Dict [Text , Any ]], keep : bool = True
26
+ ) -> List [Dict [Text , Any ]]:
27
+ """
28
+ Checks if multiple entity labels are assigned to one word.
29
+
30
+ This might happen if you are using a tokenizer that splits up words into
31
+ sub-words and different entity labels are assigned to the individual sub-words.
32
+ In such a case keep the entity label with the highest confidence as entity
33
+ label for that word. If you set 'keep' to 'False', all entity labels for
34
+ that word will be removed.
35
+
36
+ Args:
37
+ entities: list of entities
38
+ keep:
39
+ If set to 'True', the entity label with the highest confidence is kept
40
+ if multiple entity labels are assigned to one word. If set to 'False'
41
+ all entity labels for that word will be removed.
42
+
43
+ Returns: updated list of entities
44
+ """
45
+ if len (entities ) <= 1 :
46
+ return entities
47
+
48
+ entity_indices : List [List [int ]] = []
49
+
50
+ # get indices of entity labels that belong to one word
51
+ for idx in range (1 , len (entities )):
52
+ if entities [idx ]["start" ] == entities [idx - 1 ]["end" ]:
53
+ if entity_indices and entity_indices [- 1 ][- 1 ] == idx - 1 :
54
+ entity_indices [- 1 ].append (idx )
55
+ else :
56
+ entity_indices .append ([idx - 1 , idx ])
57
+
58
+ entity_indices_to_remove = set ()
59
+
60
+ for indices in entity_indices :
61
+ if not keep :
62
+ entity_indices_to_remove .update (indices )
63
+ continue
64
+
65
+ # get start, end, and value of entity matching the complete word
66
+ start = entities [indices [0 ]]["start" ]
67
+ end = entities [indices [- 1 ]]["end" ]
68
+ value = "" .join (entities [idx ]["value" ] for idx in indices )
69
+ idx = self ._get_highest_confidence_idx (entities , indices )
70
+
71
+ if idx is None :
72
+ entity_indices_to_remove .update (indices )
73
+ else :
74
+ # We just want to keep the entity with the highest confidence value
75
+ indices .remove (idx )
76
+ entity_indices_to_remove .update (indices )
77
+ # update that entity to cover the complete word
78
+ entities [idx ]["start" ] = start
79
+ entities [idx ]["end" ] = end
80
+ entities [idx ]["value" ] = value
81
+
82
+ # sort indices to remove entries at the end of the list first
83
+ # to avoid index out of range errors
84
+ for idx in sorted (entity_indices_to_remove , reverse = True ):
85
+ entities .remove (entities [idx ])
86
+
87
+ return entities
88
+
89
+ @staticmethod
90
+ def _get_highest_confidence_idx (
91
+ entities : List [Dict [Text , Any ]], indices : List [int ]
92
+ ) -> Optional [int ]:
93
+ """
94
+ Args:
95
+ entities: the full list of entities
96
+ indices: the indices to consider
97
+
98
+ Returns: the idx of the entity label with the highest confidence.
99
+ """
100
+ confidences = [
101
+ entities [idx ]["confidence" ]
102
+ for idx in indices
103
+ if "confidence" in entities [idx ]
104
+ ]
105
+
106
+ if len (confidences ) != len (indices ):
107
+ return None
108
+
109
+ return confidences .index (max (confidences ))
110
+
24
111
@staticmethod
25
112
def filter_irrelevant_entities (extracted : list , requested_dimensions : set ) -> list :
26
113
"""Only return dimensions the user configured"""
0 commit comments