Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add global initiate chat #538

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import warnings
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Expand Down Expand Up @@ -66,8 +67,13 @@

F = TypeVar("F", bound=Callable[..., Any])

if TYPE_CHECKING:
# checks if ConversableAgent is implementing LLMAgent protocol
def create_conversible_agent(name: str) -> LLMAgent:
return ConversableAgent(name)

class ConversableAgent(LLMAgent):

class ConversableAgent:
"""(In preview) A class for generic conversable agents which can be configured as assistant or user proxy.

After receiving each message, the agent will send a reply to the sender unless the msg is a termination msg.
Expand Down Expand Up @@ -284,6 +290,10 @@ def __init__(
"update_agent_state": [],
}

# check if the agent is implementing LLMAgent protocol
if not isinstance(self, LLMAgent):
raise TypeError("ConversableAgent must implement LLMAgent protocol")

def _validate_name(self, name: str) -> None:
# Validation for name using regex to detect any whitespace
if re.search(r"\s", name):
Expand Down
3 changes: 3 additions & 0 deletions autogen/agentchat/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .initiate_chat import AsyncResponseProtocol, ResponseProtocol, a_initiate_chat, initiate_chat

__all__ = ["AsyncResponseProtocol", "ResponseProtocol", "a_initiate_chat", "initiate_chat"]
107 changes: 107 additions & 0 deletions autogen/agentchat/experimental/chat_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

from contextvars import ContextVar
from dataclasses import dataclass
from logging import Logger, getLogger
from typing import TYPE_CHECKING, Any, Optional, Protocol, Union, runtime_checkable
from uuid import UUID

from pydantic import UUID4

if TYPE_CHECKING:
from ..agent import Agent


@runtime_checkable
class ChatContextProtocol(Protocol):
@property
def initial_agent(self) -> "Agent":
"""The agent that initiated the chat."""
...

@property
def agents(self) -> list["Agent"]:
"""The agents participating in the chat."""
...

@property
def initial_message(self) -> Union[str, dict[str, Any]]:
"""The messages received by the agent."""
...

@property
def messages(self) -> list[dict[str, Any]]: ...

@property
def logger(self) -> Logger: ...

@classmethod
def get_registered_chat(cls) -> "ChatContextProtocol": ...

def __entry__(self) -> "ChatContextProtocol": ...

def __exit__(
self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[Any]
) -> None: ...


@dataclass
class ChatContext:
initial_agent: "Agent"
agents: list["Agent"]
initial_message: Union[str, dict[str, Any]]
messages: list[dict[str, Any]]
logger: Logger
uuid: UUID

_registered_chats: ContextVar[list["ChatContext"]] = ContextVar("registered_chats", default=[])

def __init__(
self,
*,
initial_agent: "Agent",
agents: Optional[list["Agent"]] = None,
initial_message: Union[str, dict[str, Any]],
logger: Optional[Logger] = None,
uuid: Optional[UUID] = None,
):
self.initial_agent = initial_agent
self.agents = agents or []
self.initial_message = initial_message
self.messages = []
self.logger = logger or getLogger(__name__)
self.uuid = uuid or UUID4()

@classmethod
def get_registered_chat(cls) -> "ChatContext":
registered_chats: list[ChatContext] = cls._registered_chats.get()
if registered_chats:
return registered_chats[-1]
raise ValueError("No registered chats found.")

def __entry__(self) -> "ChatContext":
registered_chats = ChatContext._registered_chats.get()
registered_chats.append(self)
return self

def __exit__(self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[Any]) -> None:
registered_chats = ChatContext._registered_chats.get()
registered_chats.pop()

# check if the InitiateChatIOStream implements the IOStream protocol
if TYPE_CHECKING:

@staticmethod
def _type_check(
*,
initial_agent: "Agent",
agents: Optional[list["Agent"]] = None,
initial_message: Union[str, dict[str, Any]],
logger: Optional[Logger] = None,
uuid: Optional[UUID] = None,
) -> ChatContextProtocol:
return ChatContext(
initial_agent=initial_agent, agents=agents, initial_message=initial_message, logger=logger, uuid=uuid
)
132 changes: 132 additions & 0 deletions autogen/agentchat/experimental/initiate_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncIterable, Iterable, Optional, Protocol, Union, runtime_checkable
from uuid import UUID

from ...io import IOStream
from .chat_context import ChatContext

if TYPE_CHECKING:
from ...messages import BaseMessage
from ..agent import Agent

__all__ = ["AsyncResponseProtocol", "ChatContext", "ResponseProtocol", "a_initiate_chat", "initiate_chat"]


@runtime_checkable
class ResponseProtocol(Protocol):
@property
def messages(self) -> Iterable["BaseMessage"]:
"""The messages received by the agent."""
...

# todo: replace request_uuid with InputResponseMessage
def send(self, request_uuid: UUID, response: str) -> None:
"""Send a response to a request."""
...


@runtime_checkable
class AsyncResponseProtocol(Protocol):
@property
def messages(self) -> AsyncIterable["BaseMessage"]:
"""The messages received by the agent."""
...


@dataclass
class Response:
iostream: IOStream
chat_context: ChatContext

@property
def messages(self) -> Iterable["BaseMessage"]:
"""The messages received by the agent."""
raise NotImplementedError("This function is not implemented yet.")

# todo: replace request_uuid with InputResponseMessage
def send(self, request_uuid: UUID, response: str) -> None:
"""Send a response to a request."""
raise NotImplementedError("This function is not implemented yet.")

# check if the Response implements the ResponseProtocol protocol
if TYPE_CHECKING:

@staticmethod
def _type_check(iostream: IOStream, chat_context: ChatContext) -> ResponseProtocol:
return Response(iostream=iostream, chat_context=chat_context)


class InitiateChatIOStream:
def print(self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False) -> None:
"""Print data to the output stream.

