Skip to content

Commit a7616d4

Browse files
kearnswtmborasabot
authored
support for additional training metadata (RasaHQ#5743)
* support for additional training metadata * added test for additional training data attributes Co-authored-by: Tom Bocklisch <[email protected]> Co-authored-by: Roberto <[email protected]>
1 parent a2eaf44 commit a7616d4

File tree

6 files changed

+87
-68
lines changed

6 files changed

+87
-68
lines changed

changelog/5743.enhancement.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Support for additional training metadata.
2+
3+
Training data messages now to support kwargs and the Rasa JSON data reader
4+
includes all fields when instantiating a training data instance.

rasa/nlu/training_data/formats/rasa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def read_from_json(self, js: Dict[Text, Any], **_) -> "TrainingData":
5252
all_examples = common_examples + intent_examples + entity_examples
5353
training_examples = []
5454
for ex in all_examples:
55-
msg = Message.build(ex["text"], ex.get("intent"), ex.get("entities"))
55+
msg = Message.build(**ex)
5656
training_examples.append(msg)
5757

5858
return TrainingData(

rasa/nlu/training_data/message.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313

1414
class Message:
1515
def __init__(
16-
self, text: Text, data=None, output_properties=None, time=None
16+
self, text: Text, data=None, output_properties=None, time=None, **kwargs
1717
) -> None:
1818
self.text = text
1919
self.time = time
2020
self.data = data if data else {}
21+
self.data.update(**kwargs)
2122

2223
if output_properties:
2324
self.output_properties = output_properties
@@ -72,7 +73,7 @@ def __hash__(self) -> int:
7273
return hash((self.text, str(ordered(self.data))))
7374

7475
@classmethod
75-
def build(cls, text, intent=None, entities=None) -> "Message":
76+
def build(cls, text, intent=None, entities=None, **kwargs) -> "Message":
7677
data = {}
7778
if intent:
7879
split_intent, response_key = cls.separate_intent_response_key(intent)
@@ -81,7 +82,7 @@ def build(cls, text, intent=None, entities=None) -> "Message":
8182
data[RESPONSE_KEY_ATTRIBUTE] = response_key
8283
if entities:
8384
data[ENTITIES] = entities
84-
return cls(text, data)
85+
return cls(text, data, **kwargs)
8586

8687
def get_combined_intent_response_key(self) -> Text:
8788
"""Get intent as it appears in training data"""

tests/core/test_tracker_stores.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def test_get_db_url_with_query():
343343
)
344344

345345

346-
def test_db_url_with_query_from_endpoint_config():
346+
def test_db_url_with_query_from_endpoint_config(tmp_path):
347347
endpoint_config = """
348348
tracker_store:
349349
dialect: postgresql
@@ -356,11 +356,9 @@ def test_db_url_with_query_from_endpoint_config():
356356
driver: my-driver
357357
another: query
358358
"""
359-
360-
with tempfile.NamedTemporaryFile("w+", suffix="_tmp_config_file.yml") as f:
361-
f.write(endpoint_config)
362-
f.flush()
363-
store_config = read_endpoint_config(f.name, "tracker_store")
359+
f = tmp_path / "tmp_config_file.yml"
360+
f.write_text(endpoint_config)
361+
store_config = read_endpoint_config(str(f), "tracker_store")
364362

365363
url = SQLTrackerStore.get_db_url(**store_config.kwargs)
366364

tests/nlu/test_config.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,14 @@ def test_blank_config(blank_config):
2424
assert final_config.as_dict() == blank_config.as_dict()
2525

2626

27-
def test_invalid_config_json():
27+
def test_invalid_config_json(tmp_path):
2828
file_config = """pipeline: [pretrained_embeddings_spacy""" # invalid yaml
2929

30-
with tempfile.NamedTemporaryFile("w+", suffix="_tmp_config_file.json") as f:
31-
f.write(file_config)
32-
f.flush()
30+
f = tmp_path / "tmp_config_file.json"
31+
f.write_text(file_config)
3332

34-
with pytest.raises(config.InvalidConfigError):
35-
config.load(f.name)
33+
with pytest.raises(config.InvalidConfigError):
34+
config.load(str(f))
3635

3736

3837
def test_invalid_pipeline_template():

tests/nlu/training_data/test_training_data.py

Lines changed: 69 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def test_markdown_single_sections():
259259
assert td_syn_only.entity_synonyms == {"Chines": "chinese", "Chinese": "chinese"}
260260

261261

262-
def test_repeated_entities():
262+
def test_repeated_entities(tmp_path):
263263
data = """
264264
{
265265
"rasa_nlu_data": {
@@ -279,21 +279,20 @@ def test_repeated_entities():
279279
]
280280
}
281281
}"""
282-
with tempfile.NamedTemporaryFile(suffix="_tmp_training_data.json") as f:
283-
f.write(data.encode(io_utils.DEFAULT_ENCODING))
284-
f.flush()
285-
td = training_data.load_data(f.name)
286-
assert len(td.entity_examples) == 1
287-
example = td.entity_examples[0]
288-
entities = example.get("entities")
289-
assert len(entities) == 1
290-
tokens = WhitespaceTokenizer().tokenize(example, attribute=TEXT)
291-
start, end = MitieEntityExtractor.find_entity(entities[0], example.text, tokens)
292-
assert start == 9
293-
assert end == 10
294-
295-
296-
def test_multiword_entities():
282+
f = tmp_path / "tmp_training_data.json"
283+
f.write_text(data, io_utils.DEFAULT_ENCODING)
284+
td = training_data.load_data(str(f))
285+
assert len(td.entity_examples) == 1
286+
example = td.entity_examples[0]
287+
entities = example.get("entities")
288+
assert len(entities) == 1
289+
tokens = WhitespaceTokenizer().tokenize(example, attribute=TEXT)
290+
start, end = MitieEntityExtractor.find_entity(entities[0], example.text, tokens)
291+
assert start == 9
292+
assert end == 10
293+
294+
295+
def test_multiword_entities(tmp_path):
297296
data = """
298297
{
299298
"rasa_nlu_data": {
@@ -313,21 +312,20 @@ def test_multiword_entities():
313312
]
314313
}
315314
}"""
316-
with tempfile.NamedTemporaryFile(suffix="_tmp_training_data.json") as f:
317-
f.write(data.encode(io_utils.DEFAULT_ENCODING))
318-
f.flush()
319-
td = training_data.load_data(f.name)
320-
assert len(td.entity_examples) == 1
321-
example = td.entity_examples[0]
322-
entities = example.get("entities")
323-
assert len(entities) == 1
324-
tokens = WhitespaceTokenizer().tokenize(example, attribute=TEXT)
325-
start, end = MitieEntityExtractor.find_entity(entities[0], example.text, tokens)
326-
assert start == 4
327-
assert end == 7
328-
329-
330-
def test_nonascii_entities():
315+
f = tmp_path / "tmp_training_data.json"
316+
f.write_text(data, io_utils.DEFAULT_ENCODING)
317+
td = training_data.load_data(str(f))
318+
assert len(td.entity_examples) == 1
319+
example = td.entity_examples[0]
320+
entities = example.get("entities")
321+
assert len(entities) == 1
322+
tokens = WhitespaceTokenizer().tokenize(example, attribute=TEXT)
323+
start, end = MitieEntityExtractor.find_entity(entities[0], example.text, tokens)
324+
assert start == 4
325+
assert end == 7
326+
327+
328+
def test_nonascii_entities(tmp_path):
331329
data = """
332330
{
333331
"luis_schema_version": "5.0",
@@ -345,22 +343,21 @@ def test_nonascii_entities():
345343
}
346344
]
347345
}"""
348-
with tempfile.NamedTemporaryFile(suffix="_tmp_training_data.json") as f:
349-
f.write(data.encode(io_utils.DEFAULT_ENCODING))
350-
f.flush()
351-
td = training_data.load_data(f.name)
352-
assert len(td.entity_examples) == 1
353-
example = td.entity_examples[0]
354-
entities = example.get("entities")
355-
assert len(entities) == 1
356-
entity = entities[0]
357-
assert entity["value"] == "ßäæ ?€ö)"
358-
assert entity["start"] == 19
359-
assert entity["end"] == 27
360-
assert entity["entity"] == "description"
361-
362-
363-
def test_entities_synonyms():
346+
f = tmp_path / "tmp_training_data.json"
347+
f.write_text(data, io_utils.DEFAULT_ENCODING)
348+
td = training_data.load_data(str(f))
349+
assert len(td.entity_examples) == 1
350+
example = td.entity_examples[0]
351+
entities = example.get("entities")
352+
assert len(entities) == 1
353+
entity = entities[0]
354+
assert entity["value"] == "ßäæ ?€ö)"
355+
assert entity["start"] == 19
356+
assert entity["end"] == 27
357+
assert entity["entity"] == "description"
358+
359+
360+
def test_entities_synonyms(tmp_path):
364361
data = """
365362
{
366363
"rasa_nlu_data": {
@@ -398,11 +395,10 @@ def test_entities_synonyms():
398395
]
399396
}
400397
}"""
401-
with tempfile.NamedTemporaryFile(suffix="_tmp_training_data.json") as f:
402-
f.write(data.encode(io_utils.DEFAULT_ENCODING))
403-
f.flush()
404-
td = training_data.load_data(f.name)
405-
assert td.entity_synonyms["New York City"] == "nyc"
398+
f = tmp_path / "tmp_training_data.json"
399+
f.write_text(data, io_utils.DEFAULT_ENCODING)
400+
td = training_data.load_data(str(f))
401+
assert td.entity_synonyms["New York City"] == "nyc"
406402

407403

408404
def cmp_message_list(firsts, seconds):
@@ -531,3 +527,24 @@ def test_load_data_from_non_existing_file():
531527

532528
def test_is_empty():
533529
assert TrainingData().is_empty()
530+
531+
532+
def test_custom_attributes(tmp_path):
533+
data = """
534+
{
535+
"rasa_nlu_data": {
536+
"common_examples" : [
537+
{
538+
"intent": "happy",
539+
"text": "I'm happy.",
540+
"sentiment": 0.8
541+
}
542+
]
543+
}
544+
}"""
545+
f = tmp_path / "tmp_training_data.json"
546+
f.write_text(data, io_utils.DEFAULT_ENCODING)
547+
td = training_data.load_data(str(f))
548+
assert len(td.training_examples) == 1
549+
example = td.training_examples[0]
550+
assert example.get("sentiment") == 0.8

0 commit comments

Comments
 (0)