Skip to content

Commit

Permalink
Samba api
Browse files Browse the repository at this point in the history
  • Loading branch information
lorr1 committed Jul 12, 2024
1 parent f6be905 commit 39b8a23
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 11 deletions.
10 changes: 8 additions & 2 deletions examples/demo_user_feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from meadow.client.api.anthropic import AnthropicClient
from meadow.client.api.api_client import APIClient
from meadow.client.api.openai import OpenAIClient
from meadow.client.api.samba import SambaClient
from meadow.client.api.together import TogetherClient
from meadow.client.schema import LLMConfig
from meadow.database.connector.connector import Connector
from meadow.database.connector.duckdb import DuckDBConnector
Expand Down Expand Up @@ -230,14 +232,18 @@ def run_meadow(
api_client: APIClient = AnthropicClient(api_key=api_key)
elif api_provider == "openai":
api_client = OpenAIClient(api_key=api_key)
elif api_provider == "together":
api_client = TogetherClient(api_key=api_key)
elif api_provider == "samba":
api_client = SambaClient(api_key=api_key)
else:
raise ValueError(f"Unknown API provider: {api_provider}")
client = Client(
cache=cache,
api_client=api_client,
model=model,
)
big_client = Client(cache=cache, api_client=api_client, model="gpt-4o")
big_client = Client(cache=cache, api_client=api_client, model=model)
llm_config = LLMConfig(
max_tokens=1000,
temperature=0.0,
Expand Down Expand Up @@ -288,7 +294,7 @@ def run_meadow(
argparser.add_argument(
"--model",
type=str,
help="Anthropic model.",
help="Model.",
default="gpt-4o",
)
argparser.add_argument(
Expand Down
2 changes: 2 additions & 0 deletions meadow/agent/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def parse_plan(
message += ">"
if "<steps>" not in message:
error_message = "Plan not found in the response. Please provide a plan in the format <steps>...</steps>."
if "</steps>" not in message:
message += "</steps>"
if not error_message:
inner_steps = re.search(r"(<steps>.*</steps>)", message, re.DOTALL).group(1)
parsed_steps = parse_steps(inner_steps)
Expand Down
128 changes: 128 additions & 0 deletions meadow/client/api/samba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import json
import logging
import os
import uuid
from typing import Any

import requests

from meadow.client.api.api_client import APIClient
from meadow.client.schema import (
ChatMessage,
ChatRequest,
ChatResponse,
Choice,
Usage,
)

logger = logging.getLogger(__name__)


def send_request(api_key: str, url: str, data: dict[str, Any]) -> dict:
"""
Sends a POST request to the specified URL with headers and JSON data.
"""

# Headers including the Authorization and Content-Type
headers = {"Authorization": f"Basic {api_key}", "Content-Type": "application/json"}

# Sending the POST request
response = requests.post(url, headers=headers, json=data)

after_end_pair = response.text.rsplit("event: end_event", 1)
if len(after_end_pair) != 2:
raise ValueError("end_event not found in response")
after_end = after_end_pair[1]
text = after_end.split("data: ", 1)[1]
response_dict = json.loads(text)
return response_dict


class SambaClient(APIClient):
"""Samba LLM Client."""

def __init__(self, api_key: str | None, url: str | None = None) -> None:
"""Initialize the client.
Args:
api_key: Samba API key. If no api_key is provided, the client will read
the SAMBA_API_KEY environment variable.
url: Samba URL. If no url is provided, the client will read the SAMBA_URL
"""
self.url = url or os.environ.get("SAMBA_URL")
self.api_key = api_key or os.environ.get("SAMBA_API_KEY")
if self.url is None:
raise ValueError(
"Samba URL not provided. Please pass in or set SAMBA_URL env variable"
)
if self.api_key is None:
raise ValueError(
"Samba API key not provided. Please pass in or set SAMBA_API_KEY env variable"
)

def convert_request_for_samba(self, request: ChatRequest) -> dict[str, str]:
"""Convert a ChatRequest to a dict for Samba."""
request_dict = request.model_dump(exclude_none=True)
# perform renamings
key_renamings = {
"messages": "inputs",
"max_tokens": "max_tokens_to_generate",
"stop": "stop_sequences",
}
keys_to_delete = {
"n",
"seed",
"presence_penalty",
"frequency_penalty",
"response_format",
}
for old_key, new_key in key_renamings.items():
if old_key in request_dict:
request_dict[new_key] = request_dict.pop(old_key)
for key in keys_to_delete:
if key in request_dict:
del request_dict[key]
# now reformat tools to fit Anthropic's format
if "tools" in request_dict:
raise NotImplementedError("Tools are not yet supported for Samba.")
return request_dict

def convert_samba_to_response(self, samba_response: dict) -> ChatResponse:
"""Convert an Samba response to a ChatResponse."""
choices = []
tool_calls = None
choices.append(
Choice(
index=0,
message=ChatMessage(
content=samba_response["completion"],
role="assistant",
tool_calls=tool_calls,
),
)
)
response = ChatResponse(
# Samba doesn't have an id so assign one
id=str(uuid.uuid4()),
cached=False,
choices=choices,
created=int(samba_response["start_time"]),
model=samba_response["model"],
usage=Usage(
completion_tokens=samba_response["completion_tokens_count"],
prompt_tokens=samba_response["prompt_tokens_count"],
total_tokens=samba_response["total_tokens_count"],
),
)
return response

async def arun_chat(self, request: ChatRequest) -> ChatResponse:
"""Send a chat request."""
# convert the request to Samba format
samba_request = self.convert_request_for_samba(request)
# send the request to Samba
# TODO: make async
samba_response = send_request(self.api_key, self.url, samba_request)
# convert the response to ChatResponse
response = self.convert_samba_to_response(samba_response)
return response
8 changes: 1 addition & 7 deletions meadow/client/api/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
logger = logging.getLogger(__name__)


DEFAULT_MAX_TOKENS = 100


class TogetherClient(APIClient):
"""Together LLM Client."""

Expand All @@ -34,10 +31,7 @@ def __init__(self, api_key: str | None) -> None:
)

def convert_request_for_together(self, request: ChatRequest) -> dict[str, str]:
"""Convert a ChatRequest to a dict for Together.
Together requires max_tokens to be set. If None is provided, will default.
"""
"""Convert a ChatRequest to a dict for Together."""
request_dict = request.model_dump(exclude_none=True)
keys_to_delete = {"seed"}
for key in keys_to_delete:
Expand Down
2 changes: 1 addition & 1 deletion meadow/client/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class LLMConfig(BaseModel):
"""Stop sequences."""
stop: list[str] | None = None

"""Penalize resence."""
"""Penalize presence."""
presence_penalty: float | None = None

"""Penalize frequency."""
Expand Down
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pytest-cov = "^5.0.0"
pytest-asyncio = "^0.23.6"
types-colorama = "^0.4.15.20240311"
pandas-stubs = "^2.2.1.240316"
types-requests = "^2.32.0.20240712"

[build-system]
requires = ["poetry-core"]
Expand Down

0 comments on commit 39b8a23

Please sign in to comment.