Skip to content

Commit a524b43

Browse files
authored
Merge pull request RasaHQ#5767 from lluchini/Restarting_Session_on_Every_Message
Restarting Session on Every Message
2 parents 05a637f + c6f1abd commit a524b43

File tree

5 files changed

+28
-11
lines changed

5 files changed

+28
-11
lines changed

changelog/5964.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed a bug when custom metadata passed with the utterance always restarted the session.

docs/api/rasa-sdk.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ slots, but only a user's name and their phone number. To do that, you'd override
158158
) -> List[EventType]:
159159

160160
# the session should begin with a `session_started` event
161-
events = [SessionStarted()]
161+
events = [SessionStarted(metadata=self.metadata)]
162162

163163
# any slots that should be carried over should come after the
164164
# `session_started` event

rasa/core/actions/action.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ async def run(
178178
``tracker.get_slot(slot_name)`` and the most recent user
179179
message is ``tracker.latest_message.text``.
180180
domain (Domain): the bot's domain
181-
181+
metadata: dictionary that can be sent to action server with custom
182+
data.
182183
Returns:
183184
List[Event]: A list of :class:`rasa.core.events.Event` instances
184185
"""
@@ -352,6 +353,9 @@ class ActionSessionStart(Action):
352353
session.
353354
"""
354355

356+
# Optional arbitrary metadata that can be passed to the SessionStarted event.
357+
metadata: Optional[Dict[Text, Any]] = None
358+
355359
def name(self) -> Text:
356360
return ACTION_SESSION_START_NAME
357361

@@ -378,7 +382,7 @@ async def run(
378382
) -> List[Event]:
379383
from rasa.core.events import SessionStarted
380384

381-
_events = [SessionStarted()]
385+
_events = [SessionStarted(metadata=self.metadata)]
382386

383387
if domain.session_config.carry_over_slots:
384388
_events.extend(self._slot_set_events_from_tracker(tracker))

rasa/core/processor.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,12 @@ async def _update_tracker_session(
172172
f"Starting a new session for conversation ID '{tracker.sender_id}'."
173173
)
174174

175-
if metadata:
176-
tracker.events.append(SessionStarted(metadata=metadata))
177-
178175
await self._run_action(
179176
action=self._get_action(ACTION_SESSION_START_NAME),
180177
tracker=tracker,
181178
output_channel=output_channel,
182179
nlg=self.nlg,
180+
metadata=metadata,
183181
)
184182

185183
async def get_tracker_with_session_start(
@@ -636,11 +634,22 @@ async def _cancel_reminders(
636634
scheduler.remove_job(scheduled_job.id)
637635

638636
async def _run_action(
639-
self, action, tracker, output_channel, nlg, policy=None, confidence=None
637+
self,
638+
action,
639+
tracker,
640+
output_channel,
641+
nlg,
642+
policy=None,
643+
confidence=None,
644+
metadata: Optional[Dict[Text, Any]] = None,
640645
) -> bool:
641646
# events and return values are used to update
642647
# the tracker state after an action has been taken
643648
try:
649+
# Here we set optional metadata to the ActionSessionStart, which will then
650+
# be passed to the SessionStart event. Otherwise the metadata will be lost.
651+
if action.name() == ACTION_SESSION_START_NAME:
652+
action.metadata = metadata
644653
events = await action.run(output_channel, nlg, tracker, self.domain)
645654
except ActionExecutionRejection:
646655
events = [ActionExecutionRejected(action.name(), policy, confidence)]

tests/core/test_processor.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -505,12 +505,15 @@ async def test_update_tracker_session_with_metadata(
505505
# the save is not called in _update_tracker_session()
506506
default_processor._save_tracker(tracker)
507507

508-
# inspect tracker events and make sure SessionStarted event is present and has metadata.
508+
# inspect tracker events and make sure SessionStarted event is present
509+
# and has metadata.
509510
tracker = default_processor.tracker_store.retrieve(sender_id)
510-
session_event_idx = tracker.events.index(SessionStarted())
511-
session_event_metadata = tracker.events[session_event_idx].metadata
511+
assert tracker.events.count(SessionStarted()) == 1
512512

513-
assert session_event_metadata == metadata
513+
session_started_event_idx = tracker.events.index(SessionStarted())
514+
session_started_event_metadata = tracker.events[session_started_event_idx].metadata
515+
516+
assert session_started_event_metadata == metadata
514517

515518

516519
# noinspection PyProtectedMember

0 commit comments

Comments
 (0)