Skip to content

Commit e2f1d68

Browse files
authored
Merge pull request RasaHQ#4977 from RasaHQ/update-session-keep-predicting
Update when to predict another action
2 parents 56a96cd + f2ee1f5 commit e2f1d68

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

rasa/core/processor.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -480,9 +480,19 @@ def is_action_limit_reached():
480480

481481
# noinspection PyUnusedLocal
482482
@staticmethod
483-
def should_predict_another_action(action_name, events) -> bool:
484-
is_listen_action = action_name == ACTION_LISTEN_NAME
485-
return not is_listen_action
483+
def should_predict_another_action(action_name: Text, events: List[Event]) -> bool:
484+
"""Determine whether the processor should predict another action.
485+
486+
Args:
487+
action_name: Name of the latest executed action.
488+
events: List of events returned by the latest executed action.
489+
490+
Returns:
491+
`False` if `action_name` is `ACTION_LISTEN_NAME` or
492+
`ACTION_SESSION_START_NAME`, otherwise `True`.
493+
"""
494+
495+
return action_name not in (ACTION_LISTEN_NAME, ACTION_SESSION_START_NAME)
486496

487497
@staticmethod
488498
async def _send_bot_messages(

tests/core/test_processor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from rasa.core import jobs
1515
from rasa.core.actions.action import ACTION_LISTEN_NAME, ACTION_SESSION_START_NAME
16+
1617
from rasa.core.agent import Agent
1718
from rasa.core.channels.channel import CollectingOutputChannel, UserMessage
1819
from rasa.core.domain import SessionConfig
@@ -472,3 +473,23 @@ async def test_handle_message_with_session_start(
472473
SlotSet(entity, slot_2[entity]),
473474
ActionExecuted(ACTION_LISTEN_NAME),
474475
]
476+
477+
478+
# noinspection PyProtectedMember
479+
@pytest.mark.parametrize(
480+
"action_name, should_predict_another_action",
481+
[
482+
(ACTION_LISTEN_NAME, False),
483+
(ACTION_SESSION_START_NAME, False),
484+
("utter_greet", True),
485+
],
486+
)
487+
async def test_should_predict_another_action(
488+
default_processor: MessageProcessor,
489+
action_name: Text,
490+
should_predict_another_action: bool,
491+
):
492+
assert (
493+
default_processor.should_predict_another_action(action_name, [])
494+
== should_predict_another_action
495+
)

0 commit comments

Comments
 (0)