forked from geekan/MetaGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request geekan#70 from better629/feat_ltmem
add memory_storage using ann to avoid similar idea repetitive execution
- Loading branch information
Showing
16 changed files
with
526 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,5 @@ | |
""" | ||
|
||
from metagpt.memory.memory import Memory | ||
from metagpt.memory.longterm_memory import LongTermMemory | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# @Desc : the implement of Long-term memory | ||
|
||
from typing import Iterable, Type | ||
|
||
from metagpt.logs import logger | ||
from metagpt.schema import Message | ||
from metagpt.memory import Memory | ||
from metagpt.memory.memory_storage import MemoryStorage | ||
|
||
|
||
class LongTermMemory(Memory): | ||
""" | ||
The Long-term memory for Roles | ||
- recover memory when it staruped | ||
- update memory when it changed | ||
""" | ||
|
||
def __init__(self): | ||
self.memory_storage: MemoryStorage = MemoryStorage() | ||
super(LongTermMemory, self).__init__() | ||
self.rc = None # RoleContext | ||
self.msg_from_recover = False | ||
|
||
def recover_memory(self, role_id: str, rc: "RoleContext"): | ||
messages = self.memory_storage.recover_memory(role_id) | ||
self.rc = rc | ||
if not self.memory_storage.is_initialized: | ||
logger.warning(f'It may the first time to run Agent {role_id}, the long-term memory is empty') | ||
else: | ||
logger.warning(f'Agent {role_id} has existed memory storage with {len(messages)} messages ' | ||
f'and has recovered them.') | ||
self.msg_from_recover = True | ||
self.add_batch(messages) | ||
self.msg_from_recover = False | ||
|
||
def add(self, message: Message): | ||
super(LongTermMemory, self).add(message) | ||
for action in self.rc.watch: | ||
if message.cause_by == action and not self.msg_from_recover: | ||
# currently, only add role's watching messages to its memory_storage | ||
# and ignore adding messages from recover repeatedly | ||
self.memory_storage.add(message) | ||
|
||
def remember(self, observed: list[Message], k=10) -> list[Message]: | ||
""" | ||
remember the most similar k memories from observed Messages, return all when k=0 | ||
1. remember the short-term memory(stm) news | ||
2. integrate the stm news with ltm(long-term memory) news | ||
""" | ||
stm_news = super(LongTermMemory, self).remember(observed) # shot-term memory news | ||
if not self.memory_storage.is_initialized: | ||
# memory_storage hasn't initialized, use default `remember` to get stm_news | ||
return stm_news | ||
|
||
ltm_news: list[Message] = [] | ||
for mem in stm_news: | ||
# integrate stm & ltm | ||
mem_searched = self.memory_storage.search(mem) | ||
if len(mem_searched) > 0: | ||
ltm_news.append(mem) | ||
return ltm_news[-k:] | ||
|
||
def delete(self, message: Message): | ||
super(LongTermMemory, self).delete(message) | ||
# TODO delete message in memory_storage | ||
|
||
def clear(self): | ||
super(LongTermMemory, self).clear() | ||
self.memory_storage.clean() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# @Desc : the implement of memory storage | ||
|
||
from typing import List | ||
from pathlib import Path | ||
|
||
from langchain.vectorstores.faiss import FAISS | ||
|
||
from metagpt.const import DATA_PATH, MEM_TTL | ||
from metagpt.logs import logger | ||
from metagpt.schema import Message | ||
from metagpt.utils.serialize import serialize_message, deserialize_message | ||
from metagpt.document_store.faiss_store import FaissStore | ||
|
||
|
||
class MemoryStorage(FaissStore): | ||
""" | ||
The memory storage with Faiss as ANN search engine | ||
""" | ||
|
||
def __init__(self, mem_ttl: int = MEM_TTL): | ||
self.role_id: str = None | ||
self.role_mem_path: str = None | ||
self.mem_ttl: int = mem_ttl # later use | ||
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories | ||
self._initialized: bool = False | ||
|
||
self.store: FAISS = None # Faiss engine | ||
|
||
@property | ||
def is_initialized(self) -> bool: | ||
return self._initialized | ||
|
||
def recover_memory(self, role_id: str) -> List[Message]: | ||
self.role_id = role_id | ||
self.role_mem_path = Path(DATA_PATH / f'role_mem/{self.role_id}/') | ||
self.role_mem_path.mkdir(parents=True, exist_ok=True) | ||
|
||
self.store = self._load() | ||
messages = [] | ||
if not self.store: | ||
# TODO init `self.store` under here with raw faiss api instead under `add` | ||
pass | ||
else: | ||
for _id, document in self.store.docstore._dict.items(): | ||
messages.append(deserialize_message(document.metadata.get("message_ser"))) | ||
self._initialized = True | ||
|
||
return messages | ||
|
||
def _get_index_and_store_fname(self): | ||
if not self.role_mem_path: | ||
logger.error(f'You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory') | ||
return None, None | ||
index_fpath = Path(self.role_mem_path / f'{self.role_id}.index') | ||
storage_fpath = Path(self.role_mem_path / f'{self.role_id}.pkl') | ||
return index_fpath, storage_fpath | ||
|
||
def persist(self): | ||
super(MemoryStorage, self).persist() | ||
logger.debug(f'Agent {self.role_id} persist memory into local') | ||
|
||
def add(self, message: Message) -> bool: | ||
""" add message into memory storage""" | ||
docs = [message.content] | ||
metadatas = [{"message_ser": serialize_message(message)}] | ||
if not self.store: | ||
# init Faiss | ||
self.store = self._write(docs, metadatas) | ||
self._initialized = True | ||
else: | ||
self.store.add_texts(texts=docs, metadatas=metadatas) | ||
self.persist() | ||
logger.info(f"Agent {self.role_id}'s memory_storage add a message") | ||
|
||
def search(self, message: Message, k=4) -> List[Message]: | ||
"""search for dissimilar messages""" | ||
if not self.store: | ||
return [] | ||
|
||
resp = self.store.similarity_search_with_score( | ||
query=message.content, | ||
k=k | ||
) | ||
# filter the result which score is smaller than the threshold | ||
filtered_resp = [] | ||
for item, score in resp: | ||
# the smaller score means more similar relation | ||
if score < self.threshold: | ||
continue | ||
# convert search result into Memory | ||
metadata = item.metadata | ||
new_mem = deserialize_message(metadata.get("message_ser")) | ||
filtered_resp.append(new_mem) | ||
return filtered_resp | ||
|
||
def clean(self): | ||
index_fpath, storage_fpath = self._get_index_and_store_fname() | ||
if index_fpath and index_fpath.exists(): | ||
index_fpath.unlink(missing_ok=True) | ||
if storage_fpath and storage_fpath.exists(): | ||
storage_fpath.unlink(missing_ok=True) | ||
|
||
self.store = None | ||
self._initialized = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# @Desc : the implement of serialization and deserialization | ||
|
||
import copy | ||
from typing import Tuple, List, Type, Union, Dict | ||
import pickle | ||
from collections import defaultdict | ||
from pydantic import create_model | ||
|
||
from metagpt.schema import Message | ||
from metagpt.actions.action import Action | ||
from metagpt.actions.action_output import ActionOutput | ||
|
||
|
||
def actionoutout_schema_to_mapping(schema: Dict) -> Dict: | ||
""" | ||
directly traverse the `properties` in the first level. | ||
schema structure likes | ||
``` | ||
{ | ||
"title":"prd", | ||
"type":"object", | ||
"properties":{ | ||
"Original Requirements":{ | ||
"title":"Original Requirements", | ||
"type":"string" | ||
}, | ||
}, | ||
"required":[ | ||
"Original Requirements", | ||
] | ||
} | ||
``` | ||
""" | ||
mapping = dict() | ||
for field, property in schema['properties'].items(): | ||
if property['type'] == 'string': | ||
mapping[field] = (str, ...) | ||
elif property['type'] == 'array' and property['items']['type'] == 'string': | ||
mapping[field] = (List[str], ...) | ||
elif property['type'] == 'array' and property['items']['type'] == 'array': | ||
# here only consider the `Tuple[str, str]` situation | ||
mapping[field] = (List[Tuple[str, str]], ...) | ||
return mapping | ||
|
||
|
||
def serialize_message(message: Message): | ||
message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference | ||
ic = message_cp.instruct_content | ||
if ic: | ||
# model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly | ||
schema = ic.schema() | ||
mapping = actionoutout_schema_to_mapping(schema) | ||
|
||
message_cp.instruct_content = { | ||
'class': schema['title'], | ||
'mapping': mapping, | ||
'value': ic.dict() | ||
} | ||
msg_ser = pickle.dumps(message_cp) | ||
|
||
return msg_ser | ||
|
||
|
||
def deserialize_message(message_ser: str) -> Message: | ||
message = pickle.loads(message_ser) | ||
if message.instruct_content: | ||
ic = message.instruct_content | ||
ic_obj = ActionOutput.create_model_class(class_name=ic['class'], | ||
mapping=ic['mapping']) | ||
ic_new = ic_obj(**ic['value']) | ||
message.instruct_content = ic_new | ||
|
||
return message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# @Desc : |
Oops, something went wrong.