forked from stanfordnlp/dspy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
238 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
--- | ||
sidebar_position: 10 | ||
--- | ||
|
||
# dspy.CloudflareAI | ||
|
||
### Usage | ||
|
||
```python | ||
lm = dspy.CloudflareAI(model="@hf/meta-llama/meta-llama-3-8b-instruct") | ||
``` | ||
|
||
### Constructor | ||
|
||
The constructor initializes the base class `LM` and verifies the `api_key` and `account_id` for using Cloudflare AI API. | ||
The following environment variables are expected to be set or passed as arguments: | ||
|
||
- `CLOUDFLARE_ACCOUNT_ID`: Account ID for Cloudflare. | ||
- `CLOUDFLARE_API_KEY`: API key for Cloudflare. | ||
|
||
```python | ||
class CloudflareAI(LM): | ||
def __init__( | ||
self, | ||
model: str = "@hf/meta-llama/meta-llama-3-8b-instruct", | ||
account_id: Optional[str] = None, | ||
api_key: Optional[str] = None, | ||
system_prompt: Optional[str] = None, | ||
**kwargs, | ||
): | ||
``` | ||
|
||
**Parameters:** | ||
|
||
- `model` (_str_): Model hosted on Cloudflare. Defaults to `@hf/meta-llama/meta-llama-3-8b-instruct`. | ||
- `account_id` (_Optional[str]_, _optional_): Account ID for Cloudflare. Defaults to None. Reads from environment variable `CLOUDFLARE_ACCOUNT_ID`. | ||
- `api_key` (_Optional[str]_, _optional_): API key for Cloudflare. Defaults to None. Reads from environment variable `CLOUDFLARE_API_KEY`. | ||
- `system_prompt` (_Optional[str]_, _optional_): System prompt to use for generation. | ||
|
||
### Methods | ||
|
||
Refer to [`dspy.OpenAI`](https://dspy-docs.vercel.app/api/language_model_clients/OpenAI) documentation. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import logging | ||
import os | ||
from typing import Any, Optional | ||
|
||
import backoff | ||
import requests | ||
from pydantic import BaseModel, ValidationError | ||
|
||
from dsp.modules.lm import LM | ||
|
||
|
||
def backoff_hdlr(details) -> None: | ||
"""Log backoff details when retries occur.""" | ||
logging.warning( | ||
f"Backing off {details['wait']:0.1f} seconds afters {details['tries']} tries " | ||
f"calling function {details['target']} with args {details['args']} and kwargs {details['kwargs']}", | ||
) | ||
|
||
|
||
def giveup_hdlr(details) -> bool: | ||
"""Decide whether to give up on retries based on the exception.""" | ||
logging.error( | ||
"Giving up: After {tries} tries, calling {target} failed due to {value}".format( | ||
tries=details["tries"], | ||
target=details["target"], | ||
value=details.get("value", "Unknown Error"), | ||
), | ||
) | ||
return False # Always returns False to not give up | ||
|
||
|
||
class LLMResponse(BaseModel): | ||
response: str | ||
|
||
|
||
class CloudflareAIResponse(BaseModel): | ||
result: LLMResponse | ||
success: bool | ||
errors: list | ||
messages: list | ||
|
||
|
||
class CloudflareAI(LM): | ||
"""Wrapper around Cloudflare Workers AI API.""" | ||
|
||
def __init__( | ||
self, | ||
model: str = "@hf/meta-llama/meta-llama-3-8b-instruct", | ||
account_id: Optional[str] = None, | ||
api_key: Optional[str] = None, | ||
system_prompt: Optional[str] = None, | ||
**kwargs, | ||
): | ||
super().__init__(model) | ||
self.provider = "cloudflare" | ||
self.model = model | ||
self.account_id = os.environ.get("CLOUDFLARE_ACCOUNT_ID") if account_id is None else account_id | ||
self.api_key = os.environ.get("CLOUDFLARE_API_KEY") if api_key is None else api_key | ||
self.base_url = f"https://api.cloudflare.com/client/v4/accounts/{self.account_id}/ai/run/{model}" | ||
self.headers = {"Authorization": f"Bearer {self.api_key}"} | ||
self.system_prompt = system_prompt | ||
self.kwargs = { | ||
"temperature": 0.0, # Cloudflare Workers AI does not support temperature | ||
"max_tokens": kwargs.get("max_tokens", 256), | ||
**kwargs, | ||
} | ||
self.history: list[dict[str, Any]] = [] | ||
|
||
@backoff.on_exception( | ||
backoff.expo, | ||
requests.exceptions.RequestException, | ||
max_time=1000, | ||
on_backoff=backoff_hdlr, | ||
on_giveup=giveup_hdlr, | ||
) | ||
def basic_request(self, prompt: str, **kwargs): # noqa: ANN201 - Other LM implementations don't have a return type | ||
messages = [{"role": "user", "content": prompt}] | ||
|
||
if self.system_prompt: | ||
messages.insert(0, {"role": "system", "content": self.system_prompt}) | ||
|
||
json_payload = {"messages": messages, "max_tokens": self.kwargs["max_tokens"], **kwargs} | ||
|
||
response = requests.post(self.base_url, headers=self.headers, json=json_payload) # noqa: S113 - There is a backoff decorator which handles timeout | ||
response.raise_for_status() | ||
|
||
""" | ||
Schema of the response: | ||
{ | ||
"result": | ||
{ | ||
"response": string, | ||
"success":boolean, | ||
"errors":[], | ||
"messages":[] | ||
} | ||
} | ||
""" | ||
try: | ||
res = response.json() | ||
cf = CloudflareAIResponse.model_validate(res) | ||
result = cf.result | ||
|
||
history_entry = {"prompt": prompt, "response": result.response, "kwargs": kwargs} | ||
self.history.append(history_entry) | ||
|
||
return result.response | ||
except ValidationError as e: | ||
logging.error(f"Error validating response: {e}") | ||
raise e | ||
|
||
def request(self, prompt: str, **kwargs): # noqa: ANN201- Other LM implementations don't have a return type | ||
"""Makes an API request to Cloudflare Workers AI with error handling.""" | ||
return self.basic_request(prompt, **kwargs) | ||
|
||
def __call__(self, prompt: str, **kwargs): | ||
"""Retrieve the AI completion from Cloudflare Workers AI API.""" | ||
response = self.request(prompt, **kwargs) | ||
|
||
return [response] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
"""Tests for Cloudflare models. | ||
Note: Requires configuration of your Cloudflare account_id and api_key. | ||
""" | ||
|
||
import dspy | ||
|
||
models = { | ||
"@cf/qwen/qwen1.5-0.5b-chat": "https://huggingface.co/qwen/qwen1.5-0.5b-chat", | ||
"@hf/meta-llama/meta-llama-3-8b-instruct": "https://llama.meta.com", | ||
"@hf/nexusflow/starling-lm-7b-beta": "https://huggingface.co/Nexusflow/Starling-LM-7B-beta", | ||
"@cf/meta/llama-3-8b-instruct": "https://llama.meta.com", | ||
"@hf/thebloke/neural-chat-7b-v3-1-awq": "", | ||
"@cf/meta/llama-2-7b-chat-fp16": "https://ai.meta.com/llama/", | ||
"@cf/mistral/mistral-7b-instruct-v0.1": "https://mistral.ai/news/announcing-mistral-7b/", | ||
"@cf/tinyllama/tinyllama-1.1b-chat-v1.0": "https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0", | ||
"@hf/mistral/mistral-7b-instruct-v0.2": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2", | ||
"@cf/fblgit/una-cybertron-7b-v2-bf16": "", | ||
"@hf/thebloke/codellama-7b-instruct-awq": "https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-AWQ", | ||
"@cf/thebloke/discolm-german-7b-v1-awq": "https://huggingface.co/TheBloke/DiscoLM_German_7b_v1-AWQ", | ||
"@cf/meta/llama-2-7b-chat-int8": "https://ai.meta.com/llama/", | ||
"@hf/thebloke/mistral-7b-instruct-v0.1-awq": "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-AWQ", | ||
"@hf/thebloke/openchat_3.5-awq": "", | ||
"@cf/qwen/qwen1.5-7b-chat-awq": "https://huggingface.co/qwen/qwen1.5-7b-chat-awq", | ||
"@hf/thebloke/llama-2-13b-chat-awq": "https://huggingface.co/TheBloke/Llama-2-13B-chat-AWQ", | ||
"@hf/thebloke/deepseek-coder-6.7b-base-awq": "", | ||
"@hf/thebloke/openhermes-2.5-mistral-7b-awq": "", | ||
"@hf/thebloke/deepseek-coder-6.7b-instruct-awq": "", | ||
"@cf/deepseek-ai/deepseek-math-7b-instruct": "https://huggingface.co/deepseek-ai/deepseek-math-7b-instruct", | ||
"@cf/tiiuae/falcon-7b-instruct": "https://huggingface.co/tiiuae/falcon-7b-instruct", | ||
"@hf/nousresearch/hermes-2-pro-mistral-7b": "https://huggingface.co/NousResearch/Hermes-2-Pro-Mistral-7B", | ||
"@hf/thebloke/zephyr-7b-beta-awq": "https://huggingface.co/TheBloke/zephyr-7B-beta-AWQ", | ||
"@cf/qwen/qwen1.5-1.8b-chat": "https://huggingface.co/qwen/qwen1.5-1.8b-chat", | ||
"@cf/defog/sqlcoder-7b-2": "https://huggingface.co/defog/sqlcoder-7b-2", | ||
"@cf/microsoft/phi-2": "https://huggingface.co/microsoft/phi-2", | ||
"@hf/google/gemma-7b-it": "https://ai.google.dev/gemma/docs", | ||
} | ||
|
||
|
||
def get_lm(name: str) -> dspy.LM: | ||
return dspy.CloudflareAI(model=name) | ||
|
||
|
||
def run_tests(): | ||
"""Test the providers and models""" | ||
# Configure your AWS credentials with the AWS CLI before running this script | ||
models | ||
|
||
predict_func = dspy.Predict("question -> answer") | ||
for model_name in models.keys(): | ||
lm = get_lm(model_name) | ||
with dspy.context(lm=lm): | ||
question = "What is the capital of France?" | ||
answer = predict_func(question=question).answer | ||
print(f"Question: {question}\nAnswer: {answer}") | ||
print("---------------------------------") | ||
lm.inspect_history() | ||
print("---------------------------------\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |