Skip to content

Commit

Permalink
Merge pull request geekan#1058 from iorisa/fixbug/issues/1016
Browse files Browse the repository at this point in the history
fixbug: geekan#1016
  • Loading branch information
geekan authored Mar 21, 2024
2 parents dd348d0 + f6a11d5 commit 91053f0
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 34 deletions.
4 changes: 2 additions & 2 deletions metagpt/actions/di/write_analysis_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
STRUCTUAL_PROMPT,
)
from metagpt.schema import Message, Plan
from metagpt.utils.common import CodeParser, process_message, remove_comments
from metagpt.utils.common import CodeParser, remove_comments


class WriteAnalysisCode(Action):
Expand Down Expand Up @@ -50,7 +50,7 @@ async def run(
)

working_memory = working_memory or []
context = process_message([Message(content=structual_prompt, role="user")] + working_memory)
context = self.llm.format_msg([Message(content=structual_prompt, role="user")] + working_memory)

# LLM call
if use_reflection:
Expand Down
22 changes: 22 additions & 0 deletions metagpt/provider/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,28 @@ def _assistant_msg(self, msg: str) -> dict[str, str]:
def _system_msg(self, msg: str) -> dict[str, str]:
return {"role": "system", "content": msg}

def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
"""convert messages to list[dict]."""
from metagpt.schema import Message

if not isinstance(messages, list):
messages = [messages]

processed_messages = []
for msg in messages:
if isinstance(msg, str):
processed_messages.append({"role": "user", "content": msg})
elif isinstance(msg, dict):
assert set(msg.keys()) == set(["role", "content"])
processed_messages.append(msg)
elif isinstance(msg, Message):
processed_messages.append(msg.to_dict())
else:
raise ValueError(
f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"
)
return processed_messages

def _system_msgs(self, msgs: list[str]) -> list[dict[str, str]]:
return [self._system_msg(msg) for msg in msgs]

Expand Down
30 changes: 30 additions & 0 deletions metagpt/provider/google_gemini_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message


class GeminiGenerativeModel(GenerativeModel):
Expand Down Expand Up @@ -61,6 +62,35 @@ def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) ->
def _assistant_msg(self, msg: str) -> dict[str, str]:
return {"role": "model", "parts": [msg]}

def _system_msg(self, msg: str) -> dict[str, str]:
return {"role": "user", "parts": [msg]}

def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
"""convert messages to list[dict]."""
from metagpt.schema import Message

if not isinstance(messages, list):
messages = [messages]

# REF: https://ai.google.dev/tutorials/python_quickstart
# As a dictionary, the message requires `role` and `parts` keys.
# The role in a conversation can either be the `user`, which provides the prompts,
# or `model`, which provides the responses.
processed_messages = []
for msg in messages:
if isinstance(msg, str):
processed_messages.append({"role": "user", "parts": [msg]})
elif isinstance(msg, dict):
assert set(msg.keys()) == set(["role", "parts"])
processed_messages.append(msg)
elif isinstance(msg, Message):
processed_messages.append({"role": "user" if msg.role == "user" else "model", "parts": [msg.content]})
else:
raise ValueError(
f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"
)
return processed_messages

def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
return kwargs
Expand Down
9 changes: 2 additions & 7 deletions metagpt/provider/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.common import (
CodeParser,
decode_image,
log_and_reraise,
process_message,
)
from metagpt.utils.common import CodeParser, decode_image, log_and_reraise
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
Expand Down Expand Up @@ -150,7 +145,7 @@ async def acompletion_text(self, messages: list[dict], stream=False, timeout=3)
async def _achat_completion_function(
self, messages: list[dict], timeout: int = 3, **chat_configs
) -> ChatCompletion:
messages = process_message(messages)
messages = self.format_msg(messages)
kwargs = self._cons_kwargs(messages=messages, timeout=timeout, **chat_configs)
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)
Expand Down
23 changes: 0 additions & 23 deletions metagpt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,29 +802,6 @@ def decode_image(img_url_or_b64: str) -> Image:
return img


def process_message(messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
"""convert messages to list[dict]."""
from metagpt.schema import Message

# 全部转成list
if not isinstance(messages, list):
messages = [messages]

# 转成list[dict]
processed_messages = []
for msg in messages:
if isinstance(msg, str):
processed_messages.append({"role": "user", "content": msg})
elif isinstance(msg, dict):
assert set(msg.keys()) == set(["role", "content"])
processed_messages.append(msg)
elif isinstance(msg, Message):
processed_messages.append(msg.to_dict())
else:
raise ValueError(f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!")
return processed_messages


def log_and_reraise(retry_state: RetryCallState):
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
logger.warning(
Expand Down
3 changes: 1 addition & 2 deletions tests/mock/mock_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.openai_api import OpenAILLM
from metagpt.schema import Message
from metagpt.utils.common import process_message

OriginalLLM = OpenAILLM if config.llm.api_type == LLMType.OPENAI else AzureOpenAILLM

Expand Down Expand Up @@ -105,7 +104,7 @@ async def aask_batch(self, msgs: list, timeout=3) -> str:
return rsp

async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
msg_key = json.dumps(process_message(messages), ensure_ascii=False)
msg_key = json.dumps(self.format_msg(messages), ensure_ascii=False)
rsp = await self._mock_rsp(msg_key, self.original_aask_code, messages, **kwargs)
return rsp

Expand Down

0 comments on commit 91053f0

Please sign in to comment.