Skip to content

Commit 84a262e

Browse files
authored
Merge pull request RasaHQ#5578 from alfredfrancis/feature/get_output_channel_for_socketio
Implemented get_output_channel() fn for SocketIO Channel
2 parents 89a7699 + 1ca0823 commit 84a262e

File tree

3 files changed

+28
-9
lines changed

3 files changed

+28
-9
lines changed

changelog/5578.improvement.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added ``socketio`` to the compatible channels for :ref:`reminders-and-external-events`.

docs/_static/spec/rasa.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,7 @@ components:
834834
- telegram
835835
- twilio
836836
- webexteams
837+
- socketio
837838

838839
responses:
839840

rasa/core/channels/socketio.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ class SocketIOOutput(OutputChannel):
2828
def name(cls) -> Text:
2929
return "socketio"
3030

31-
def __init__(self, sio, sid, bot_message_evt) -> None:
31+
def __init__(self, sio: AsyncServer, bot_message_evt: Text) -> None:
3232
self.sio = sio
33-
self.sid = sid
3433
self.bot_message_evt = bot_message_evt
3534

3635
async def _send_message(self, socket_id: Text, response: Any) -> None:
@@ -44,15 +43,15 @@ async def send_text_message(
4443
"""Send a message through this channel."""
4544

4645
for message_part in text.strip().split("\n\n"):
47-
await self._send_message(self.sid, {"text": message_part})
46+
await self._send_message(recipient_id, {"text": message_part})
4847

4948
async def send_image_url(
5049
self, recipient_id: Text, image: Text, **kwargs: Any
5150
) -> None:
5251
"""Sends an image to the output"""
5352

5453
message = {"attachment": {"type": "image", "payload": {"src": image}}}
55-
await self._send_message(self.sid, message)
54+
await self._send_message(recipient_id, message)
5655

5756
async def send_text_with_buttons(
5857
self,
@@ -80,7 +79,7 @@ async def send_text_with_buttons(
8079
)
8180

8281
for message in messages:
83-
await self._send_message(self.sid, message)
82+
await self._send_message(recipient_id, message)
8483

8584
async def send_elements(
8685
self, recipient_id: Text, elements: Iterable[Dict[Text, Any]], **kwargs: Any
@@ -95,22 +94,22 @@ async def send_elements(
9594
}
9695
}
9796

98-
await self._send_message(self.sid, message)
97+
await self._send_message(recipient_id, message)
9998

10099
async def send_custom_json(
101100
self, recipient_id: Text, json_message: Dict[Text, Any], **kwargs: Any
102101
) -> None:
103102
"""Sends custom json to the output"""
104103

105-
json_message.setdefault("room", self.sid)
104+
json_message.setdefault("room", recipient_id)
106105

107106
await self.sio.emit(self.bot_message_evt, **json_message)
108107

109108
async def send_attachment(
110109
self, recipient_id: Text, attachment: Dict[Text, Any], **kwargs: Any
111110
) -> None:
112111
"""Sends an attachment to the user."""
113-
await self._send_message(self.sid, {"attachment": attachment})
112+
await self._send_message(recipient_id, {"attachment": attachment})
114113

115114

116115
class SocketIOInput(InputChannel):
@@ -144,6 +143,19 @@ def __init__(
144143
self.user_message_evt = user_message_evt
145144
self.namespace = namespace
146145
self.socketio_path = socketio_path
146+
self.sio = None
147+
148+
def get_output_channel(self) -> Optional["OutputChannel"]:
149+
if self.sio is None:
150+
raise_warning(
151+
"SocketIO output channel cannot be recreated. "
152+
"This is expected behavior when using multiple Sanic "
153+
"workers or multiple Rasa Open Source instances. "
154+
"Please use a different channel for external events in these "
155+
"scenarios."
156+
)
157+
return
158+
return SocketIOOutput(self.sio, self.bot_message_evt)
147159

148160
def blueprint(
149161
self, on_new_message: Callable[[UserMessage], Awaitable[Any]]
@@ -155,6 +167,9 @@ def blueprint(
155167
sio, self.socketio_path, "socketio_webhook", __name__
156168
)
157169

170+
# make sio object static to use in get_output_channel
171+
self.sio = sio
172+
158173
@socketio_webhook.route("/", methods=["GET"])
159174
async def health(_: Request) -> HTTPResponse:
160175
return response.json({"status": "ok"})
@@ -173,12 +188,14 @@ async def session_request(sid: Text, data: Optional[Dict]):
173188
data = {}
174189
if "session_id" not in data or data["session_id"] is None:
175190
data["session_id"] = uuid.uuid4().hex
191+
if self.session_persistence:
192+
sio.enter_room(sid, data["session_id"])
176193
await sio.emit("session_confirm", data["session_id"], room=sid)
177194
logger.debug(f"User {sid} connected to socketIO endpoint.")
178195

179196
@sio.on(self.user_message_evt, namespace=self.namespace)
180197
async def handle_message(sid: Text, data: Dict) -> Any:
181-
output_channel = SocketIOOutput(sio, sid, self.bot_message_evt)
198+
output_channel = SocketIOOutput(sio, self.bot_message_evt)
182199

183200
if self.session_persistence:
184201
if not data.get("session_id"):

0 commit comments

Comments
 (0)