Args:
objects (any): The data to print.
sep (str, optional): The separator between objects. Defaults to " ".
end (str, optional): The end of the output. Defaults to "\n".
flush (bool, optional): Whether to flush the output. Defaults to False.
"""
raise NotImplementedError("This function is not implemented yet.")

def send(self, message: BaseMessage) -> None:
"""Send a message to the output stream.

Args:
message (Any): The message to send.
"""
raise NotImplementedError("This function is not implemented yet.")

def input(self, prompt: str = "", *, password: bool = False) -> str:
"""Read a line from the input stream.

Args:
prompt (str, optional): The prompt to display. Defaults to "".
password (bool, optional): Whether to read a password. Defaults to False.

Returns:
str: The line read from the input stream.

"""
raise NotImplementedError("This function is not implemented yet.")

# check if the InitiateChatIOStream implements the IOStream protocol
if TYPE_CHECKING:

@staticmethod
def _type_check(agent: Agent) -> IOStream:
return InitiateChatIOStream()


def initiate_chat(
agent: "Agent",
*,
message: Union[str, dict[str, Any]],
recipient: Optional["Agent"] = None,
) -> ResponseProtocol:
# start initiate chat in a background thread
iostream = InitiateChatIOStream()
chat_context = ChatContext(
initial_agent=agent,
agents=[recipient] if recipient else [],
initial_message=message,
)
response = Response(iostream=iostream, chat_context=chat_context)

with ThreadPoolExecutor() as executor:
executor.submit(agent.initiate_chat, agent, message=message, recipient=recipient)

return response


async def a_initiate_chat(
agent: "Agent",
*,
message: Union[str, dict[str, Any]],
recipient: Optional["Agent"] = None,
) -> ResponseProtocol:
raise NotImplementedError("This function is not implemented yet.")
5 changes: 2 additions & 3 deletions autogen/io/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import getpass
from typing import Any

from autogen.messages.base_message import BaseMessage
from autogen.messages.print_message import PrintMessage

from ..messages.base_message import BaseMessage
from ..messages.print_message import PrintMessage
from .base import IOStream

__all__ = ("IOConsole",)
Expand Down
Loading
Loading