Skip to content

Commit f7a66e9

Browse files
authored
Merge pull request RasaHQ#4812 from kearnsw/master
Improve flexibility of Slack connector
2 parents b0c060f + f2d501f commit f7a66e9

File tree

4 files changed

+160
-16
lines changed

4 files changed

+160
-16
lines changed

changelog/4811.improvement.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Support invoking a ``SlackBot`` by direct messaging or ``@<app name>`` mentions.

docs/user-guide/connectors/slack.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ e.g. using:
5353
5454
You need to supply a ``credentials.yml`` with the following content:
5555

56-
- The ``slack_channel`` is the target your bot posts to.
57-
This can be a channel or an individual person. You can leave out
58-
the argument to post DMs to the bot.
56+
- The ``slack_channel`` can be a channel or an individual person that the bot should listen to for communications, in
57+
addition to the default behavior of listening for direct messages and app mentions, i.e. "@app_name".
58+
5959

6060
- Use the entry for ``Bot User OAuth Access Token`` in the
6161
"OAuth & Permissions" tab as your ``slack_token``. It should start
@@ -75,4 +75,4 @@ The endpoint for receiving slack messages is
7575
``http://localhost:5005/webhooks/slack/webhook``, replacing
7676
the host and port with the appropriate values. This is the URL
7777
you should add in the "OAuth & Permissions" section as well as
78-
the "Event Subscriptions".
78+
the "Event Subscriptions".

