Skip to content

Commit

Permalink
Merge pull request geekan#70 from better629/feat_ltmem
Browse files Browse the repository at this point in the history
add memory_storage using ann to avoid similar idea repetitive execution
  • Loading branch information
geekan authored Jul 24, 2023
2 parents 00f4ebe + cddb3aa commit c37613c
Show file tree
Hide file tree
Showing 16 changed files with 526 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ allure-results
docs/scripts/set_env.sh
key.yaml
output.json
data
data/output_add.json
data.ms
examples/nb/
Expand Down
3 changes: 3 additions & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ RPM: 10
## Use SD service, based on https://github.com/AUTOMATIC1111/stable-diffusion-webui
SD_URL: "YOUR_SD_URL"
SD_T2I_API: "/sdapi/v1/txt2img"

#### for Execution
#LONG_TERM_MEMORY: false
5 changes: 5 additions & 0 deletions metagpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def __init__(self, yaml_file=default_yaml_file):
self.google_api_key = self._get('GOOGLE_API_KEY')
self.google_cse_id = self._get('GOOGLE_CSE_ID')
self.search_engine = self._get('SEARCH_ENGINE', SearchEngineType.SERPAPI_GOOGLE)

self.long_term_memory = self._get('LONG_TERM_MEMORY', False)
if self.long_term_memory:
logger.warning("LONG_TERM_MEMORY is True")

self.max_budget = self._get('MAX_BUDGET', 10.0)
self.total_cost = 0.0

Expand Down
2 changes: 2 additions & 0 deletions metagpt/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@ def get_project_root():
API_QUESTIONS_PATH = UT_PATH / "files/question/"
YAPI_URL = "http://yapi.deepwisdomai.com/"
TMP = PROJECT_ROOT / 'tmp'

