Skip to content

Commit

Permalink
refactor(tool/email): use unified interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
idiotWu committed Dec 26, 2024
1 parent 4c4bb0d commit afef096
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 79 deletions.
142 changes: 101 additions & 41 deletions npiai/tools/google/gmail/app.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import asyncio
import json
import os
from typing import List
from textwrap import dedent
from typing import List, AsyncGenerator

from googleapiclient.errors import HttpError
from markdown import markdown
from markdownify import markdownify as to_markdown
from simplegmail.message import Message

from npiai import FunctionTool, function, utils
from npiai.error import UnauthorizedError
from npiai.context import Context
from npiai.constant import app
from npiai.utils import html_to_markdown
from npiai.tools.shared_types.base_email_tool import BaseEmailTool, EmailMessage

from google.oauth2.credentials import Credentials as GoogleCredentials
from oauth2client.client import OAuth2Credentials
Expand All @@ -23,11 +23,6 @@
def convert_google_cred_to_oauth2_cred(
google_credentials: GoogleCredentials,
) -> OAuth2Credentials:
# Convert expiry datetime to string if necessary
expiry = (
google_credentials.expiry.isoformat() if google_credentials.expiry else None
)

# Create an instance of OAuth2Credentials
return OAuth2Credentials(
access_token=google_credentials.token,
Expand All @@ -41,7 +36,7 @@ def convert_google_cred_to_oauth2_cred(
)


class Gmail(FunctionTool):
class Gmail(FunctionTool, BaseEmailTool):
name = "gmail"
description = 'interact with Gmail using English, e.g., gmail("send an email to [email protected]")'
system_prompt = "You are a Gmail Agent helping users to manage their emails"
Expand All @@ -53,6 +48,17 @@ def __init__(self, creds: GoogleCredentials | None = None):
super().__init__()
self._creds = creds

def _fetch_messages_by_ids(self, message_ids: List[str]) -> List[Message]:
emails: List[Message] = []

for message_id in message_ids:
try:
emails.append(self._gmail_client.get_message_by_id(message_id))
except HttpError:
pass

return emails

@classmethod
def from_context(cls, ctx: Context) -> "Gmail":
if not utils.is_cloud_env():
Expand All @@ -76,32 +82,80 @@ async def start(self):
)
await super().start()

def _get_messages_from_ids(self, message_ids: List[str]) -> List[Message]:
emails: List[Message] = []
def convert_message(self, message: Message) -> EmailMessage:
return EmailMessage(
id=message.id,
thread_id=message.thread_id,
sender=message.sender,
recipient=message.recipient,
cc=message.cc,
bcc=message.bcc,
subject=message.subject,
body=message.plain or html_to_markdown(message.html),
)

for message_id in message_ids:
try:
emails.append(self._gmail_client.get_message_by_id(message_id))
except HttpError:
pass
async def get_message_by_id(self, message_id: str) -> EmailMessage | None:
try:
message = self._gmail_client.get_message_by_id(message_id)
return self.convert_message(message)
except HttpError:
return None

return emails
async def list_inbox_stream(
self,
limit: int = -1,
query: str = None,
) -> AsyncGenerator[EmailMessage, None]:
"""
List emails in the inbox
Args:
limit: The number of emails to list, -1 for all. Default is -1.
query: A query to filter the emails. Default is None.
"""

