-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(tool/email): use unified interfaces
- Loading branch information
Showing
6 changed files
with
161 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,18 @@ | ||
import asyncio | ||
import json | ||
import os | ||
from typing import List | ||
from textwrap import dedent | ||
from typing import List, AsyncGenerator | ||
|
||
from googleapiclient.errors import HttpError | ||
from markdown import markdown | ||
from markdownify import markdownify as to_markdown | ||
from simplegmail.message import Message | ||
|
||
from npiai import FunctionTool, function, utils | ||
from npiai.error import UnauthorizedError | ||
from npiai.context import Context | ||
from npiai.constant import app | ||
from npiai.utils import html_to_markdown | ||
from npiai.tools.shared_types.base_email_tool import BaseEmailTool, EmailMessage | ||
|
||
from google.oauth2.credentials import Credentials as GoogleCredentials | ||
from oauth2client.client import OAuth2Credentials | ||
|
@@ -23,11 +23,6 @@ | |
def convert_google_cred_to_oauth2_cred( | ||
google_credentials: GoogleCredentials, | ||
) -> OAuth2Credentials: | ||
# Convert expiry datetime to string if necessary | ||
expiry = ( | ||
google_credentials.expiry.isoformat() if google_credentials.expiry else None | ||
) | ||
|
||
# Create an instance of OAuth2Credentials | ||
return OAuth2Credentials( | ||
access_token=google_credentials.token, | ||
|
@@ -41,7 +36,7 @@ def convert_google_cred_to_oauth2_cred( | |
) | ||
|
||
|
||
class Gmail(FunctionTool): | ||
class Gmail(FunctionTool, BaseEmailTool): | ||
name = "gmail" | ||
description = 'interact with Gmail using English, e.g., gmail("send an email to [email protected]")' | ||
system_prompt = "You are a Gmail Agent helping users to manage their emails" | ||
|
@@ -53,6 +48,17 @@ def __init__(self, creds: GoogleCredentials | None = None): | |
super().__init__() | ||
self._creds = creds | ||
|
||
def _fetch_messages_by_ids(self, message_ids: List[str]) -> List[Message]: | ||
emails: List[Message] = [] | ||
|
||
for message_id in message_ids: | ||
try: | ||
emails.append(self._gmail_client.get_message_by_id(message_id)) | ||
except HttpError: | ||
pass | ||
|
||
return emails | ||
|
||
@classmethod | ||
def from_context(cls, ctx: Context) -> "Gmail": | ||
if not utils.is_cloud_env(): | ||
|
@@ -76,32 +82,80 @@ async def start(self): | |
) | ||
await super().start() | ||
|
||
def _get_messages_from_ids(self, message_ids: List[str]) -> List[Message]: | ||
emails: List[Message] = [] | ||
def convert_message(self, message: Message) -> EmailMessage: | ||
return EmailMessage( | ||
id=message.id, | ||
thread_id=message.thread_id, | ||
sender=message.sender, | ||
recipient=message.recipient, | ||
cc=message.cc, | ||
bcc=message.bcc, | ||
subject=message.subject, | ||
body=message.plain or html_to_markdown(message.html), | ||
) | ||
|
||
for message_id in message_ids: | ||
try: | ||
emails.append(self._gmail_client.get_message_by_id(message_id)) | ||
except HttpError: | ||
pass | ||
async def get_message_by_id(self, message_id: str) -> EmailMessage | None: | ||
try: | ||
message = self._gmail_client.get_message_by_id(message_id) | ||
return self.convert_message(message) | ||
except HttpError: | ||
return None | ||
|
||
return emails | ||
async def list_inbox_stream( | ||
self, | ||
limit: int = -1, | ||
query: str = None, | ||
) -> AsyncGenerator[EmailMessage, None]: | ||
""" | ||
List emails in the inbox | ||
Args: | ||
limit: The number of emails to list, -1 for all. Default is -1. | ||
query: A query to filter the emails. Default is None. | ||
""" | ||
|
||
@staticmethod | ||
def _message_to_string(message: Message) -> str: | ||
return ( | ||
dedent( | ||
f""" | ||
Message ID: {message.id} | ||
Thread ID: {message.thread_id} | ||
Sender ID: {message.headers.get('Message-ID', message.id)} | ||
From: {message.sender} | ||
To: {message.recipient} | ||
Subject: {message.subject} | ||
""" | ||
page_size = 10 if limit == -1 else min(limit, 10) | ||
page_token = None | ||
count = 0 | ||
|
||
while limit == -1 or count < limit: | ||
response = ( | ||
self._gmail_client.service.users() | ||
.messages() | ||
.list( | ||
userId="me", | ||
q=query, | ||
maxResults=page_size, | ||
pageToken=page_token, | ||
) | ||
.execute() | ||
) | ||
+ f"Content: {message.plain or to_markdown(message.html)}" | ||
) | ||
|
||
messages = response.get("messages", []) | ||
|
||
if not messages: | ||
return | ||
|
||
for msg_ref in messages: | ||
# noinspection PyProtectedMember | ||
msg = self._gmail_client._build_message_from_ref( | ||
user_id="me", | ||
message_ref=msg_ref, | ||
) | ||
|
||
msg.attachments[0].download() | ||
|
||
yield self.convert_message(msg) | ||
|
||
count += 1 | ||
|
||
if limit != -1 and count >= limit: | ||
return | ||
|
||
page_token = response.get("nextPageToken", None) | ||
|
||
if not page_token: | ||
return | ||
|
||
@function | ||
def add_labels(self, message_ids: List[str], labels: List[str]) -> str: | ||
|
@@ -113,7 +167,7 @@ def add_labels(self, message_ids: List[str], labels: List[str]) -> str: | |
message_ids: A list of IDs of messages that should be labeled. You can find this in the "Message ID: ..." line of the email. | ||
labels: A list of labels to add. | ||
""" | ||
messages = self._get_messages_from_ids(message_ids) | ||
messages = self._fetch_messages_by_ids(message_ids) | ||
|
||
if len(messages) == 0: | ||
raise Exception("Error: No messages found for the given IDs") | ||
|
@@ -144,7 +198,7 @@ def remove_labels(self, message_ids: List[str], labels: List[str]) -> str: | |
message_ids: A list of IDs of messages that should be labeled. You can find this in the "Message ID: ..." line of the email. | ||
labels: A list of labels to remove. | ||
""" | ||
messages = self._get_messages_from_ids(message_ids) | ||
messages = self._fetch_messages_by_ids(message_ids) | ||
|
||
if len(messages) == 0: | ||
raise Exception("Error: No messages found for the given IDs") | ||
|
@@ -193,7 +247,9 @@ def create_draft( | |
msg_html=markdown(message), | ||
) | ||
|
||
return "The following draft is created:\n" + self._message_to_string(msg) | ||
return "The following draft is created:\n" + json.dumps( | ||
self.convert_message(msg), ensure_ascii=False | ||
) | ||
|
||
@function | ||
def create_reply_draft( | ||
|
@@ -242,7 +298,9 @@ def create_reply_draft( | |
reply_to=recipient_id, | ||
) | ||
|
||
return "The following reply draft is created:\n" + self._message_to_string(msg) | ||
return "The following reply draft is created:\n" + json.dumps( | ||
self.convert_message(msg), ensure_ascii=False | ||
) | ||
|
||
@function | ||
def reply( | ||
|
@@ -291,7 +349,9 @@ def reply( | |
reply_to=recipient_id, | ||
) | ||
|
||
return "The following reply is sent:\n" + self._message_to_string(msg) | ||
return "The following reply is sent:\n" + json.dumps( | ||
self.convert_message(msg), ensure_ascii=False | ||
) | ||
|
||
@function | ||
def search_emails(self, query: str = None, max_results: int = 100) -> str: | ||
|
@@ -307,9 +367,7 @@ def search_emails(self, query: str = None, max_results: int = 100) -> str: | |
max_results=max_results, | ||
) | ||
|
||
return json.dumps( | ||
[self._message_to_string(m) for m in msgs], ensure_ascii=False | ||
) | ||
return json.dumps([self.convert_message(m) for m in msgs], ensure_ascii=False) | ||
|
||
@function | ||
async def send_email( | ||
|
@@ -350,7 +408,9 @@ async def send_email( | |
msg_html=markdown(message), | ||
) | ||
|
||
return "Sending Success\n" + self._message_to_string(msg) | ||
return "Sending Success\n" + json.dumps( | ||
self.convert_message(msg), ensure_ascii=False | ||
) | ||
|
||
@function | ||
async def wait_for_reply(self, sender: str): | ||
|
@@ -369,6 +429,6 @@ async def wait_for_reply(self, sender: str): | |
if len(messages): | ||
msg = messages[0] | ||
msg.mark_as_read() | ||
return self._message_to_string(msg) | ||
return json.dumps(self.convert_message(msg), ensure_ascii=False) | ||
|
||
await asyncio.sleep(3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.