Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plugins/fix backend ci errors #12615

Merged
merged 14 commits into from
Jan 10, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,46 @@ def invoke(
content_list = []
usage = LLMUsage.empty_usage()
system_fingerprint = None
tools_calls: list[AssistantPromptMessage.ToolCall] = []

def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_name: str):
if not tool_name:
return tools_calls[-1]

tool_call = next(
(tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None
)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id="",
type="",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
)
tools_calls.append(tool_call)

return tool_call

for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments

for chunk in result:
if isinstance(chunk.delta.message.content, str):
content += chunk.delta.message.content
elif isinstance(chunk.delta.message.content, list):
content_list.extend(chunk.delta.message.content)
if chunk.delta.message.tool_calls:
increase_tool_call(chunk.delta.message.tool_calls)

usage = chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = chunk.system_fingerprint
Expand All @@ -120,7 +155,10 @@ def invoke(
result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=content or content_list),
message=AssistantPromptMessage(
content=content or content_list,
tool_calls=tools_calls,
),
usage=usage,
system_fingerprint=system_fingerprint,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ def timezone_convert(current_time: str, source_timezone: str, target_timezone: s
datetime_with_tz = input_timezone.localize(local_time)
# timezone convert
converted_datetime = datetime_with_tz.astimezone(output_timezone)
return converted_datetime.strftime(format=time_format)
return converted_datetime.strftime(format=time_format) # type: ignore
except Exception as e:
raise ToolInvokeError(str(e))
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@

class WebscraperProvider(BuiltinToolProviderController):
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
"""
Validate credentials
"""
pass
44 changes: 44 additions & 0 deletions api/tests/integration_tests/model_runtime/__mock/plugin_daemon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
from collections.abc import Callable

import pytest

# import monkeypatch
from _pytest.monkeypatch import MonkeyPatch

from core.plugin.manager.model import PluginModelManager
from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass


def mock_plugin_daemon(
monkeypatch: MonkeyPatch,
) -> Callable[[], None]:
"""
mock openai module

:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""

def unpatch() -> None:
monkeypatch.undo()

monkeypatch.setattr(PluginModelManager, "invoke_llm", MockModelClass.invoke_llm)
monkeypatch.setattr(PluginModelManager, "fetch_model_providers", MockModelClass.fetch_model_providers)
monkeypatch.setattr(PluginModelManager, "get_model_schema", MockModelClass.get_model_schema)

return unpatch


MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"


@pytest.fixture
def setup_model_mock(monkeypatch):
if MOCK:
unpatch = mock_plugin_daemon(monkeypatch)

yield

if MOCK:
unpatch()
249 changes: 249 additions & 0 deletions api/tests/integration_tests/model_runtime/__mock/plugin_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
import datetime
import uuid
from collections.abc import Generator, Sequence
from decimal import Decimal
from json import dumps

# import monkeypatch
from typing import Optional

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
)
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.plugin.manager.model import PluginModelManager


class MockModelClass(PluginModelManager):
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
"""
Fetch model providers for the given tenant.
"""
return [
PluginModelProviderEntity(
id=uuid.uuid4().hex,
created_at=datetime.datetime.now(),
updated_at=datetime.datetime.now(),
provider="openai",
tenant_id=tenant_id,
plugin_unique_identifier="langgenius/openai/openai",
plugin_id="langgenius/openai",
declaration=ProviderEntity(
provider="openai",
label=I18nObject(
en_US="OpenAI",
zh_Hans="OpenAI",
),
description=I18nObject(
en_US="OpenAI",
zh_Hans="OpenAI",
),
icon_small=I18nObject(
en_US="https://example.com/icon_small.png",
zh_Hans="https://example.com/icon_small.png",
),
icon_large=I18nObject(
en_US="https://example.com/icon_large.png",
zh_Hans="https://example.com/icon_large.png",
),
supported_model_types=[ModelType.LLM],
configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL],
models=[
AIModelEntity(
model="gpt-3.5-turbo",
label=I18nObject(
en_US="gpt-3.5-turbo",
zh_Hans="gpt-3.5-turbo",
),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL],
),
AIModelEntity(
model="gpt-3.5-turbo-instruct",
label=I18nObject(
en_US="gpt-3.5-turbo-instruct",
zh_Hans="gpt-3.5-turbo-instruct",
),
model_type=ModelType.LLM,
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={
ModelPropertyKey.MODE: LLMMode.COMPLETION,
},
features=[],
),
],
),
)
]

def get_model_schema(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model_type: str,
model: str,
credentials: dict,
) -> AIModelEntity | None:
"""
Get model schema
"""
return AIModelEntity(
model=model,
label=I18nObject(
en_US="OpenAI",
zh_Hans="OpenAI",
),
model_type=ModelType(model_type),
fetch_from=FetchFrom.PREDEFINED_MODEL,
model_properties={},
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL] if model == "gpt-3.5-turbo" else [],
)

@staticmethod
def generate_function_call(
tools: Optional[list[PromptMessageTool]],
) -> Optional[AssistantPromptMessage.ToolCall]:
if not tools or len(tools) == 0:
return None
function: PromptMessageTool = tools[0]
function_name = function.name
function_parameters = function.parameters
function_parameters_type = function_parameters["type"]
if function_parameters_type != "object":
return None
function_parameters_properties = function_parameters["properties"]
function_parameters_required = function_parameters["required"]
parameters = {}
for parameter_name, parameter in function_parameters_properties.items():
if parameter_name not in function_parameters_required:
continue
parameter_type = parameter["type"]
if parameter_type == "string":
if "enum" in parameter:
if len(parameter["enum"]) == 0:
continue
parameters[parameter_name] = parameter["enum"][0]
else:
parameters[parameter_name] = "kawaii"
elif parameter_type == "integer":
parameters[parameter_name] = 114514
elif parameter_type == "number":
parameters[parameter_name] = 1919810.0
elif parameter_type == "boolean":
parameters[parameter_name] = True

return AssistantPromptMessage.ToolCall(
id=str(uuid.uuid4()),
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=function_name,
arguments=dumps(parameters),
),
)

@staticmethod
def mocked_chat_create_sync(
model: str,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> LLMResult:
tool_call = MockModelClass.generate_function_call(tools=tools)

return LLMResult(
id=str(uuid.uuid4()),
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content="elaina", tool_calls=[tool_call] if tool_call else []),
usage=LLMUsage(
prompt_tokens=2,
completion_tokens=1,
total_tokens=3,
prompt_unit_price=Decimal(0.0001),
completion_unit_price=Decimal(0.0002),
prompt_price_unit=Decimal(1),
prompt_price=Decimal(0.0001),
completion_price_unit=Decimal(1),
completion_price=Decimal(0.0002),
total_price=Decimal(0.0003),
currency="USD",
latency=0.001,
),
)

@staticmethod
def mocked_chat_create_stream(
model: str,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> Generator[LLMResultChunk, None, None]:
tool_call = MockModelClass.generate_function_call(tools=tools)

full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
for i in range(0, len(full_text) + 1):
if i == len(full_text):
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content="",
tool_calls=[tool_call] if tool_call else [],
),
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=full_text[i],
tool_calls=[tool_call] if tool_call else [],
),
usage=LLMUsage(
prompt_tokens=2,
completion_tokens=17,
total_tokens=19,
prompt_unit_price=Decimal(0.0001),
completion_unit_price=Decimal(0.0002),
prompt_price_unit=Decimal(1),
prompt_price=Decimal(0.0001),
completion_price_unit=Decimal(1),
completion_price=Decimal(0.0002),
total_price=Decimal(0.0003),
currency="USD",
latency=0.001,
),
),
)

def invoke_llm(
self: PluginModelManager,
*,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
):
return MockModelClass.mocked_chat_create_stream(model=model, prompt_messages=prompt_messages, tools=tools)
Loading
Loading