Skip to content

Commit d55e868

Browse files
authored
Merge pull request RasaHQ#5614 from RasaHQ/filter-operation
Improved filtering for NLU training data examples
2 parents a095714 + c7464a6 commit d55e868

File tree

4 files changed

+74
-25
lines changed

4 files changed

+74
-25
lines changed

changelog/5614.misc.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Replace ``TrainingData.filter_by_intent`` function with a more general function which filters training
2+
examples using a filtering function.

rasa/nlu/selectors/response_selector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,9 @@ def preprocess_train_data(self, training_data: TrainingData) -> RasaModelData:
272272
"""
273273

274274
if self.retrieval_intent:
275-
training_data = training_data.filter_by_intent(self.retrieval_intent)
275+
training_data = training_data.filter_training_examples(
276+
lambda ex: self.retrieval_intent == ex.get(INTENT)
277+
)
276278
else:
277279
# retrieval intent was left to its default value
278280
logger.info(

rasa/nlu/training_data/training_data.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from collections import Counter, OrderedDict
55
from copy import deepcopy
66
from os.path import relpath
7-
from typing import Any, Dict, List, Optional, Set, Text, Tuple
7+
from typing import Any, Dict, List, Optional, Set, Text, Tuple, Callable
88

99
import rasa.nlu.utils
1010
from rasa.utils.common import raise_warning, lazy_property
11-
from rasa.nlu.constants import RESPONSE, RESPONSE_KEY_ATTRIBUTE
11+
from rasa.nlu.constants import ENTITIES, INTENT, RESPONSE, RESPONSE_KEY_ATTRIBUTE
1212
from rasa.nlu.training_data.message import Message
1313
from rasa.nlu.training_data.util import check_duplicate_synonym
1414
from rasa.nlu.utils import list_to_str
@@ -75,21 +75,35 @@ def merge(self, *others: "TrainingData") -> "TrainingData":
7575
nlg_stories,
7676
)
7777

78-
def filter_by_intent(self, intent: Text):
79-
"""Filter training examples """
78+
def filter_training_examples(
79+
self, condition: Callable[[Message], bool]
80+
) -> "TrainingData":
81+
"""Filter training examples.
8082
81-
training_examples = []
82-
for ex in self.training_examples:
83-
if ex.get("intent") == intent:
84-
training_examples.append(ex)
83+
Args:
84+
condition: A function that will be applied to filter training examples.
85+
86+
Returns:
87+
TrainingData: A TrainingData with filtered training examples.
88+
"""
8589

8690
return TrainingData(
87-
training_examples,
91+
list(filter(condition, self.training_examples)),
8892
self.entity_synonyms,
8993
self.regex_features,
9094
self.lookup_tables,
9195
)
9296

97+
def filter_by_intent(self, intent: Text) -> "TrainingData":
98+
"""Filter training examples."""
99+
raise_warning(
100+
"The `filter_by_intent` function is deprecated. "
101+
"Please use `filter_training_examples` instead.",
102+
DeprecationWarning,
103+
stacklevel=2,
104+
)
105+
return self.filter_training_examples(lambda ex: intent == ex.get(INTENT))
106+
93107
def __hash__(self) -> int:
94108
from rasa.core import utils as core_utils
95109

@@ -105,49 +119,49 @@ def sanitize_examples(examples: List[Message]) -> List[Message]:
105119
Remove trailing whitespaces from intent and response annotations and drop duplicate examples."""
106120

107121
for ex in examples:
108-
if ex.get("intent"):
109-
ex.set("intent", ex.get("intent").strip())
122+
if ex.get(INTENT):
123+
ex.set(INTENT, ex.get(INTENT).strip())
110124

111-
if ex.get("response"):
112-
ex.set("response", ex.get("response").strip())
125+
if ex.get(RESPONSE):
126+
ex.set(RESPONSE, ex.get(RESPONSE).strip())
113127

114128
return list(OrderedDict.fromkeys(examples))
115129

116130
@lazy_property
117131
def intent_examples(self) -> List[Message]:
118-
return [ex for ex in self.training_examples if ex.get("intent")]
132+
return [ex for ex in self.training_examples if ex.get(INTENT)]
119133

120134
@lazy_property
121135
def response_examples(self) -> List[Message]:
122-
return [ex for ex in self.training_examples if ex.get("response")]
136+
return [ex for ex in self.training_examples if ex.get(RESPONSE)]
123137

