diff --git a/examples/demo_user_feedback.py b/examples/demo_user_feedback.py index 1186553..0fe7f8f 100644 --- a/examples/demo_user_feedback.py +++ b/examples/demo_user_feedback.py @@ -25,6 +25,8 @@ from meadow.client.api.anthropic import AnthropicClient from meadow.client.api.api_client import APIClient from meadow.client.api.openai import OpenAIClient +from meadow.client.api.samba import SambaClient +from meadow.client.api.together import TogetherClient from meadow.client.schema import LLMConfig from meadow.database.connector.connector import Connector from meadow.database.connector.duckdb import DuckDBConnector @@ -230,6 +232,10 @@ def run_meadow( api_client: APIClient = AnthropicClient(api_key=api_key) elif api_provider == "openai": api_client = OpenAIClient(api_key=api_key) + elif api_provider == "together": + api_client = TogetherClient(api_key=api_key) + elif api_provider == "samba": + api_client = SambaClient(api_key=api_key) else: raise ValueError(f"Unknown API provider: {api_provider}") client = Client( @@ -237,7 +243,7 @@ def run_meadow( api_client=api_client, model=model, ) - big_client = Client(cache=cache, api_client=api_client, model="gpt-4o") + big_client = Client(cache=cache, api_client=api_client, model=model) llm_config = LLMConfig( max_tokens=1000, temperature=0.0, @@ -288,7 +294,7 @@ def run_meadow( argparser.add_argument( "--model", type=str, - help="Anthropic model.", + help="Model.", default="gpt-4o", ) argparser.add_argument( diff --git a/meadow/agent/planner.py b/meadow/agent/planner.py index 598da22..19d99b2 100644 --- a/meadow/agent/planner.py +++ b/meadow/agent/planner.py @@ -133,6 +133,8 @@ def parse_plan( message += ">" if "" not in message: error_message = "Plan not found in the response. Please provide a plan in the format ...." + if "" not in message: + message += "" if not error_message: inner_steps = re.search(r"(.*)", message, re.DOTALL).group(1) parsed_steps = parse_steps(inner_steps) diff --git a/meadow/client/api/samba.py b/meadow/client/api/samba.py new file mode 100644 index 0000000..c109281 --- /dev/null +++ b/meadow/client/api/samba.py @@ -0,0 +1,128 @@ +import json +import logging +import os +import uuid +from typing import Any + +import requests + +from meadow.client.api.api_client import APIClient +from meadow.client.schema import ( + ChatMessage, + ChatRequest, + ChatResponse, + Choice, + Usage, +) + +logger = logging.getLogger(__name__) + + +def send_request(api_key: str, url: str, data: dict[str, Any]) -> dict: + """ + Sends a POST request to the specified URL with headers and JSON data. + """ + + # Headers including the Authorization and Content-Type + headers = {"Authorization": f"Basic {api_key}", "Content-Type": "application/json"} + + # Sending the POST request + response = requests.post(url, headers=headers, json=data) + + after_end_pair = response.text.rsplit("event: end_event", 1) + if len(after_end_pair) != 2: + raise ValueError("end_event not found in response") + after_end = after_end_pair[1] + text = after_end.split("data: ", 1)[1] + response_dict = json.loads(text) + return response_dict + + +class SambaClient(APIClient): + """Samba LLM Client.""" + + def __init__(self, api_key: str | None, url: str | None = None) -> None: + """Initialize the client. + + Args: + api_key: Samba API key. If no api_key is provided, the client will read + the SAMBA_API_KEY environment variable. + url: Samba URL. If no url is provided, the client will read the SAMBA_URL + """ + self.url = url or os.environ.get("SAMBA_URL") + self.api_key = api_key or os.environ.get("SAMBA_API_KEY") + if self.url is None: + raise ValueError( + "Samba URL not provided. Please pass in or set SAMBA_URL env variable" + ) + if self.api_key is None: + raise ValueError( + "Samba API key not provided. Please pass in or set SAMBA_API_KEY env variable" + ) + + def convert_request_for_samba(self, request: ChatRequest) -> dict[str, str]: + """Convert a ChatRequest to a dict for Samba.""" + request_dict = request.model_dump(exclude_none=True) + # perform renamings + key_renamings = { + "messages": "inputs", + "max_tokens": "max_tokens_to_generate", + "stop": "stop_sequences", + } + keys_to_delete = { + "n", + "seed", + "presence_penalty", + "frequency_penalty", + "response_format", + } + for old_key, new_key in key_renamings.items(): + if old_key in request_dict: + request_dict[new_key] = request_dict.pop(old_key) + for key in keys_to_delete: + if key in request_dict: + del request_dict[key] + # now reformat tools to fit Anthropic's format + if "tools" in request_dict: + raise NotImplementedError("Tools are not yet supported for Samba.") + return request_dict + + def convert_samba_to_response(self, samba_response: dict) -> ChatResponse: + """Convert an Samba response to a ChatResponse.""" + choices = [] + tool_calls = None + choices.append( + Choice( + index=0, + message=ChatMessage( + content=samba_response["completion"], + role="assistant", + tool_calls=tool_calls, + ), + ) + ) + response = ChatResponse( + # Samba doesn't have an id so assign one + id=str(uuid.uuid4()), + cached=False, + choices=choices, + created=int(samba_response["start_time"]), + model=samba_response["model"], + usage=Usage( + completion_tokens=samba_response["completion_tokens_count"], + prompt_tokens=samba_response["prompt_tokens_count"], + total_tokens=samba_response["total_tokens_count"], + ), + ) + return response + + async def arun_chat(self, request: ChatRequest) -> ChatResponse: + """Send a chat request.""" + # convert the request to Samba format + samba_request = self.convert_request_for_samba(request) + # send the request to Samba + # TODO: make async + samba_response = send_request(self.api_key, self.url, samba_request) + # convert the response to ChatResponse + response = self.convert_samba_to_response(samba_response) + return response diff --git a/meadow/client/api/together.py b/meadow/client/api/together.py index a578483..7fdb515 100644 --- a/meadow/client/api/together.py +++ b/meadow/client/api/together.py @@ -16,9 +16,6 @@ logger = logging.getLogger(__name__) -DEFAULT_MAX_TOKENS = 100 - - class TogetherClient(APIClient): """Together LLM Client.""" @@ -34,10 +31,7 @@ def __init__(self, api_key: str | None) -> None: ) def convert_request_for_together(self, request: ChatRequest) -> dict[str, str]: - """Convert a ChatRequest to a dict for Together. - - Together requires max_tokens to be set. If None is provided, will default. - """ + """Convert a ChatRequest to a dict for Together.""" request_dict = request.model_dump(exclude_none=True) keys_to_delete = {"seed"} for key in keys_to_delete: diff --git a/meadow/client/schema.py b/meadow/client/schema.py index 970498a..ca4f315 100644 --- a/meadow/client/schema.py +++ b/meadow/client/schema.py @@ -117,7 +117,7 @@ class LLMConfig(BaseModel): """Stop sequences.""" stop: list[str] | None = None - """Penalize resence.""" + """Penalize presence.""" presence_penalty: float | None = None """Penalize frequency.""" diff --git a/poetry.lock b/poetry.lock index 204b69f..b4745c8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2810,6 +2810,20 @@ files = [ {file = "types_pytz-2024.1.0.20240417-py3-none-any.whl", hash = "sha256:8335d443310e2db7b74e007414e74c4f53b67452c0cb0d228ca359ccfba59659"}, ] +[[package]] +name = "types-requests" +version = "2.32.0.20240712" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-requests-2.32.0.20240712.tar.gz", hash = "sha256:90c079ff05e549f6bf50e02e910210b98b8ff1ebdd18e19c873cd237737c1358"}, + {file = "types_requests-2.32.0.20240712-py3-none-any.whl", hash = "sha256:f754283e152c752e46e70942fa2a146b5bc70393522257bb85bd1ef7e019dcc3"}, +] + +[package.dependencies] +urllib3 = ">=2" + [[package]] name = "typing-extensions" version = "4.11.0" @@ -2986,4 +3000,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "43cd260562328b2192854aa62b08e8eafb88a5ea006ba9300e11a5e1d91d0309" +content-hash = "92b505616378e5aa3847f6bd529766154199f26f097d74f986119952718dbd02" diff --git a/pyproject.toml b/pyproject.toml index 1e76012..efd7468 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ pytest-cov = "^5.0.0" pytest-asyncio = "^0.23.6" types-colorama = "^0.4.15.20240311" pandas-stubs = "^2.2.1.240316" +types-requests = "^2.32.0.20240712" [build-system] requires = ["poetry-core"]