@staticmethod
def _message_to_string(message: Message) -> str:
return (
dedent(
f"""
Message ID: {message.id}
Thread ID: {message.thread_id}
Sender ID: {message.headers.get('Message-ID', message.id)}
From: {message.sender}
To: {message.recipient}
Subject: {message.subject}
"""
page_size = 10 if limit == -1 else min(limit, 10)
page_token = None
count = 0

while limit == -1 or count < limit:
response = (
self._gmail_client.service.users()
.messages()
.list(
userId="me",
q=query,
maxResults=page_size,
pageToken=page_token,
)
.execute()
)
+ f"Content: {message.plain or to_markdown(message.html)}"
)

messages = response.get("messages", [])

if not messages:
return

for msg_ref in messages:
# noinspection PyProtectedMember
msg = self._gmail_client._build_message_from_ref(
user_id="me",
message_ref=msg_ref,
)

msg.attachments[0].download()

yield self.convert_message(msg)

count += 1

if limit != -1 and count >= limit:
return

page_token = response.get("nextPageToken", None)

if not page_token:
return

@function
def add_labels(self, message_ids: List[str], labels: List[str]) -> str:
Expand All @@ -113,7 +167,7 @@ def add_labels(self, message_ids: List[str], labels: List[str]) -> str:
message_ids: A list of IDs of messages that should be labeled. You can find this in the "Message ID: ..." line of the email.
labels: A list of labels to add.
"""
messages = self._get_messages_from_ids(message_ids)
messages = self._fetch_messages_by_ids(message_ids)

if len(messages) == 0:
raise Exception("Error: No messages found for the given IDs")
Expand Down Expand Up @@ -144,7 +198,7 @@ def remove_labels(self, message_ids: List[str], labels: List[str]) -> str:
message_ids: A list of IDs of messages that should be labeled. You can find this in the "Message ID: ..." line of the email.
labels: A list of labels to remove.
"""
messages = self._get_messages_from_ids(message_ids)
messages = self._fetch_messages_by_ids(message_ids)

if len(messages) == 0:
raise Exception("Error: No messages found for the given IDs")
Expand Down Expand Up @@ -193,7 +247,9 @@ def create_draft(
msg_html=markdown(message),
)

return "The following draft is created:\n" + self._message_to_string(msg)
return "The following draft is created:\n" + json.dumps(
self.convert_message(msg), ensure_ascii=False
)

@function
def create_reply_draft(
Expand Down Expand Up @@ -242,7 +298,9 @@ def create_reply_draft(
reply_to=recipient_id,
)

return "The following reply draft is created:\n" + self._message_to_string(msg)
return "The following reply draft is created:\n" + json.dumps(
self.convert_message(msg), ensure_ascii=False
)

@function
def reply(
Expand Down Expand Up @@ -291,7 +349,9 @@ def reply(
reply_to=recipient_id,
)

return "The following reply is sent:\n" + self._message_to_string(msg)
return "The following reply is sent:\n" + json.dumps(
self.convert_message(msg), ensure_ascii=False
)

@function
def search_emails(self, query: str = None, max_results: int = 100) -> str:
Expand All @@ -307,9 +367,7 @@ def search_emails(self, query: str = None, max_results: int = 100) -> str:
max_results=max_results,
)

return json.dumps(
[self._message_to_string(m) for m in msgs], ensure_ascii=False
)
return json.dumps([self.convert_message(m) for m in msgs], ensure_ascii=False)

@function
async def send_email(
Expand Down Expand Up @@ -350,7 +408,9 @@ async def send_email(
msg_html=markdown(message),
)

return "Sending Success\n" + self._message_to_string(msg)
return "Sending Success\n" + json.dumps(
self.convert_message(msg), ensure_ascii=False
)

@function
async def wait_for_reply(self, sender: str):
Expand All @@ -369,6 +429,6 @@ async def wait_for_reply(self, sender: str):
if len(messages):
msg = messages[0]
msg.mark_as_read()
return self._message_to_string(msg)
return json.dumps(self.convert_message(msg), ensure_ascii=False)

await asyncio.sleep(3)
2 changes: 1 addition & 1 deletion npiai/tools/outlook/__test__/list_emails.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def main():
async with Outlook(creds) as outlook:
async for email in outlook.list_inbox_stream(limit=10):
msg_with_body = await outlook.get_message_by_id(email.id)
print(outlook.message_to_dict(msg_with_body))
print(outlook.convert_message(msg_with_body))


if __name__ == "__main__":
Expand Down
53 changes: 23 additions & 30 deletions npiai/tools/outlook/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@

from npiai import FunctionTool, function, Context
from npiai.utils import html_to_markdown
from npiai.tools.shared_types.base_email_tool import BaseEmailTool, EmailMessage


class Outlook(FunctionTool):
class Outlook(FunctionTool, BaseEmailTool):
name = "outlook"
description = "Manage Outlook emails"
system_prompt = "You are an agent helping user manage Outlook emails."
Expand All @@ -42,23 +43,19 @@ def _get_email_address(self, recipient: Recipient) -> str:

return f"{recipient.email_address.name} <{recipient.email_address.address}>"

def message_to_dict(self, message: Message):
def convert_message(self, message: Message):
recipients = ", ".join(
self._get_email_address(recipient) for recipient in message.to_recipients
)

email = {
"message_id": message.id,
"conversation_id": message.conversation_id,
"from": self._get_email_address(message.sender),
"to": recipients,
"subject": message.subject,
}

if message.body:
email["body"] = html_to_markdown(message.body.content)

return email
return EmailMessage(
id=message.id,
thread_id=message.conversation_id,
sender=self._get_email_address(message.sender),
recipient=recipients,
subject=message.subject,
body=html_to_markdown(message.body.content) if message.body else None,
)

async def get_message_by_id(self, message_id: str):
"""
Expand All @@ -67,27 +64,24 @@ async def get_message_by_id(self, message_id: str):
Args:
message_id: the ID of the message
"""
return await self._client.me.messages.by_message_id(message_id).get()
msg = await self._client.me.messages.by_message_id(message_id).get()
return self.convert_message(msg)

async def list_inbox_stream(
self, limit: int = -1, query: str = None, include_body: bool = False
) -> AsyncGenerator[Message, None]:
self, limit: int = -1, query: str = None
) -> AsyncGenerator[EmailMessage, None]:
"""
List emails in the inbox
Args:
limit: The number of emails to list, -1 for all. Default is -1.
query: A query to filter the emails. Default is None.
include_body: Whether to include the email body. Default is False.
"""
page_size = 10 if limit == -1 else min(limit, 10)
select = ["id", "subject", "from", "receivedDateTime"]
select = ["id", "subject", "from", "receivedDateTime", "body"]
count = 0

if include_body:
select.append("body")

while limit != -1 and count < limit:
while limit == -1 or count < limit:
query_params = (
MessagesRequestBuilder.MessagesRequestBuilderGetQueryParameters(
skip=count,
Expand All @@ -112,33 +106,32 @@ async def list_inbox_stream(
return

for message in messages.value:
yield message
yield self.convert_message(message)
count += 1

if count >= limit:
if limit != -1 and count >= limit:
return

@function
async def search_emails(
self,
limit: int = 100,
query: str = None,
include_body: bool = False,
) -> str:
"""
Search for emails with a query.
Args:
limit: The number of emails to return, -1 for all. Default is -1.
query: A query to filter the emails. Default is None.
include_body: Whether to include the email body. Default is False.
"""
messages = []

async for message in self.list_inbox_stream(
limit=limit, query=query, include_body=include_body
async for msg in self.list_inbox_stream(
limit=limit,
query=query,
):
messages.append(self.message_to_dict(message))
messages.append(msg)

return json.dumps(messages, ensure_ascii=False)

Expand Down
Empty file.
Loading

0 comments on commit afef096

Please sign in to comment.