Skip to content

Commit ef3ae7c

Browse files
authored
Merge pull request RasaHQ#5595 from RasaHQ/train-retrieval-intents
Update POST /model/train endpoint to accept responses
2 parents 7d31c62 + 1d3343a commit ef3ae7c

File tree

8 files changed

+197
-0
lines changed

8 files changed

+197
-0
lines changed

changelog/5595.improvement.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Update ``POST /model/train`` endpoint to accept retrieval action responses
2+
at the ``responses`` key of the JSON payload.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
intents:
2+
- greet
3+
- goodbye
4+
- affirm
5+
- deny
6+
- mood_great
7+
- mood_unhappy
8+
- bot_challenge
9+
- chitchat
10+
- chitchat/ask_name
11+
- chitchat/ask_weather
12+
13+
responses:
14+
utter_greet:
15+
- text: Hey! How are you?
16+
utter_cheer_up:
17+
- text: 'Here is something to cheer you up:'
18+
image: https://i.imgur.com/nGF1K8f.jpg
19+
utter_did_that_help:
20+
- text: Did that help you?
21+
utter_happy:
22+
- text: Great, carry on!
23+
utter_goodbye:
24+
- text: Bye
25+
utter_iamabot:
26+
- text: I am a bot, powered by Rasa.
27+
28+
actions:
29+
- respond_chitchat
30+
- utter_greet
31+
- utter_cheer_up
32+
- utter_did_that_help
33+
- utter_happy
34+
- utter_goodbye
35+
- utter_iamabot
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
## intent:greet
2+
- hey
3+
- hello
4+
- hi
5+
- good morning
6+
- good evening
7+
- hey there
8+
9+
## intent:goodbye
10+
- bye
11+
- goodbye
12+
- see you around
13+
- see you later
14+
15+
## intent:affirm
16+
- yes
17+
- indeed
18+
- of course
19+
- that sounds good
20+
- correct
21+
22+
## intent:deny
23+
- no
24+
- never
25+
- I don't think so
26+
- don't like that
27+
- no way
28+
- not really
29+
30+
## intent:mood_great
31+
- perfect
32+
- very good
33+
- great
34+
- amazing
35+
- wonderful
36+
- I am feeling very good
37+
- I am great
38+
- I'm good
39+
40+
## intent:mood_unhappy
41+
- sad
42+
- very sad
43+
- unhappy
44+
- bad
45+
- very bad
46+
- awful
47+
- terrible
48+
- not very good
49+
- extremely sad
50+
- so sad
51+
52+
## intent:bot_challenge
53+
- are you a bot?
54+
- are you a human?
55+
- am I talking to a bot?
56+
- am I talking to a human?
57+
58+
## intent:chitchat/ask_name
59+
- what's your name
60+
- who are you?
61+
- what are you called?
62+
63+
## intent:chitchat/ask_weather
64+
- how's weather?
65+
- is it sunny where you are?

data/test_responses/default.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
## ask name
2+
* chitchat/ask_name
3+
- my name is Sara, Rasa's documentation bot!
4+
5+
## ask weather
6+
* chitchat/ask_weather
7+
- it's always sunny where I live
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
## happy path
2+
* greet
3+
- utter_greet
4+
* mood_great
5+
- utter_happy
6+
7+
## sad path 1
8+
* greet
9+
- utter_greet
10+
* mood_unhappy
11+
- utter_cheer_up
12+
- utter_did_that_help
13+
* affirm
14+
- utter_happy
15+
16+
## sad path 2
17+
* greet
18+
- utter_greet
19+
* mood_unhappy
20+
- utter_cheer_up
21+
- utter_did_that_help
22+
* deny
23+
- utter_goodbye
24+
25+
## say goodbye
26+
* goodbye
27+
- utter_goodbye
28+
29+
## bot challenge
30+
* bot_challenge
31+
- utter_iamabot
32+
33+
## chitchat
34+
* chitchat
35+
- respond_chitchat

docs/_static/spec/rasa.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,8 @@ components:
12911291
$ref: '#/components/schemas/ConfigFile'
12921292
nlu:
12931293
$ref: '#/components/schemas/NLUTrainingData'
1294+
responses:
1295+
$ref: '#/components/schemas/RetrievalIntentsTrainingData'
12941296
stories:
12951297
$ref: '#/components/schemas/StoriesTrainingData'
12961298
force:
@@ -1355,6 +1357,17 @@ components:
13551357
13561358
- unhappy
13571359
1360+
RetrievalIntentsTrainingData:
1361+
type: string
1362+
description: Rasa response texts for retrieval intents in markdown format
1363+
example: >-
1364+
## ask name
1365+
* chitchat/ask_name
1366+
- my name is Sara, Rasa's documentation bot!
1367+
1368+
## ask weather
1369+
* chitchat/ask_weather
1370+
- it's always sunny where I live
13581371
13591372
StoriesTrainingData:
13601373
type: string

rasa/server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,10 @@ async def train(request: Request) -> HTTPResponse:
771771
stories_path = os.path.join(temp_dir, "stories.md")
772772
rasa.utils.io.write_text_file(rjs["stories"], stories_path)
773773

774+
if "responses" in rjs:
775+
responses_path = os.path.join(temp_dir, "responses.md")
776+
rasa.utils.io.write_text_file(rjs["responses"], responses_path)
777+
774778
domain_path = DEFAULT_DOMAIN_PATH
775779
if "domain" in rjs:
776780
domain_path = os.path.join(temp_dir, "domain.yml")

tests/test_server.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,42 @@ def test_train_core_success(
432432
assert os.path.exists(os.path.join(model_path, "fingerprint.json"))
433433

434434

435+
def test_train_with_retrieval_events_success(rasa_app, default_stack_config):
436+
with ExitStack() as stack:
437+
domain_file = stack.enter_context(
438+
open("data/test_domains/default_retrieval_intents.yml")
439+
)
440+
config_file = stack.enter_context(open(default_stack_config))
441+
core_file = stack.enter_context(
442+
open("data/test_stories/stories_retrieval_intents.md")
443+
)
444+
responses_file = stack.enter_context(open("data/test_responses/default.md"))
445+
nlu_file = stack.enter_context(
446+
open("data/test_nlu/default_retrieval_intents.md")
447+
)
448+
449+
payload = dict(
450+
domain=domain_file.read(),
451+
config=config_file.read(),
452+
stories=core_file.read(),
453+
responses=responses_file.read(),
454+
nlu=nlu_file.read(),
455+
)
456+
457+
_, response = rasa_app.post("/model/train", json=payload)
458+
assert response.status == 200
459+
460+
# save model to temporary file
461+
tempdir = tempfile.mkdtemp()
462+
model_path = os.path.join(tempdir, "model.tar.gz")
463+
with open(model_path, "wb") as f:
464+
f.write(response.body)
465+
466+
# unpack model and ensure fingerprint is present
467+
model_path = unpack_model(model_path)
468+
assert os.path.exists(os.path.join(model_path, "fingerprint.json"))
469+
470+
435471
def test_train_missing_config(rasa_app: SanicTestClient):
436472
payload = dict(domain="domain data", config=None)
437473

0 commit comments

Comments
 (0)