Skip to content

Commit

Permalink
feat: adds Cloudflare Workers AI
Browse files Browse the repository at this point in the history
  • Loading branch information
isbecker committed May 6, 2024
1 parent d7dace6 commit 3dcddbe
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 3 deletions.
42 changes: 42 additions & 0 deletions docs/api/language_model_clients/Cloudflare.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
---
sidebar_position: 10
---

# dspy.CloudflareAI

### Usage

```python
lm = dspy.CloudflareAI(model="@hf/meta-llama/meta-llama-3-8b-instruct")
```

### Constructor

The constructor initializes the base class `LM` and verifies the `api_key` and `account_id` for using Cloudflare AI API.
The following environment variables are expected to be set or passed as arguments:

- `CLOUDFLARE_ACCOUNT_ID`: Account ID for Cloudflare.
- `CLOUDFLARE_API_KEY`: API key for Cloudflare.

```python
class CloudflareAI(LM):
def __init__(
self,
model: str = "@hf/meta-llama/meta-llama-3-8b-instruct",
account_id: Optional[str] = None,
api_key: Optional[str] = None,
system_prompt: Optional[str] = None,
**kwargs,
):
```

**Parameters:**

- `model` (_str_): Model hosted on Cloudflare. Defaults to `@hf/meta-llama/meta-llama-3-8b-instruct`.
- `account_id` (_Optional[str]_, _optional_): Account ID for Cloudflare. Defaults to None. Reads from environment variable `CLOUDFLARE_ACCOUNT_ID`.
- `api_key` (_Optional[str]_, _optional_): API key for Cloudflare. Defaults to None. Reads from environment variable `CLOUDFLARE_API_KEY`.
- `system_prompt` (_Optional[str]_, _optional_): System prompt to use for generation.

### Methods