MEM_TTL = 24 * 30 * 3600
2 changes: 1 addition & 1 deletion metagpt/document_store/faiss_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, raw_data: Path, cache_dir=None, meta_col='source', content_co
def _load(self) -> Optional["FaissStore"]:
index_file, store_file = self._get_index_and_store_fname()
if not (index_file.exists() and store_file.exists()):
logger.warning("Download data from http://pan.deepwisdomai.com/library/13ff7974-fbc7-40ab-bc10-041fdc97adbd/LLM/00_QCS-%E5%90%91%E9%87%8F%E6%95%B0%E6%8D%AE/qcs")
logger.info("Missing at least one of index_file/store_file, load failed and return None")
return None
index = faiss.read_index(str(index_file))
with open(str(store_file), "rb") as f:
Expand Down
2 changes: 2 additions & 0 deletions metagpt/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@
"""

from metagpt.memory.memory import Memory
from metagpt.memory.longterm_memory import LongTermMemory

71 changes: 71 additions & 0 deletions metagpt/memory/longterm_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the implement of Long-term memory

from typing import Iterable, Type

from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.memory import Memory
from metagpt.memory.memory_storage import MemoryStorage


class LongTermMemory(Memory):
"""
The Long-term memory for Roles
- recover memory when it staruped
- update memory when it changed
"""

def __init__(self):
self.memory_storage: MemoryStorage = MemoryStorage()
super(LongTermMemory, self).__init__()
self.rc = None # RoleContext
self.msg_from_recover = False

def recover_memory(self, role_id: str, rc: "RoleContext"):
messages = self.memory_storage.recover_memory(role_id)
self.rc = rc
if not self.memory_storage.is_initialized:
logger.warning(f'It may the first time to run Agent {role_id}, the long-term memory is empty')
else:
logger.warning(f'Agent {role_id} has existed memory storage with {len(messages)} messages '
f'and has recovered them.')
self.msg_from_recover = True
self.add_batch(messages)
self.msg_from_recover = False

def add(self, message: Message):
super(LongTermMemory, self).add(message)
for action in self.rc.watch:
if message.cause_by == action and not self.msg_from_recover:
# currently, only add role's watching messages to its memory_storage
# and ignore adding messages from recover repeatedly
self.memory_storage.add(message)

def remember(self, observed: list[Message], k=10) -> list[Message]:
"""
remember the most similar k memories from observed Messages, return all when k=0
1. remember the short-term memory(stm) news
2. integrate the stm news with ltm(long-term memory) news
"""
stm_news = super(LongTermMemory, self).remember(observed) # shot-term memory news
if not self.memory_storage.is_initialized:
# memory_storage hasn't initialized, use default `remember` to get stm_news
return stm_news

ltm_news: list[Message] = []
for mem in stm_news:
# integrate stm & ltm
mem_searched = self.memory_storage.search(mem)
if len(mem_searched) > 0:
ltm_news.append(mem)
return ltm_news[-k:]

def delete(self, message: Message):
super(LongTermMemory, self).delete(message)
# TODO delete message in memory_storage

def clear(self):
super(LongTermMemory, self).clear()
self.memory_storage.clean()
10 changes: 10 additions & 0 deletions metagpt/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ def get(self, k=0) -> list[Message]:
"""Return the most recent k memories, return all when k=0"""
return self.storage[-k:]

def remember(self, observed: list[Message], k=10) -> list[Message]:
"""remember the most recent k memories from observed Messages, return all when k=0"""
already_observed = self.get(k)
news: list[Message] = []
for i in observed:
if i in already_observed:
continue
news.append(i)
return news

def get_by_action(self, action: Type[Action]) -> list[Message]:
"""Return all messages triggered by a specified Action"""
return self.index[action]
Expand Down
106 changes: 106 additions & 0 deletions metagpt/memory/memory_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the implement of memory storage

from typing import List
from pathlib import Path

from langchain.vectorstores.faiss import FAISS

from metagpt.const import DATA_PATH, MEM_TTL
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.serialize import serialize_message, deserialize_message
from metagpt.document_store.faiss_store import FaissStore


class MemoryStorage(FaissStore):
"""
The memory storage with Faiss as ANN search engine
"""

def __init__(self, mem_ttl: int = MEM_TTL):
self.role_id: str = None
self.role_mem_path: str = None
self.mem_ttl: int = mem_ttl # later use
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
self._initialized: bool = False

self.store: FAISS = None # Faiss engine

@property
def is_initialized(self) -> bool:
return self._initialized

def recover_memory(self, role_id: str) -> List[Message]:
self.role_id = role_id
self.role_mem_path = Path(DATA_PATH / f'role_mem/{self.role_id}/')
self.role_mem_path.mkdir(parents=True, exist_ok=True)

self.store = self._load()
messages = []
if not self.store:
# TODO init `self.store` under here with raw faiss api instead under `add`
pass
else:
for _id, document in self.store.docstore._dict.items():
messages.append(deserialize_message(document.metadata.get("message_ser")))
self._initialized = True

return messages

def _get_index_and_store_fname(self):
if not self.role_mem_path:
logger.error(f'You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory')
return None, None
index_fpath = Path(self.role_mem_path / f'{self.role_id}.index')
storage_fpath = Path(self.role_mem_path / f'{self.role_id}.pkl')
return index_fpath, storage_fpath

def persist(self):
super(MemoryStorage, self).persist()
logger.debug(f'Agent {self.role_id} persist memory into local')

def add(self, message: Message) -> bool:
""" add message into memory storage"""
docs = [message.content]
metadatas = [{"message_ser": serialize_message(message)}]
if not self.store:
# init Faiss
self.store = self._write(docs, metadatas)
self._initialized = True
else:
self.store.add_texts(texts=docs, metadatas=metadatas)
self.persist()
logger.info(f"Agent {self.role_id}'s memory_storage add a message")

def search(self, message: Message, k=4) -> List[Message]:
"""search for dissimilar messages"""
if not self.store:
return []

resp = self.store.similarity_search_with_score(
query=message.content,
k=k
)
# filter the result which score is smaller than the threshold
filtered_resp = []
for item, score in resp:
# the smaller score means more similar relation
if score < self.threshold:
continue
# convert search result into Memory
metadata = item.metadata
new_mem = deserialize_message(metadata.get("message_ser"))
filtered_resp.append(new_mem)
return filtered_resp

def clean(self):
index_fpath, storage_fpath = self._get_index_and_store_fname()
if index_fpath and index_fpath.exists():
index_fpath.unlink(missing_ok=True)
if storage_fpath and storage_fpath.exists():
storage_fpath.unlink(missing_ok=True)

self.store = None
self._initialized = False
20 changes: 12 additions & 8 deletions metagpt/roles/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
from pydantic import BaseModel, Field

# from metagpt.environment import Environment
from metagpt.config import CONFIG
from metagpt.actions import Action, ActionOutput
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.memory import Memory, LongTermMemory
from metagpt.schema import Message

PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """
Expand Down Expand Up @@ -65,13 +66,19 @@ class RoleContext(BaseModel):
"""角色运行时上下文"""
env: 'Environment' = Field(default=None)
memory: Memory = Field(default_factory=Memory)
long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory)
state: int = Field(default=0)
todo: Action = Field(default=None)
watch: set[Type[Action]] = Field(default_factory=set)

