Skip to content

Commit e85f442

Browse files
authored
Merge branch 'master' into patch-release-1.9.3
2 parents 805de56 + 46c00a3 commit e85f442

File tree

6 files changed

+225
-10
lines changed

6 files changed

+225
-10
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
Rasa is an open source machine learning framework to automate text-and voice-based conversations. With Rasa, you can build contexual assistants on:
1515
- Facebook Messenger
1616
- Slack
17+
- Google Hangouts
18+
- Webex Teams
1719
- Microsoft Bot Framework
1820
- Rocket.Chat
1921
- Mattermost

changelog/5475.bugfix.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
One word can just have one entity label.
2+
3+
If you are using, for example, ``ConveRTTokenizer`` words can be split into multiple tokens.
4+
Our entity extractors assign entity labels per token. So, it might happen, that a word, that was split into two tokens,
5+
got assigned two different entity labels. This is now fixed. One word can just have one entity label at a time.

rasa/nlu/classifiers/diet_classifier.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -329,11 +329,11 @@ def _tag_id_index_mapping(self, training_data: TrainingData) -> Dict[Text, int]:
329329
if self.component_config[BILOU_FLAG]:
330330
return bilou_utils.build_tag_id_dict(training_data)
331331

332-
distinct_tag_ids = set(
332+
distinct_tag_ids = {
333333
e["entity"]
334334
for example in training_data.entity_examples
335335
for e in example.get(ENTITIES)
336-
) - {None}
336+
} - {None}
337337

338338
tag_id_dict = {
339339
tag_id: idx for idx, tag_id in enumerate(sorted(distinct_tag_ids), 1)
@@ -662,7 +662,7 @@ def _predict(self, message: Message) -> Optional[Dict[Text, tf.Tensor]]:
662662
"There is no trained model: component is either not trained or "
663663
"didn't receive enough training data."
664664
)
665-
return
665+
return None
666666

