Skip to content

Commit 002b892

Browse files
authored
Merge pull request RasaHQ#5611 from lluchini/metada-on-session-start-action
Add metada on session start action
2 parents c3bbc31 + b759106 commit 002b892

File tree

3 files changed

+49
-4
lines changed

3 files changed

+49
-4
lines changed

changelog/5574.bugfix.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Fixed an issue that happened when metadata is passed in a new session.
2+
3+
Now the metadata is correctly passed to the ActionSessionStart.

rasa/core/processor.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
ReminderScheduled,
3737
SlotSet,
3838
UserUttered,
39+
SessionStarted,
3940
)
4041
from rasa.core.interpreter import (
4142
INTENT_MESSAGE_PREFIX,
@@ -149,7 +150,10 @@ async def predict_next(self, sender_id: Text) -> Optional[Dict[Text, Any]]:
149150
}
150151

151152
async def _update_tracker_session(
152-
self, tracker: DialogueStateTracker, output_channel: OutputChannel
153+
self,
154+
tracker: DialogueStateTracker,
155+
output_channel: OutputChannel,
156+
metadata: Optional[Dict] = None,
153157
) -> None:
154158
"""Check the current session in `tracker` and update it if expired.
155159
@@ -158,6 +162,7 @@ async def _update_tracker_session(
158162
restart are considered).
159163
160164
Args:
165+
metadata: Data sent from client associated with the incoming user message.
161166
tracker: Tracker to inspect.
162167
output_channel: Output channel for potential utterances in a custom
163168
`ActionSessionStart`.
@@ -167,6 +172,9 @@ async def _update_tracker_session(
167172
f"Starting a new session for conversation ID '{tracker.sender_id}'."
168173
)
169174

175+
if metadata:
176+
tracker.events.append(SessionStarted(metadata=metadata))
177+
170178
await self._run_action(
171179
action=self._get_action(ACTION_SESSION_START_NAME),
172180
tracker=tracker,
@@ -175,13 +183,17 @@ async def _update_tracker_session(
175183
)
176184

177185
async def get_tracker_with_session_start(
178-
self, sender_id: Text, output_channel: Optional[OutputChannel] = None
186+
self,
187+
sender_id: Text,
188+
output_channel: Optional[OutputChannel] = None,
189+
metadata: Optional[Dict] = None,
179190
) -> Optional[DialogueStateTracker]:
180191
"""Get tracker for `sender_id` or create a new tracker for `sender_id`.
181192
182193
If a new tracker is created, `action_session_start` is run.
183194
184195
Args:
196+
metadata: Data sent from client associated with the incoming user message.
185197
output_channel: Output channel associated with the incoming user message.
186198
sender_id: Conversation ID for which to fetch the tracker.
187199
@@ -193,7 +205,7 @@ async def get_tracker_with_session_start(
193205
if not tracker:
194206
return None
195207

196-
await self._update_tracker_session(tracker, output_channel)
208+
await self._update_tracker_session(tracker, output_channel, metadata)
197209

198210
return tracker
199211

@@ -233,7 +245,7 @@ async def log_message(
233245
# we have a Tracker instance for each user
234246
# which maintains conversation state
235247
tracker = await self.get_tracker_with_session_start(
236-
message.sender_id, message.output_channel
248+
message.sender_id, message.output_channel, message.metadata
237249
)
238250

239251
if tracker:
@@ -291,10 +303,12 @@ def predict_next_action(
291303
action = self.domain.action_for_index(
292304
max_confidence_index, self.action_endpoint
293305
)
306+
294307
logger.debug(
295308
f"Predicted next action '{action.name()}' with confidence "
296309
f"{action_confidences[max_confidence_index]:.2f}."
297310
)
311+
298312
return action, policy, action_confidences[max_confidence_index]
299313

300314
@staticmethod

tests/core/test_processor.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,34 @@ async def test_update_tracker_session(
485485
]
486486

487487

488+
# noinspection PyProtectedMember
489+
async def test_update_tracker_session_with_metadata(
490+
default_channel: CollectingOutputChannel,
491+
default_processor: MessageProcessor,
492+
monkeypatch: MonkeyPatch,
493+
):
494+
sender_id = uuid.uuid4().hex
495+
tracker = default_processor.tracker_store.get_or_create_tracker(sender_id)
496+
497+
# patch `_has_session_expired()` so the `_update_tracker_session()` call actually
498+
# does something
499+
monkeypatch.setattr(default_processor, "_has_session_expired", lambda _: True)
500+
501+
metadata = {"metadataTestKey": "metadataTestValue"}
502+
503+
await default_processor._update_tracker_session(tracker, default_channel, metadata)
504+
505+
# the save is not called in _update_tracker_session()
506+
default_processor._save_tracker(tracker)
507+
508+
# inspect tracker events and make sure SessionStarted event is present and has metadata.
509+
tracker = default_processor.tracker_store.retrieve(sender_id)
510+
session_event_idx = tracker.events.index(SessionStarted())
511+
session_event_metadata = tracker.events[session_event_idx].metadata
512+
513+
assert session_event_metadata == metadata
514+
515+
488516
# noinspection PyProtectedMember
489517
async def test_update_tracker_session_with_slots(
490518
default_channel: CollectingOutputChannel,

0 commit comments

Comments
 (0)