Refer to [`dspy.OpenAI`](https://dspy-docs.vercel.app/api/language_model_clients/OpenAI) documentation.
1 change: 1 addition & 0 deletions dsp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .azure_openai import AzureOpenAI
from .cache_utils import *
from .clarifai import *
from .cloudflare import *
from .cohere import *
from .colbertv2 import ColBERTv2
from .databricks import *
Expand Down
120 changes: 120 additions & 0 deletions dsp/modules/cloudflare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import logging
import os
from typing import Any, Optional

import backoff
import requests
from pydantic import BaseModel, ValidationError

from dsp.modules.lm import LM


def backoff_hdlr(details) -> None:
"""Log backoff details when retries occur."""
logging.warning(
f"Backing off {details['wait']:0.1f} seconds afters {details['tries']} tries "
f"calling function {details['target']} with args {details['args']} and kwargs {details['kwargs']}",
)


def giveup_hdlr(details) -> bool:
"""Decide whether to give up on retries based on the exception."""
logging.error(
"Giving up: After {tries} tries, calling {target} failed due to {value}".format(
tries=details["tries"],
target=details["target"],
value=details.get("value", "Unknown Error"),
),
)
return False # Always returns False to not give up


class LLMResponse(BaseModel):
response: str


class CloudflareAIResponse(BaseModel):
result: LLMResponse
success: bool
errors: list
messages: list


class CloudflareAI(LM):
"""Wrapper around Cloudflare Workers AI API."""

def __init__(
self,
model: str = "@hf/meta-llama/meta-llama-3-8b-instruct",
account_id: Optional[str] = None,
api_key: Optional[str] = None,
system_prompt: Optional[str] = None,
**kwargs,
):
super().__init__(model)
self.provider = "cloudflare"
self.model = model
self.account_id = os.environ.get("CLOUDFLARE_ACCOUNT_ID") if account_id is None else account_id
self.api_key = os.environ.get("CLOUDFLARE_API_KEY") if api_key is None else api_key
self.base_url = f"https://api.cloudflare.com/client/v4/accounts/{self.account_id}/ai/run/{model}"
self.headers = {"Authorization": f"Bearer {self.api_key}"}
self.system_prompt = system_prompt
self.kwargs = {
"temperature": 0.0, # Cloudflare Workers AI does not support temperature
"max_tokens": kwargs.get("max_tokens", 256),
**kwargs,
}
self.history: list[dict[str, Any]] = []

@backoff.on_exception(
backoff.expo,
requests.exceptions.RequestException,
max_time=1000,
on_backoff=backoff_hdlr,
on_giveup=giveup_hdlr,
)
def basic_request(self, prompt: str, **kwargs): # noqa: ANN201 - Other LM implementations don't have a return type
messages = [{"role": "user", "content": prompt}]

if self.system_prompt:
messages.insert(0, {"role": "system", "content": self.system_prompt})

json_payload = {"messages": messages, "max_tokens": self.kwargs["max_tokens"], **kwargs}

response = requests.post(self.base_url, headers=self.headers, json=json_payload) # noqa: S113 - There is a backoff decorator which handles timeout
response.raise_for_status()

"""
Schema of the response:
{
"result":
{
"response": string,
"success":boolean,
"errors":[],
"messages":[]
}
}
"""
try:
res = response.json()
cf = CloudflareAIResponse.model_validate(res)
result = cf.result

history_entry = {"prompt": prompt, "response": result.response, "kwargs": kwargs}
self.history.append(history_entry)

return result.response
except ValidationError as e:
logging.error(f"Error validating response: {e}")
raise e

def request(self, prompt: str, **kwargs): # noqa: ANN201- Other LM implementations don't have a return type
"""Makes an API request to Cloudflare Workers AI with error handling."""
return self.basic_request(prompt, **kwargs)

def __call__(self, prompt: str, **kwargs):
"""Retrieve the AI completion from Cloudflare Workers AI API."""
response = self.request(prompt, **kwargs)

return [response]
16 changes: 13 additions & 3 deletions dsp/modules/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,23 @@ def inspect_history(self, n: int = 1, skip: int = 0):
prompt = x["prompt"]

if prompt != last_prompt:
if provider == "clarifai" or provider == "google" or provider == "groq" or provider == "Bedrock" or provider == "Sagemaker":
if (
provider == "clarifai"
or provider == "google"
or provider == "groq"
or provider == "Bedrock"
or provider == "Sagemaker"
):
printed.append((prompt, x["response"]))
elif provider == "anthropic":
blocks = [{"text": block.text} for block in x["response"].content if block.type == "text"]
printed.append((prompt, blocks))
elif provider == "cohere":
printed.append((prompt, x["response"].text))
elif provider == "mistral":
printed.append((prompt, x['response'].choices))
printed.append((prompt, x["response"].choices))
elif provider == "cloudflare":
printed.append((prompt, [x["response"]]))
else:
printed.append((prompt, x["response"]["choices"]))

Expand All @@ -79,11 +87,13 @@ def inspect_history(self, n: int = 1, skip: int = 0):
elif provider == "clarifai" or provider == "claude":
text = choices
elif provider == "groq":
text = ' ' + choices
text = " " + choices
elif provider == "google":
text = choices[0].parts[0].text
elif provider == "mistral":
text = choices[0].message.content
elif provider == "cloudflare":
text = choices[0]
else:
text = choices[0]["text"]
printing_value += self.print_green(text, end="")
Expand Down
1 change: 1 addition & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ColBERTv2 = dsp.ColBERTv2
Pyserini = dsp.PyseriniRetriever
Clarifai = dsp.ClarifaiLLM
CloudflareAI = dsp.CloudflareAI
Google = dsp.Google
GoogleVertexAI = dsp.GoogleVertexAI
GROQ = dsp.GroqLM
Expand Down
61 changes: 61 additions & 0 deletions tests/modules/test_cloudflare_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Tests for Cloudflare models.
Note: Requires configuration of your Cloudflare account_id and api_key.
"""

import dspy

models = {
"@cf/qwen/qwen1.5-0.5b-chat": "https://huggingface.co/qwen/qwen1.5-0.5b-chat",
"@hf/meta-llama/meta-llama-3-8b-instruct": "https://llama.meta.com",
"@hf/nexusflow/starling-lm-7b-beta": "https://huggingface.co/Nexusflow/Starling-LM-7B-beta",
"@cf/meta/llama-3-8b-instruct": "https://llama.meta.com",
"@hf/thebloke/neural-chat-7b-v3-1-awq": "",
"@cf/meta/llama-2-7b-chat-fp16": "https://ai.meta.com/llama/",
"@cf/mistral/mistral-7b-instruct-v0.1": "https://mistral.ai/news/announcing-mistral-7b/",
"@cf/tinyllama/tinyllama-1.1b-chat-v1.0": "https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"@hf/mistral/mistral-7b-instruct-v0.2": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
"@cf/fblgit/una-cybertron-7b-v2-bf16": "",
"@hf/thebloke/codellama-7b-instruct-awq": "https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-AWQ",
"@cf/thebloke/discolm-german-7b-v1-awq": "https://huggingface.co/TheBloke/DiscoLM_German_7b_v1-AWQ",
"@cf/meta/llama-2-7b-chat-int8": "https://ai.meta.com/llama/",
"@hf/thebloke/mistral-7b-instruct-v0.1-awq": "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-AWQ",
"@hf/thebloke/openchat_3.5-awq": "",
"@cf/qwen/qwen1.5-7b-chat-awq": "https://huggingface.co/qwen/qwen1.5-7b-chat-awq",
"@hf/thebloke/llama-2-13b-chat-awq": "https://huggingface.co/TheBloke/Llama-2-13B-chat-AWQ",
"@hf/thebloke/deepseek-coder-6.7b-base-awq": "",
"@hf/thebloke/openhermes-2.5-mistral-7b-awq": "",
"@hf/thebloke/deepseek-coder-6.7b-instruct-awq": "",
"@cf/deepseek-ai/deepseek-math-7b-instruct": "https://huggingface.co/deepseek-ai/deepseek-math-7b-instruct",
"@cf/tiiuae/falcon-7b-instruct": "https://huggingface.co/tiiuae/falcon-7b-instruct",
"@hf/nousresearch/hermes-2-pro-mistral-7b": "https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B",
"@hf/thebloke/zephyr-7b-beta-awq": "https://huggingface.co/TheBloke/zephyr-7B-beta-AWQ",
"@cf/qwen/qwen1.5-1.8b-chat": "https://huggingface.co/qwen/qwen1.5-1.8b-chat",
"@cf/defog/sqlcoder-7b-2": "https://huggingface.co/defog/sqlcoder-7b-2",
"@cf/microsoft/phi-2": "https://huggingface.co/microsoft/phi-2",
"@hf/google/gemma-7b-it": "https://ai.google.dev/gemma/docs",
}


def get_lm(name: str) -> dspy.LM:
return dspy.CloudflareAI(model=name)


def run_tests():
"""Test the providers and models"""
# Configure your AWS credentials with the AWS CLI before running this script
models

predict_func = dspy.Predict("question -> answer")
for model_name in models.keys():
lm = get_lm(model_name)
with dspy.context(lm=lm):
question = "What is the capital of France?"
answer = predict_func(question=question).answer
print(f"Question: {question}\nAnswer: {answer}")
print("---------------------------------")
lm.inspect_history()
print("---------------------------------\n")


if __name__ == "__main__":
run_tests()

0 comments on commit 3dcddbe

Please sign in to comment.