667667
# create session data from message and convert it into a batch of 1
668668
model_data = self._create_model_data([message])
@@ -744,9 +744,8 @@ def _predict_entities(
744744

745745
return entities
746746

747-
@staticmethod
748747
def _convert_tags_to_entities(
749-
text: Text, tokens: List[Token], tags: List[Text]
748+
self, text: Text, tokens: List[Token], tags: List[Text]
750749
) -> List[Dict[Text, Any]]:
751750
entities = []
752751
last_tag = NO_ENTITY_TAG
@@ -774,7 +773,7 @@ def _convert_tags_to_entities(
774773
for entity in entities:
775774
entity["value"] = text[entity["start"] : entity["end"]]
776775

777-
return entities
776+
return self.clean_up_entities(entities)
778777

779778
def process(self, message: Message, **kwargs: Any) -> None:
780779
"""Return the most likely label and its similarity to the input."""
@@ -1191,7 +1190,7 @@ def _combine_sparse_dense_features(
11911190

11921191
def _features_as_seq_ids(
11931192
self, features: List[Union[np.ndarray, tf.Tensor, tf.SparseTensor]], name: Text
1194-
) -> tf.Tensor:
1193+
) -> Optional[tf.Tensor]:
11951194
"""Creates dense labels for negative sampling."""
11961195

11971196
# if there are dense features - we can use them
@@ -1206,6 +1205,8 @@ def _features_as_seq_ids(
12061205
self._tf_layers[f"sparse_to_dense_ids.{name}"](f)
12071206
)
12081207

1208+
return None
1209+
12091210
def _create_bow(
12101211
self,
12111212
features: List[Union[tf.Tensor, tf.SparseTensor]],

rasa/nlu/extractors/crf_entity_extractor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,9 @@ def extract_entities(self, message: Message) -> List[Dict[Text, Any]]:
165165
if self.ent_tagger is not None:
166166
text_data = self._from_text_to_crf(message)
167167
features = self._sentence_to_features(text_data)
168-
ents = self.ent_tagger.predict_marginals_single(features)
169-
return self._from_crf_to_json(message, ents)
168+
entities = self.ent_tagger.predict_marginals_single(features)
169+
entities = self._from_crf_to_json(message, entities)
170+
return self.clean_up_entities(entities)
170171
else:
171172
return []
172173

rasa/nlu/extractors/extractor.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Text, Tuple
1+
from typing import Any, Dict, List, Text, Tuple, Optional
22

33
from rasa.nlu.components import Component
44
from rasa.nlu.constants import EXTRACTOR, ENTITIES
@@ -21,6 +21,93 @@ def add_processor_name(self, entity: Dict[Text, Any]) -> Dict[Text, Any]:
2121

2222
return entity
2323

24+
def clean_up_entities(
25+
self, entities: List[Dict[Text, Any]], keep: bool = True
26+
) -> List[Dict[Text, Any]]:
27+
"""
28+
Checks if multiple entity labels are assigned to one word.
29+
30+
This might happen if you are using a tokenizer that splits up words into
31+
sub-words and different entity labels are assigned to the individual sub-words.
32+
In such a case keep the entity label with the highest confidence as entity
33+
label for that word. If you set 'keep' to 'False', all entity labels for
34+
that word will be removed.
35+
36+
Args:
37+
entities: list of entities
38+
keep:
39+
If set to 'True', the entity label with the highest confidence is kept
40+
if multiple entity labels are assigned to one word. If set to 'False'
41+
all entity labels for that word will be removed.
42+
43+
Returns: updated list of entities
44+
"""
45+
if len(entities) <= 1:
46+
return entities
47+
48+
entity_indices: List[List[int]] = []
49+
50+
# get indices of entity labels that belong to one word
51+
for idx in range(1, len(entities)):
52+
if entities[idx]["start"] == entities[idx - 1]["end"]:
53+
if entity_indices and entity_indices[-1][-1] == idx - 1:
54+
entity_indices[-1].append(idx)
55+
else:
56+
entity_indices.append([idx - 1, idx])
57+
58+
entity_indices_to_remove = set()
59+
60+
for indices in entity_indices:
61+
if not keep:
62+
entity_indices_to_remove.update(indices)
63+
continue
64+
65+
# get start, end, and value of entity matching the complete word
66+
start = entities[indices[0]]["start"]
67+
end = entities[indices[-1]]["end"]
68+
value = "".join(entities[idx]["value"] for idx in indices)
69+
idx = self._get_highest_confidence_idx(entities, indices)
70+
71+
if idx is None:
72+
entity_indices_to_remove.update(indices)
73+
else:
74+
# We just want to keep the entity with the highest confidence value
75+
indices.remove(idx)
76+
entity_indices_to_remove.update(indices)
77+
# update that entity to cover the complete word
78+
entities[idx]["start"] = start
79+
entities[idx]["end"] = end
80+
entities[idx]["value"] = value
81+
82+
# sort indices to remove entries at the end of the list first
83+
# to avoid index out of range errors
84+
for idx in sorted(entity_indices_to_remove, reverse=True):
85+
entities.remove(entities[idx])
86+
87+
return entities
88+
89+
@staticmethod
90+
def _get_highest_confidence_idx(
91+
entities: List[Dict[Text, Any]], indices: List[int]
92+
) -> Optional[int]:
93+
"""
94+
Args:
95+
entities: the full list of entities
96+
indices: the indices to consider
97+
98+
Returns: the idx of the entity label with the highest confidence.
99+
"""
100+
confidences = [
101+
entities[idx]["confidence"]
102+
for idx in indices
103+
if "confidence" in entities[idx]
104+
]
105+
106+
if len(confidences) != len(indices):
107+
return None
108+
109+
return confidences.index(max(confidences))
110+
24111
@staticmethod
25112
def filter_irrelevant_entities(extracted: list, requested_dimensions: set) -> list:
26113
"""Only return dimensions the user configured"""
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from typing import Any, Text, Dict, List
2+
3+
import pytest
4+
5+
from rasa.nlu.extractors.extractor import EntityExtractor
6+
7+
8+
@pytest.mark.parametrize(
9+
"entities, keep, expected_entities",
10+
[
11+
(
12+
[
13+
{"entity": "iata", "start": 0, "end": 3, "value": "Aar"},
14+
{"entity": "city", "start": 3, "end": 6, "value": "hus"},
15+
],
16+
False,
17+
[],
18+
),
19+
(
20+
[
21+
{"entity": "iata", "start": 0, "end": 3, "value": "Aar"},
22+
{"entity": "city", "start": 3, "end": 6, "value": "hus"},
23+
],
24+
True,
25+
[],
26+
),
27+
(
28+
[
29+
{"entity": "city", "start": 0, "end": 3, "value": "Aarhus"},
30+
{"entity": "type", "start": 4, "end": 9, "value": "city"},
31+
],
32+
False,
33+
[
34+
{"entity": "city", "start": 0, "end": 3, "value": "Aarhus"},
35+
{"entity": "type", "start": 4, "end": 9, "value": "city"},
36+
],
37+
),
38+
(
39+
[
40+
{
41+
"entity": "city",
42+
"start": 0,
43+
"end": 3,
44+
"confidence": 0.87,
45+
"value": "Aar",
46+
},
47+
{
48+
"entity": "iata",
49+
"start": 3,
50+
"end": 6,
51+
"confidence": 0.43,
52+
"value": "hus",
53+
},
54+
],
55+
True,
56+
[
57+
{
58+
"entity": "city",
59+
"start": 0,
60+
"end": 6,
61+
"confidence": 0.87,
62+
"value": "Aarhus",
63+
}
64+
],
65+
),
66+
(
67+
[
68+
{
69+
"entity": "iata",
70+
"start": 0,
71+
"end": 2,
72+
"confidence": 0.32,
73+
"value": "Aa",
74+
},
75+
{
76+
"entity": "city",
77+
"start": 2,
78+
"end": 3,
79+
"confidence": 0.87,
80+
"value": "r",
81+
},
82+
{
83+
"entity": "iata",
84+
"start": 3,
85+
"end": 5,
86+
"confidence": 0.21,
87+
"value": "hu",
88+
},
89+
{
90+
"entity": "city",
91+
"start": 5,
92+
"end": 6,
93+
"confidence": 0.43,
94+
"value": "s",
95+
},
96+
],
97+
True,
98+
[
99+
{
100+
"entity": "city",
101+
"start": 0,
102+
"end": 6,
103+
"confidence": 0.87,
104+
"value": "Aarhus",
105+
}
106+
],
107+
),
108+
],
109+
)
110+
def test_convert_tags_to_entities(
111+
entities: List[Dict[Text, Any]],
112+
keep: bool,
113+
expected_entities: List[Dict[Text, Any]],
114+
):
115+
extractor = EntityExtractor()
116+
117+
updated_entities = extractor.clean_up_entities(entities, keep)
118+
119+
assert updated_entities == expected_entities

0 commit comments

Comments
 (0)