Skip to content

Commit

Permalink
Proxy PR for Long Context Capability 1513 (microsoft#1591)
Browse files Browse the repository at this point in the history
* Add new capability to handle long context

* Make print conditional

* Remove superfluous comment

* Fix msg order

* Allow user to specify max_tokens

* Add ability to specify max_tokens per message; improve name

* Improve doc and readability

* Add tests

* Improve documentation and add tests per Erik and Chi's feedback

* Update notebook

* Update doc string of add to agents

* Improve doc string

* improve notebook

* Update github workflows for context handling

* Update docstring

* update notebook to use raw config list.

* Update contrib-openai.yml remove _target

* Fix code formatting

* Fix workflow file

* Update .github/workflows/contrib-openai.yml

---------

Co-authored-by: Eric Zhu <[email protected]>
Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
3 people authored Feb 8, 2024
1 parent a3c3317 commit 47d6c75
Show file tree
Hide file tree
Showing 7 changed files with 920 additions and 6 deletions.
39 changes: 39 additions & 0 deletions .github/workflows/contrib-openai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,42 @@ jobs:
with:
file: ./coverage.xml
flags: unittests
ContextHandling:
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.11"]
runs-on: ${{ matrix.os }}
environment: openai1
steps:
# checkout to pr branch
- name: Checkout
uses: actions/checkout@v3
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies
run: |
docker --version
python -m pip install --upgrade pip wheel
pip install -e .
python -c "import autogen"
pip install coverage pytest
- name: Coverage
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AZURE_OPENAI_API_KEY: ${{ secrets.AZURE_OPENAI_API_KEY }}
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
BING_API_KEY: ${{ secrets.BING_API_KEY }}
run: |
coverage run -a -m pytest test/agentchat/contrib/capabilities/test_context_handling.py
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
37 changes: 37 additions & 0 deletions .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,40 @@ jobs:
with:
file: ./coverage.xml
flags: unittests

ContextHandling:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-2019]
python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest
- name: Install packages and dependencies for Context Handling
run: |
pip install -e .
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
fi
- name: Coverage
run: |
pip install coverage>=5.3
coverage run -a -m pytest test/agentchat/contrib/capabilities/test_context_handling.py --skip-openai
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
5 changes: 0 additions & 5 deletions autogen/agentchat/contrib/capabilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
from .teachability import Teachability
from .agent_capability import AgentCapability


__all__ = ["Teachability", "AgentCapability"]
108 changes: 108 additions & 0 deletions autogen/agentchat/contrib/capabilities/context_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import sys
from termcolor import colored
from typing import Dict, Optional, List
from autogen import ConversableAgent
from autogen import token_count_utils


class TransformChatHistory:
"""
An agent's chat history with other agents is a common context that it uses to generate a reply.
This capability allows the agent to transform its chat history prior to using it to generate a reply.
It does not permanently modify the chat history, but rather processes it on every invocation.
This capability class enables various strategies to transform chat history, such as:
- Truncate messages: Truncate each message to first maximum number of tokens.
- Limit number of messages: Truncate the chat history to a maximum number of (recent) messages.
- Limit number of tokens: Truncate the chat history to number of recent N messages that fit in
maximum number of tokens.
Note that the system message, because of its special significance, is always kept as is.
The three strategies can be combined. For example, when each of these parameters are specified
they are used in the following order:
1. First truncate messages to a maximum number of tokens
2. Second, it limits the number of message to keep
3. Third, it limits the total number of tokens in the chat history
Args:
max_tokens_per_message (Optional[int]): Maximum number of tokens to keep in each message.
max_messages (Optional[int]): Maximum number of messages to keep in the context.
max_tokens (Optional[int]): Maximum number of tokens to keep in the context.
"""

def __init__(
self,
*,
max_tokens_per_message: Optional[int] = None,
max_messages: Optional[int] = None,
max_tokens: Optional[int] = None,
):
self.max_tokens_per_message = max_tokens_per_message if max_tokens_per_message else sys.maxsize
self.max_messages = max_messages if max_messages else sys.maxsize
self.max_tokens = max_tokens if max_tokens else sys.maxsize

def add_to_agent(self, agent: ConversableAgent):
"""
Adds TransformChatHistory capability to the given agent.
"""
agent.register_hook(hookable_method=agent.process_all_messages, hook=self._transform_messages)

def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
"""
Args:
messages: List of messages to process.
Returns:
List of messages with the first system message and the last max_messages messages.
"""
processed_messages = []
messages = messages.copy()
rest_messages = messages

# check if the first message is a system message and append it to the processed messages
if len(messages) > 0:
if messages[0]["role"] == "system":
msg = messages[0]
processed_messages.append(msg)
rest_messages = messages[1:]

processed_messages_tokens = 0
for msg in messages:
msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message)

# iterate through rest of the messages and append them to the processed messages
for msg in rest_messages[-self.max_messages :]:
msg_tokens = token_count_utils.count_token(msg["content"])
if processed_messages_tokens + msg_tokens > self.max_tokens:
break
processed_messages.append(msg)
processed_messages_tokens += msg_tokens

total_tokens = 0
for msg in messages:
total_tokens += token_count_utils.count_token(msg["content"])

num_truncated = len(messages) - len(processed_messages)
if num_truncated > 0 or total_tokens > processed_messages_tokens:
print(colored(f"Truncated {len(messages) - len(processed_messages)} messages.", "yellow"))
print(colored(f"Truncated {total_tokens - processed_messages_tokens} tokens.", "yellow"))
return processed_messages


def truncate_str_to_tokens(text: str, max_tokens: int) -> str:
"""
Truncate a string so that number of tokens in less than max_tokens.
Args:
content: String to process.
max_tokens: Maximum number of tokens to keep.
Returns:
Truncated string.
"""
truncated_string = ""
for char in text:
truncated_string += char
if token_count_utils.count_token(truncated_string) == max_tokens:
break
return truncated_string
25 changes: 24 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def __init__(

# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
# New hookable methods should be added to this list as required to support new agent capabilities.
self.hook_lists = {self.process_last_message: []} # This is currently the only hookable method.
self.hook_lists = {self.process_last_message: [], self.process_all_messages: []}

def register_reply(
self,
Expand Down Expand Up @@ -1528,6 +1528,10 @@ def generate_reply(
if messages is None:
messages = self._oai_messages[sender]

# Call the hookable method that gives registered hooks a chance to process all messages.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages(messages)

# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_message(messages)
Expand Down Expand Up @@ -1584,6 +1588,10 @@ async def a_generate_reply(
if messages is None:
messages = self._oai_messages[sender]

# Call the hookable method that gives registered hooks a chance to process all messages.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages(messages)

# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_message(messages)
Expand Down Expand Up @@ -2211,6 +2219,21 @@ def register_hook(self, hookable_method: Callable, hook: Callable):
assert hook not in hook_list, f"{hook} is already registered as a hook."
hook_list.append(hook)

def process_all_messages(self, messages: List[Dict]) -> List[Dict]:
"""
Calls any registered capability hooks to process all messages, potentially modifying the messages.
"""
hook_list = self.hook_lists[self.process_all_messages]
# If no hooks are registered, or if there are no messages to process, return the original message list.
if len(hook_list) == 0 or messages is None:
return messages

# Call each hook (in order of registration) to process the messages.
processed_messages = messages
for hook in hook_list:
processed_messages = hook(processed_messages)
return processed_messages

def process_last_message(self, messages):
"""
Calls any registered capability hooks to use and potentially modify the text of the last message,
Expand Down
Loading

0 comments on commit 47d6c75

Please sign in to comment.