rasa/core/channels/slack.py

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,20 @@ def __init__(
169169
self.retry_reason_header = slack_retry_reason_header
170170
self.retry_num_header = slack_retry_number_header
171171

172+
@staticmethod
173+
def _is_app_mention(slack_event: Dict) -> bool:
174+
try:
175+
return slack_event["event"]["type"] == "app_mention"
176+
except KeyError:
177+
return False
178+
179+
@staticmethod
180+
def _is_direct_message(slack_event: Dict) -> bool:
181+
try:
182+
return slack_event["event"]["channel_type"] == "im"
183+
except KeyError:
184+
return False
185+
172186
@staticmethod
173187
def _is_user_message(slack_event: Dict) -> bool:
174188
return (
@@ -293,11 +307,15 @@ async def process_message(
293307

294308
return response.text(None, status=201, headers={"X-Slack-No-Retry": 1})
295309

310+
if metadata is not None:
311+
output_channel = metadata.get("out_channel")
312+
else:
313+
output_channel = None
314+
296315
try:
297-
out_channel = self.get_output_channel()
298316
user_msg = UserMessage(
299317
text,
300-
out_channel,
318+
self.get_output_channel(output_channel),
301319
sender_id,
302320
input_channel=self.name(),
303321
metadata=metadata,
@@ -310,6 +328,24 @@ async def process_message(
310328

311329
return response.text("")
312330

331+
def get_metadata(self, request: Request) -> Dict[Text, Any]:
332+
"""Extracts the metadata from a slack API event (https://api.slack.com/types/event).
333+
334+
Args:
335+
request: A `Request` object that contains a slack API event in the body.
336+
337+
Returns:
338+
Metadata extracted from the sent event payload. This includes the output channel for the response,
339+
and users that have installed the bot.
340+
"""
341+
slack_event = request.json
342+
event = slack_event.get("event", {})
343+
344+
return {
345+
"out_channel": event.get("channel"),
346+
"users": slack_event.get("authed_users"),
347+
}
348+
313349
def blueprint(
314350
self, on_new_message: Callable[[UserMessage], Awaitable[Any]]
315351
) -> Blueprint:
@@ -342,24 +378,45 @@ async def webhook(request: Request) -> HTTPResponse:
342378

343379
elif request.json:
344380
output = request.json
381+
event = output.get("event", {})
382+
user_message = event.get("text", "")
383+
sender_id = event.get("user", "")
384+
metadata = self.get_metadata(request)
385+
345386
if "challenge" in output:
346387
return response.json(output.get("challenge"))
347388

348-
elif self._is_user_message(output):
349-
metadata = self.get_metadata(request)
389+
elif self._is_user_message(output) and self._is_supported_channel(
390+
output, metadata
391+
):
350392
return await self.process_message(
351393
request,
352394
on_new_message,
353-
self._sanitize_user_message(
354-
output["event"]["text"], output["authed_users"]
395+
text=self._sanitize_user_message(
396+
user_message, metadata["users"]
355397
),
356-
output.get("event").get("user"),
357-
metadata,
398+
sender_id=sender_id,
399+
metadata=metadata,
400+
)
401+
else:
402+
logger.warning(
403+
f"Received message on unsupported channel: {metadata['out_channel']}"
358404
)
359405

360-
return response.text("Bot message delivered")
406+
return response.text("Bot message delivered.")
361407

362408
return slack_webhook
363409

364-
def get_output_channel(self) -> OutputChannel:
365-
return SlackBot(self.slack_token, self.slack_channel)
410+
def _is_supported_channel(self, slack_event: Dict, metadata: Dict) -> bool:
411+
return (
412+
self._is_direct_message(slack_event)
413+
or self._is_app_mention(slack_event)
414+
or metadata["out_channel"] == self.slack_channel
415+
)
416+
417+
def get_output_channel(self, channel: Optional[Text] = None) -> OutputChannel:
418+
channel = channel or self.slack_channel
419+
return SlackBot(self.slack_token, channel)
420+
421+
def set_output_channel(self, channel: Text) -> None:
422+
self.slack_channel = channel

tests/core/test_channels.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import urllib.parse
44
from typing import Dict
5-
from unittest.mock import patch, MagicMock
5+
from unittest.mock import patch, MagicMock, Mock
66

77
import pytest
88
import responses
@@ -461,6 +461,92 @@ def test_botframework_attachments():
461461
assert ch.add_attachments_to_metadata(payload, metadata) == updated_metadata
462462

463463

464+
def test_slack_metadata():
465+
from rasa.core.channels.slack import SlackInput
466+
from sanic.request import Request
467+
468+
user = "user1"
469+
channel = "channel1"
470+
authed_users = ["XXXXXXX", "YYYYYYY", "ZZZZZZZ"]
471+
direct_message_event = {
472+
"authed_users": authed_users,
473+
"event": {
474+
"client_msg_id": "XXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX",
475+
"type": "message",
476+
"text": "hello world",
477+
"user": user,
478+
"ts": "1579802617.000800",
479+
"team": "XXXXXXXXX",
480+
"blocks": [
481+
{
482+
"type": "rich_text",
483+
"block_id": "XXXXX",
484+
"elements": [
485+
{
486+
"type": "rich_text_section",
487+
"elements": [{"type": "text", "text": "hi"}],
488+
}
489+
],
490+
}
491+
],
492+
"channel": channel,
493+
"event_ts": "1579802617.000800",
494+
"channel_type": "im",
495+
},
496+
}
497+
498+
input_channel = SlackInput(
499+
slack_token="YOUR_SLACK_TOKEN", slack_channel="YOUR_SLACK_CHANNEL"
500+
)
501+
502+
r = Mock()
503+
r.json = direct_message_event
504+
metadata = input_channel.get_metadata(request=r)
505+
assert metadata["out_channel"] == channel
506+
assert metadata["users"] == authed_users
507+
508+
509+
def test_slack_metadata_missing_keys():
510+
from rasa.core.channels.slack import SlackInput
511+
from sanic.request import Request
512+
513+
channel = "channel1"
514+
direct_message_event = {
515+
"event": {
516+
"client_msg_id": "XXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX",
517+
"type": "message",
518+
"text": "hello world",
519+
"ts": "1579802617.000800",
520+
"team": "XXXXXXXXX",
521+
"blocks": [
522+
{
523+
"type": "rich_text",
524+
"block_id": "XXXXX",
525+
"elements": [
526+
{
527+
"type": "rich_text_section",
528+
"elements": [{"type": "text", "text": "hi"}],
529+
}
530+
],
531+
}
532+
],
533+
"channel": channel,
534+
"event_ts": "1579802617.000800",
535+
"channel_type": "im",
536+
},
537+
}
538+
539+
input_channel = SlackInput(
540+
slack_token="YOUR_SLACK_TOKEN", slack_channel="YOUR_SLACK_CHANNEL"
541+
)
542+
543+
r = Mock()
544+
r.json = direct_message_event
545+
metadata = input_channel.get_metadata(request=r)
546+
assert metadata["users"] is None
547+
assert metadata["out_channel"] == channel
548+
549+
464550
def test_slack_message_sanitization():
465551
from rasa.core.channels.slack import SlackInput
466552

0 commit comments

Comments
 (0)