diff --git a/examples/demo_user_feedback.py b/examples/demo_user_feedback.py
index 1186553..0fe7f8f 100644
--- a/examples/demo_user_feedback.py
+++ b/examples/demo_user_feedback.py
@@ -25,6 +25,8 @@
from meadow.client.api.anthropic import AnthropicClient
from meadow.client.api.api_client import APIClient
from meadow.client.api.openai import OpenAIClient
+from meadow.client.api.samba import SambaClient
+from meadow.client.api.together import TogetherClient
from meadow.client.schema import LLMConfig
from meadow.database.connector.connector import Connector
from meadow.database.connector.duckdb import DuckDBConnector
@@ -230,6 +232,10 @@ def run_meadow(
api_client: APIClient = AnthropicClient(api_key=api_key)
elif api_provider == "openai":
api_client = OpenAIClient(api_key=api_key)
+ elif api_provider == "together":
+ api_client = TogetherClient(api_key=api_key)
+ elif api_provider == "samba":
+ api_client = SambaClient(api_key=api_key)
else:
raise ValueError(f"Unknown API provider: {api_provider}")
client = Client(
@@ -237,7 +243,7 @@ def run_meadow(
api_client=api_client,
model=model,
)
- big_client = Client(cache=cache, api_client=api_client, model="gpt-4o")
+ big_client = Client(cache=cache, api_client=api_client, model=model)
llm_config = LLMConfig(
max_tokens=1000,
temperature=0.0,
@@ -288,7 +294,7 @@ def run_meadow(
argparser.add_argument(
"--model",
type=str,
- help="Anthropic model.",
+ help="Model.",
default="gpt-4o",
)
argparser.add_argument(
diff --git a/meadow/agent/planner.py b/meadow/agent/planner.py
index 598da22..19d99b2 100644
--- a/meadow/agent/planner.py
+++ b/meadow/agent/planner.py
@@ -133,6 +133,8 @@ def parse_plan(
message += ">"
if "" not in message:
error_message = "Plan not found in the response. Please provide a plan in the format ...."
+ if "" not in message:
+ message += ""
if not error_message:
inner_steps = re.search(r"(.*)", message, re.DOTALL).group(1)
parsed_steps = parse_steps(inner_steps)
diff --git a/meadow/client/api/samba.py b/meadow/client/api/samba.py
new file mode 100644
index 0000000..c109281
--- /dev/null
+++ b/meadow/client/api/samba.py
@@ -0,0 +1,128 @@
+import json
+import logging
+import os
+import uuid
+from typing import Any
+
+import requests
+
+from meadow.client.api.api_client import APIClient
+from meadow.client.schema import (
+ ChatMessage,
+ ChatRequest,
+ ChatResponse,
+ Choice,
+ Usage,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def send_request(api_key: str, url: str, data: dict[str, Any]) -> dict:
+ """
+ Sends a POST request to the specified URL with headers and JSON data.
+ """
+
+ # Headers including the Authorization and Content-Type
+ headers = {"Authorization": f"Basic {api_key}", "Content-Type": "application/json"}
+
+ # Sending the POST request
+ response = requests.post(url, headers=headers, json=data)
+
+ after_end_pair = response.text.rsplit("event: end_event", 1)
+ if len(after_end_pair) != 2:
+ raise ValueError("end_event not found in response")
+ after_end = after_end_pair[1]
+ text = after_end.split("data: ", 1)[1]
+ response_dict = json.loads(text)
+ return response_dict
+
+
+class SambaClient(APIClient):
+ """Samba LLM Client."""
+
+ def __init__(self, api_key: str | None, url: str | None = None) -> None:
+ """Initialize the client.
+
+ Args:
+ api_key: Samba API key. If no api_key is provided, the client will read
+ the SAMBA_API_KEY environment variable.
+ url: Samba URL. If no url is provided, the client will read the SAMBA_URL
+ """
+ self.url = url or os.environ.get("SAMBA_URL")
+ self.api_key = api_key or os.environ.get("SAMBA_API_KEY")
+ if self.url is None:
+ raise ValueError(
+ "Samba URL not provided. Please pass in or set SAMBA_URL env variable"
+ )
+ if self.api_key is None:
+ raise ValueError(
+ "Samba API key not provided. Please pass in or set SAMBA_API_KEY env variable"
+ )
+
+ def convert_request_for_samba(self, request: ChatRequest) -> dict[str, str]:
+ """Convert a ChatRequest to a dict for Samba."""
+ request_dict = request.model_dump(exclude_none=True)
+ # perform renamings
+ key_renamings = {
+ "messages": "inputs",
+ "max_tokens": "max_tokens_to_generate",
+ "stop": "stop_sequences",
+ }
+ keys_to_delete = {
+ "n",
+ "seed",
+ "presence_penalty",
+ "frequency_penalty",
+ "response_format",
+ }
+ for old_key, new_key in key_renamings.items():
+ if old_key in request_dict:
+ request_dict[new_key] = request_dict.pop(old_key)
+ for key in keys_to_delete:
+ if key in request_dict:
+ del request_dict[key]
+ # now reformat tools to fit Anthropic's format
+ if "tools" in request_dict:
+ raise NotImplementedError("Tools are not yet supported for Samba.")
+ return request_dict
+
+ def convert_samba_to_response(self, samba_response: dict) -> ChatResponse:
+ """Convert an Samba response to a ChatResponse."""
+ choices = []
+ tool_calls = None
+ choices.append(
+ Choice(
+ index=0,
+ message=ChatMessage(
+ content=samba_response["completion"],
+ role="assistant",
+ tool_calls=tool_calls,
+ ),
+ )
+ )
+ response = ChatResponse(
+ # Samba doesn't have an id so assign one
+ id=str(uuid.uuid4()),
+ cached=False,
+ choices=choices,
+ created=int(samba_response["start_time"]),
+ model=samba_response["model"],
+ usage=Usage(
+ completion_tokens=samba_response["completion_tokens_count"],
+ prompt_tokens=samba_response["prompt_tokens_count"],
+ total_tokens=samba_response["total_tokens_count"],
+ ),
+ )
+ return response
+
+ async def arun_chat(self, request: ChatRequest) -> ChatResponse:
+ """Send a chat request."""
+ # convert the request to Samba format
+ samba_request = self.convert_request_for_samba(request)
+ # send the request to Samba
+ # TODO: make async
+ samba_response = send_request(self.api_key, self.url, samba_request)
+ # convert the response to ChatResponse
+ response = self.convert_samba_to_response(samba_response)
+ return response
diff --git a/meadow/client/api/together.py b/meadow/client/api/together.py
index a578483..7fdb515 100644
--- a/meadow/client/api/together.py
+++ b/meadow/client/api/together.py
@@ -16,9 +16,6 @@
logger = logging.getLogger(__name__)
-DEFAULT_MAX_TOKENS = 100
-
-
class TogetherClient(APIClient):
"""Together LLM Client."""
@@ -34,10 +31,7 @@ def __init__(self, api_key: str | None) -> None:
)
def convert_request_for_together(self, request: ChatRequest) -> dict[str, str]:
- """Convert a ChatRequest to a dict for Together.
-
- Together requires max_tokens to be set. If None is provided, will default.
- """
+ """Convert a ChatRequest to a dict for Together."""
request_dict = request.model_dump(exclude_none=True)
keys_to_delete = {"seed"}
for key in keys_to_delete:
diff --git a/meadow/client/schema.py b/meadow/client/schema.py
index 970498a..ca4f315 100644
--- a/meadow/client/schema.py
+++ b/meadow/client/schema.py
@@ -117,7 +117,7 @@ class LLMConfig(BaseModel):
"""Stop sequences."""
stop: list[str] | None = None
- """Penalize resence."""
+ """Penalize presence."""
presence_penalty: float | None = None
"""Penalize frequency."""
diff --git a/poetry.lock b/poetry.lock
index 204b69f..b4745c8 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -2810,6 +2810,20 @@ files = [
{file = "types_pytz-2024.1.0.20240417-py3-none-any.whl", hash = "sha256:8335d443310e2db7b74e007414e74c4f53b67452c0cb0d228ca359ccfba59659"},
]
+[[package]]
+name = "types-requests"
+version = "2.32.0.20240712"
+description = "Typing stubs for requests"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "types-requests-2.32.0.20240712.tar.gz", hash = "sha256:90c079ff05e549f6bf50e02e910210b98b8ff1ebdd18e19c873cd237737c1358"},
+ {file = "types_requests-2.32.0.20240712-py3-none-any.whl", hash = "sha256:f754283e152c752e46e70942fa2a146b5bc70393522257bb85bd1ef7e019dcc3"},
+]
+
+[package.dependencies]
+urllib3 = ">=2"
+
[[package]]
name = "typing-extensions"
version = "4.11.0"
@@ -2986,4 +3000,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "43cd260562328b2192854aa62b08e8eafb88a5ea006ba9300e11a5e1d91d0309"
+content-hash = "92b505616378e5aa3847f6bd529766154199f26f097d74f986119952718dbd02"
diff --git a/pyproject.toml b/pyproject.toml
index 1e76012..efd7468 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,6 +35,7 @@ pytest-cov = "^5.0.0"
pytest-asyncio = "^0.23.6"
types-colorama = "^0.4.15.20240311"
pandas-stubs = "^2.2.1.240316"
+types-requests = "^2.32.0.20240712"
[build-system]
requires = ["poetry-core"]