124138
@lazy_property
125139
def entity_examples(self) -> List[Message]:
126-
return [ex for ex in self.training_examples if ex.get("entities")]
140+
return [ex for ex in self.training_examples if ex.get(ENTITIES)]
127141

128142
@lazy_property
129143
def intents(self) -> Set[Text]:
130144
"""Returns the set of intents in the training data."""
131-
return {ex.get("intent") for ex in self.training_examples} - {None}
145+
return {ex.get(INTENT) for ex in self.training_examples} - {None}
132146

133147
@lazy_property
134148
def responses(self) -> Set[Text]:
135149
"""Returns the set of responses in the training data."""
136-
return {ex.get("response") for ex in self.training_examples} - {None}
150+
return {ex.get(RESPONSE) for ex in self.training_examples} - {None}
137151

138152
@lazy_property
139153
def retrieval_intents(self) -> Set[Text]:
140154
"""Returns the total number of response types in the training data"""
141155
return {
142-
ex.get("intent")
156+
ex.get(INTENT)
143157
for ex in self.training_examples
144-
if ex.get("response") is not None
158+
if ex.get(RESPONSE) is not None
145159
}
146160

147161
@lazy_property
148162
def examples_per_intent(self) -> Dict[Text, int]:
149163
"""Calculates the number of examples per intent."""
150-
intents = [ex.get("intent") for ex in self.training_examples]
164+
intents = [ex.get(INTENT) for ex in self.training_examples]
151165
return dict(Counter(intents))
152166

153167
@lazy_property
@@ -299,7 +313,7 @@ def sorted_intent_examples(self) -> List[Message]:
299313
"""Sorts the intent examples by the name of the intent and then response"""
300314

301315
return sorted(
302-
self.intent_examples, key=lambda e: (e.get("intent"), e.get("response"))
316+
self.intent_examples, key=lambda e: (e.get(INTENT), e.get(RESPONSE))
303317
)
304318

305319
def validate(self) -> None:
@@ -393,7 +407,7 @@ def split_nlu_examples(
393407
) -> Tuple[list, list]:
394408
train, test = [], []
395409
for intent, count in self.examples_per_intent.items():
396-
ex = [e for e in self.intent_examples if e.data["intent"] == intent]
410+
ex = [e for e in self.intent_examples if e.data[INTENT] == intent]
397411
if random_seed is not None:
398412
random.Random(random_seed).shuffle(ex)
399413
else:

tests/nlu/training_data/test_training_data.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tempfile
55
from jsonschema import ValidationError
66

7-
from rasa.nlu.constants import TEXT
7+
from rasa.nlu.constants import TEXT, RESPONSE_KEY_ATTRIBUTE
88
from rasa.nlu import training_data
99
from rasa.nlu.convert import convert_training_data
1010
from rasa.nlu.extractors.crf_entity_extractor import CRFEntityExtractor
@@ -174,6 +174,37 @@ def test_demo_data(files):
174174
]
175175

176176

177+
@pytest.mark.parametrize(
178+
"files",
179+
[
180+
[
181+
"data/examples/rasa/demo-rasa.json",
182+
"data/examples/rasa/demo-rasa-responses.md",
183+
],
184+
[
185+
"data/examples/rasa/demo-rasa.md",
186+
"data/examples/rasa/demo-rasa-responses.md",
187+
],
188+
],
189+
)
190+
def test_demo_data_filter_out_retrieval_intents(files):
191+
from rasa.importers.utils import training_data_from_paths
192+
193+
td = training_data_from_paths(files, language="en")
194+
assert len(td.training_examples) == 46
195+
196+
td1 = td.filter_training_examples(lambda ex: ex.get(RESPONSE_KEY_ATTRIBUTE) is None)
197+
assert len(td1.training_examples) == 42
198+
199+
td2 = td.filter_training_examples(
200+
lambda ex: ex.get(RESPONSE_KEY_ATTRIBUTE) is not None
201+
)
202+
assert len(td2.training_examples) == 4
203+
204+
# make sure filtering operation doesn't mutate the source training data
205+
assert len(td.training_examples) == 46
206+
207+
177208
@pytest.mark.parametrize(
178209
"filepaths",
179210
[["data/examples/rasa/demo-rasa.md", "data/examples/rasa/demo-rasa-responses.md"]],

0 commit comments

Comments
 (0)