class Config:
arbitrary_types_allowed = True

def check(self, role_id: str):
if hasattr(CONFIG, "long_term_memory") and CONFIG.long_term_memory:
self.long_term_memory.recover_memory(role_id, self)
self.memory = self.long_term_memory # use memory to act as long_term_memory for unify operation

@property
def important_memory(self) -> list[Message]:
"""获得关注动作对应的信息"""
Expand All @@ -90,6 +97,7 @@ def __init__(self, name="", profile="", goal="", constraints="", desc=""):
self._setting = RoleSetting(name=name, profile=profile, goal=goal, constraints=constraints, desc=desc)
self._states = []
self._actions = []
self._role_id = str(self._setting)
self._rc = RoleContext()

def _reset(self):
Expand All @@ -110,6 +118,8 @@ def _init_actions(self, actions):
def _watch(self, actions: Iterable[Type[Action]]):
"""监听对应的行为"""
self._rc.watch.update(actions)
# check RoleContext after adding watch actions
self._rc.check(self._role_id)

def _set_state(self, state):
"""Update the current state."""
Expand Down Expand Up @@ -174,13 +184,7 @@ async def _observe(self) -> int:

observed = self._rc.env.memory.get_by_actions(self._rc.watch)

already_observed = self._rc.memory.get()

news: list[Message] = []
for i in observed:
if i in already_observed:
continue
news.append(i)
news = self._rc.memory.remember(observed) # remember recent exact or similar memories

for i in env_msgs:
self.recv(i)
Expand Down
75 changes: 75 additions & 0 deletions metagpt/utils/serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the implement of serialization and deserialization

import copy
from typing import Tuple, List, Type, Union, Dict
import pickle
from collections import defaultdict
from pydantic import create_model

from metagpt.schema import Message
from metagpt.actions.action import Action
from metagpt.actions.action_output import ActionOutput


def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
"""
directly traverse the `properties` in the first level.
schema structure likes
```
{
"title":"prd",
"type":"object",
"properties":{
"Original Requirements":{
"title":"Original Requirements",
"type":"string"
},
},
"required":[
"Original Requirements",
]
}
```
"""
mapping = dict()
for field, property in schema['properties'].items():
if property['type'] == 'string':
mapping[field] = (str, ...)
elif property['type'] == 'array' and property['items']['type'] == 'string':
mapping[field] = (List[str], ...)
elif property['type'] == 'array' and property['items']['type'] == 'array':
# here only consider the `Tuple[str, str]` situation
mapping[field] = (List[Tuple[str, str]], ...)
return mapping


def serialize_message(message: Message):
message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference
ic = message_cp.instruct_content
if ic:
# model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly
schema = ic.schema()
mapping = actionoutout_schema_to_mapping(schema)

message_cp.instruct_content = {
'class': schema['title'],
'mapping': mapping,
'value': ic.dict()
}
msg_ser = pickle.dumps(message_cp)

return msg_ser


def deserialize_message(message_ser: str) -> Message:
message = pickle.loads(message_ser)
if message.instruct_content:
ic = message.instruct_content
ic_obj = ActionOutput.create_model_class(class_name=ic['class'],
mapping=ic['mapping'])
ic_new = ic_obj(**ic['value'])
message.instruct_content = ic_new

return message
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ tiktoken==0.3.3
tqdm==4.64.0
#unstructured[local-inference]
anthropic==0.3.6
typing-inspect==0.8.0
typing_extensions==4.5.0
3 changes: 3 additions & 0 deletions tests/metagpt/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
Loading

0 comments on commit c37613c

Please sign in to comment.