forked from microsoft/autogen
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Proxy PR for Long Context Capability 1513 (microsoft#1591)
* Add new capability to handle long context * Make print conditional * Remove superfluous comment * Fix msg order * Allow user to specify max_tokens * Add ability to specify max_tokens per message; improve name * Improve doc and readability * Add tests * Improve documentation and add tests per Erik and Chi's feedback * Update notebook * Update doc string of add to agents * Improve doc string * improve notebook * Update github workflows for context handling * Update docstring * update notebook to use raw config list. * Update contrib-openai.yml remove _target * Fix code formatting * Fix workflow file * Update .github/workflows/contrib-openai.yml --------- Co-authored-by: Eric Zhu <[email protected]> Co-authored-by: Chi Wang <[email protected]>
- Loading branch information
1 parent
a3c3317
commit 47d6c75
Showing
7 changed files
with
920 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +0,0 @@ | ||
from .teachability import Teachability | ||
from .agent_capability import AgentCapability | ||
|
||
|
||
__all__ = ["Teachability", "AgentCapability"] | ||
108 changes: 108 additions & 0 deletions
108
autogen/agentchat/contrib/capabilities/context_handling.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import sys | ||
from termcolor import colored | ||
from typing import Dict, Optional, List | ||
from autogen import ConversableAgent | ||
from autogen import token_count_utils | ||
|
||
|
||
class TransformChatHistory: | ||
""" | ||
An agent's chat history with other agents is a common context that it uses to generate a reply. | ||
This capability allows the agent to transform its chat history prior to using it to generate a reply. | ||
It does not permanently modify the chat history, but rather processes it on every invocation. | ||
This capability class enables various strategies to transform chat history, such as: | ||
- Truncate messages: Truncate each message to first maximum number of tokens. | ||
- Limit number of messages: Truncate the chat history to a maximum number of (recent) messages. | ||
- Limit number of tokens: Truncate the chat history to number of recent N messages that fit in | ||
maximum number of tokens. | ||
Note that the system message, because of its special significance, is always kept as is. | ||
The three strategies can be combined. For example, when each of these parameters are specified | ||
they are used in the following order: | ||
1. First truncate messages to a maximum number of tokens | ||
2. Second, it limits the number of message to keep | ||
3. Third, it limits the total number of tokens in the chat history | ||
Args: | ||
max_tokens_per_message (Optional[int]): Maximum number of tokens to keep in each message. | ||
max_messages (Optional[int]): Maximum number of messages to keep in the context. | ||
max_tokens (Optional[int]): Maximum number of tokens to keep in the context. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
max_tokens_per_message: Optional[int] = None, | ||
max_messages: Optional[int] = None, | ||
max_tokens: Optional[int] = None, | ||
): | ||
self.max_tokens_per_message = max_tokens_per_message if max_tokens_per_message else sys.maxsize | ||
self.max_messages = max_messages if max_messages else sys.maxsize | ||
self.max_tokens = max_tokens if max_tokens else sys.maxsize | ||
|
||
def add_to_agent(self, agent: ConversableAgent): | ||
""" | ||
Adds TransformChatHistory capability to the given agent. | ||
""" | ||
agent.register_hook(hookable_method=agent.process_all_messages, hook=self._transform_messages) | ||
|
||
def _transform_messages(self, messages: List[Dict]) -> List[Dict]: | ||
""" | ||
Args: | ||
messages: List of messages to process. | ||
Returns: | ||
List of messages with the first system message and the last max_messages messages. | ||
""" | ||
processed_messages = [] | ||
messages = messages.copy() | ||
rest_messages = messages | ||
|
||
# check if the first message is a system message and append it to the processed messages | ||
if len(messages) > 0: | ||
if messages[0]["role"] == "system": | ||
msg = messages[0] | ||
processed_messages.append(msg) | ||
rest_messages = messages[1:] | ||
|
||
processed_messages_tokens = 0 | ||
for msg in messages: | ||
msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message) | ||
|
||
# iterate through rest of the messages and append them to the processed messages | ||
for msg in rest_messages[-self.max_messages :]: | ||
msg_tokens = token_count_utils.count_token(msg["content"]) | ||
if processed_messages_tokens + msg_tokens > self.max_tokens: | ||
break | ||
processed_messages.append(msg) | ||
processed_messages_tokens += msg_tokens | ||
|
||
total_tokens = 0 | ||
for msg in messages: | ||
total_tokens += token_count_utils.count_token(msg["content"]) | ||
|
||
num_truncated = len(messages) - len(processed_messages) | ||
if num_truncated > 0 or total_tokens > processed_messages_tokens: | ||
print(colored(f"Truncated {len(messages) - len(processed_messages)} messages.", "yellow")) | ||
print(colored(f"Truncated {total_tokens - processed_messages_tokens} tokens.", "yellow")) | ||
return processed_messages | ||
|
||
|
||
def truncate_str_to_tokens(text: str, max_tokens: int) -> str: | ||
""" | ||
Truncate a string so that number of tokens in less than max_tokens. | ||
Args: | ||
content: String to process. | ||
max_tokens: Maximum number of tokens to keep. | ||
Returns: | ||
Truncated string. | ||
""" | ||
truncated_string = "" | ||
for char in text: | ||
truncated_string += char | ||
if token_count_utils.count_token(truncated_string) == max_tokens: | ||
break | ||
return truncated_string |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.