Skip to content

Commit 75d73d2

Browse files
committed
Include source in output of failed stories only
The file source was always included in the tracker's sender id. This causes that persisted story to have 'different' trackers because the tracker's sender id will be different (because of the source). To make sure the impact is minimal only failed stories will be exported with the source of the story file. For this reason the tracker has been extended with an optional sender_source paramenter. (cherry picked from commit 090e55134d9922d13ba626b1311eeecaa8e04915)
1 parent e972166 commit 75d73d2

File tree

4 files changed

+44
-9
lines changed

4 files changed

+44
-9
lines changed

rasa/core/test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,10 @@ def _predict_tracker_actions(
348348
events = list(tracker.events)
349349

350350
partial_tracker = DialogueStateTracker.from_events(
351-
tracker.sender_id, events[:1], agent.domain.slots
351+
tracker.sender_id,
352+
events[:1],
353+
agent.domain.slots,
354+
sender_source=tracker.sender_source,
352355
)
353356

354357
tracker_actions = []
@@ -477,7 +480,7 @@ def log_failed_stories(failed, out_directory):
477480
f.write("<!-- All stories passed -->")
478481
else:
479482
for failure in failed:
480-
f.write(failure.export_stories())
483+
f.write(failure.export_stories(include_source=True))
481484
f.write("\n\n")
482485

483486

rasa/core/trackers.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ def from_events(
9898
evts: List[Event],
9999
slots: Optional[List[Slot]] = None,
100100
max_event_history: Optional[int] = None,
101+
sender_source: Optional[Text] = None,
101102
):
102-
tracker = cls(sender_id, slots, max_event_history)
103+
tracker = cls(sender_id, slots, max_event_history, sender_source)
103104
for e in evts:
104105
tracker.update(e)
105106
return tracker
@@ -109,6 +110,7 @@ def __init__(
109110
sender_id: Text,
110111
slots: Optional[Iterable[Slot]],
111112
max_event_history: Optional[int] = None,
113+
sender_source: Optional[Text] = None,
112114
) -> None:
113115
"""Initialize the tracker.
114116
@@ -127,6 +129,8 @@ def __init__(
127129
self.slots = {slot.name: copy.deepcopy(slot) for slot in slots}
128130
else:
129131
self.slots = AnySlotDict()
132+
# file source of the messages
133+
self.sender_source = sender_source
130134

131135
###
132136
# current state of the tracker - MUST be re-creatable by processing
@@ -448,13 +452,18 @@ def update(self, event: Event, domain: Optional[Domain] = None) -> None:
448452
for e in domain.slots_for_entities(event.parse_data["entities"]):
449453
self.update(e)
450454

451-
def export_stories(self, e2e: bool = False) -> Text:
455+
def export_stories(self, e2e: bool = False, include_source: bool = False) -> Text:
452456
"""Dump the tracker as a story in the Rasa Core story format.
453457
454458
Returns the dumped tracker as a string."""
455459
from rasa.core.training.structures import Story
456460

457-
story = Story.from_events(self.applied_events(), self.sender_id)
461+
story_name = (
462+
f"{self.sender_id} ({self.sender_source})"
463+
if include_source
464+
else self.sender_id
465+
)
466+
story = Story.from_events(self.applied_events(), story_name)
458467
return story.as_story_string(flat=True, e2e=e2e)
459468

460469
def export_stories_to_file(self, export_path: Text = "debug.md") -> None:

rasa/core/training/generator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def init_copy(self) -> "TrackerWithCachedStates":
8181
self.is_augmented,
8282
)
8383

84-
def copy(self, sender_id: Text = "") -> "TrackerWithCachedStates":
84+
def copy(
85+
self, sender_id: Text = "", sender_source: Text = ""
86+
) -> "TrackerWithCachedStates":
8587
"""Creates a duplicate of this tracker.
8688
8789
A new tracker will be created and all events
@@ -92,6 +94,7 @@ def copy(self, sender_id: Text = "") -> "TrackerWithCachedStates":
9294

9395
tracker = self.init_copy()
9496
tracker.sender_id = sender_id
97+
tracker.sender_source = sender_source
9598

9699
for event in self.events:
97100
tracker.update(event, skip_states=True)
@@ -521,12 +524,12 @@ def _process_step(
521524
# contribute to the trackers events
522525
if tracker.sender_id:
523526
if step.block_name not in tracker.sender_id.split(" > "):
524-
new_sender = f"{tracker.sender_id} > {step.block_name} > {step.source_name}"
527+
new_sender = tracker.sender_id + " > " + step.block_name
525528
else:
526529
new_sender = tracker.sender_id
527530
else:
528-
new_sender = f"{step.block_name} > {step.source_name}"
529-
trackers.append(tracker.copy(new_sender))
531+
new_sender = step.block_name
532+
trackers.append(tracker.copy(new_sender, step.source_name))
530533

531534
end_trackers = []
532535
for event in events:

tests/core/test_evaluation.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from pathlib import Path
33

4+
import rasa.utils.io
45
from rasa.core.test import _generate_trackers, collect_story_predictions, test
56

67
# we need this import to ignore the warning...
@@ -91,3 +92,22 @@ async def test_end_to_evaluation_with_forms(form_bot_agent: Agent):
9192
)
9293

9394
assert not story_evaluation.evaluation_store.has_prediction_target_mismatch()
95+
96+
97+
async def test_source_in_failed_stories(tmpdir: Path, default_agent: Agent):
98+
stories_path = str(tmpdir / "failed_stories.md")
99+
100+
await test(
101+
stories=E2E_STORY_FILE_UNKNOWN_ENTITY,
102+
agent=default_agent,
103+
out_directory=str(tmpdir),
104+
max_stories=None,
105+
e2e=False,
106+
)
107+
108+
failed_stories = rasa.utils.io.read_file(stories_path)
109+
110+
assert (
111+
"## simple_story_with_unknown_entity (data/test_evaluations/story_unknown_entity.md)"
112+
in failed_stories
113+
)

0 commit comments

Comments
 (0)