Skip to content

Commit

Permalink
Ollama HTTP provider. (#33)
Browse files Browse the repository at this point in the history
* Ollama HTTP provider.

* Adding the test case for Ollama Provider.

* Addressing review comments.
  • Loading branch information
rohitprasad15 authored Sep 20, 2024
1 parent 120ed16 commit 2e36536
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 211 deletions.
2 changes: 2 additions & 0 deletions aisuite/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class ProviderNames(str, Enum):
GROQ = "groq"
GOOGLE = "google"
MISTRAL = "mistral"
OLLAMA = "ollama"
OPENAI = "openai"


Expand All @@ -46,6 +47,7 @@ class ProviderFactory:
"aisuite.providers.mistral_provider",
"MistralProvider",
),
ProviderNames.OLLAMA: ("aisuite.providers.ollama_provider", "OllamaProvider"),
ProviderNames.OPENAI: ("aisuite.providers.openai_provider", "OpenAIProvider"),
}

Expand Down
1 change: 0 additions & 1 deletion aisuite/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@

from .fireworks_interface import FireworksInterface
from .octo_interface import OctoInterface
from .ollama_interface import OllamaInterface
from .replicate_interface import ReplicateInterface
from .together_interface import TogetherInterface
54 changes: 0 additions & 54 deletions aisuite/providers/ollama_interface.py

This file was deleted.

65 changes: 65 additions & 0 deletions aisuite/providers/ollama_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
import httpx
from aisuite.provider import Provider, LLMError
from aisuite.framework import ChatCompletionResponse


class OllamaProvider(Provider):
"""
Ollama Provider that makes HTTP calls instead of using SDK.
It uses the /api/chat endpoint.
Read more here - https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
If OLLAMA_API_URL is not set and not passed in config, then it will default to "http://localhost:11434"
"""

_CHAT_COMPLETION_ENDPOINT = "/api/chat"
_CONNECT_ERROR_MESSAGE = "Ollama is likely not running. Start Ollama by running `ollama serve` on your host."

def __init__(self, **config):
"""
Initialize the Ollama provider with the given configuration.
"""
self.url = config.get("api_url") or os.getenv(
"OLLAMA_API_URL", "http://localhost:11434"
)

# Optionally set a custom timeout (default to 30s)
self.timeout = config.get("timeout", 30)

def chat_completions_create(self, model, messages, **kwargs):
"""
Makes a request to the chat completions endpoint using httpx.
"""
kwargs["stream"] = False
data = {
"model": model,
"messages": messages,
**kwargs, # Pass any additional arguments to the API
}

try:
response = httpx.post(
self.url.rstrip("/") + self._CHAT_COMPLETION_ENDPOINT,
json=data,
timeout=self.timeout,
)
response.raise_for_status()
except httpx.ConnectError: # Handle connection errors
raise LLMError(f"Connection failed: {self._CONNECT_ERROR_MESSAGE}")
except httpx.HTTPStatusError as http_err:
raise LLMError(f"Ollama request failed: {http_err}")
except Exception as e:
raise LLMError(f"An error occurred: {e}")

# Return the normalized response
return self._normalize_response(response.json())

def _normalize_response(self, response_data):
"""
Normalize the API response to a common format (ChatCompletionResponse).
"""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response_data["message"][
"content"
]
return normalized_response
108 changes: 7 additions & 101 deletions examples/client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "initial_id",
"metadata": {
"ExecuteTime": {
Expand All @@ -36,7 +36,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "f75736ee",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -73,16 +73,6 @@
{
"cell_type": "code",
"execution_count": null,
"id": "744c5c15",
"metadata": {},
"outputs": [],
"source": [
"print(os.environ[\"AWS_SECRET_ACCESS_KEY\"])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4de3a24f",
"metadata": {
"ExecuteTime": {
Expand Down Expand Up @@ -131,23 +121,10 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "7e46c20a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<aisuite.framework.chat_completion_response.ChatCompletionResponse object at 0x10944bdd0>\n",
"Arrr, listen close me hearties! Here be a joke for ye:\n",
"\n",
"Why did Captain Jack Sparrow go to the doctor?\n",
"\n",
"Because he had a bit o' a \"crabby\" day! (get it? crabby? like a crustacean, but also feeling grumpy? Ah, never mind, matey, ye landlubbers wouldn't understand...\n"
]
}
],
"outputs": [],
"source": [
"client2 = ai.Client({\"azure\" : {\n",
" \"api_key\": os.environ[\"AZURE_API_KEY\"],\n",
Expand All @@ -157,61 +134,6 @@
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5388efc4-3fd2-4dc6-ab58-7b179ce07943",
"metadata": {},
"outputs": [],
"source": [
"octo_llama3_8b = \"octo:meta-llama-3-8b-instruct\"\n",
"#octo_llama3_70b = \"octo:meta-llama-3-70b-instruct\"\n",
"\n",
"response = client.chat.completions.create(model=octo_llama3_8b, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4b3e6c41-070d-4041-9ed9-c8977790fe18",
"metadata": {},
"outputs": [],
"source": [
"together_llama3_8b = \"together:meta-llama/Llama-3-8b-chat-hf\"\n",
"#together_llama3_70b = \"together:meta-llama/Llama-3-70b-chat-hf\"\n",
"\n",
"response = client.chat.completions.create(model=together_llama3_8b, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "668a6cfa-9011-480a-ae1b-6dbd6a51e716",
"metadata": {},
"outputs": [],
"source": [
"#!pip install fireworks-ai"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9900fdf3-a113-40fd-b42f-0e6d866838be",
"metadata": {},
"outputs": [],
"source": [
"fireworks_llama3_8b = \"fireworks:accounts/fireworks/models/llama-v3-8b-instruct\"\n",
"#fireworks_llama3_70b = \"fireworks:accounts/fireworks/models/llama-v3-70b-instruct\"\n",
"\n",
"response = client.chat.completions.create(model=fireworks_llama3_8b, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -227,32 +149,16 @@
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6baf88b8-2ecb-4bdf-9263-4af949668d16",
"metadata": {},
"outputs": [],
"source": [
"replicate_llama3_8b = \"replicate:meta/meta-llama-3-8b-instruct\"\n",
"#replicate_llama3_70b = \"replicate:meta/meta-llama-3-70b-instruct\"\n",
"\n",
"response = client.chat.completions.create(model=replicate_llama3_8b, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6819ac17",
"metadata": {},
"outputs": [],
"source": [
"ollama_llama3 = \"ollama:llama3\"\n",
"\n",
"response = client.chat.completions.create(model=ollama_llama3, messages=messages, temperature=0.75)\n",
"\n",
"ollama_tinyllama = \"ollama:tinyllama\"\n",
"ollama_phi3mini = \"ollama:phi3:mini\"\n",
"response = client.chat.completions.create(model=ollama_phi3mini, messages=messages, temperature=0.75)\n",
"print(response.choices[0].message.content)"
]
},
Expand Down
55 changes: 0 additions & 55 deletions tests/providers/test_ollama_interface.py

This file was deleted.

Loading

0 comments on commit 2e36536

Please sign in to comment.