Skip to content

Commit

Permalink
Async Ollama client
Browse files Browse the repository at this point in the history
  • Loading branch information
ggozad committed Oct 13, 2023
1 parent 39c5c1f commit 24b80e9
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 36 deletions.
46 changes: 20 additions & 26 deletions oterm/ollama.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import json
from typing import List, Tuple

import requests
import httpx

from oterm.config import Config

Expand All @@ -15,20 +13,20 @@ def __init__(self, model="nous-hermes:13b", template="", system=""):
self.model = model
self.template = template
self.system = system
self.context: List[int] = []
self.context: list[int] = []

def completion(
self,
prompt: str,
) -> str:
response, context = self._generate(
async def completion(self, prompt: str) -> str:
response, context = await self._agenerate(
prompt=prompt,
context=self.context,
)
self.context = context
return response

def _generate(self, prompt: str, context: List[int]) -> Tuple[str, List[int]]:
async def _agenerate(
self, prompt: str, context: list[int]
) -> tuple[str, list[int]]:
client = httpx.AsyncClient()
jsn = {
"model": self.model,
"prompt": prompt,
Expand All @@ -39,21 +37,17 @@ def _generate(self, prompt: str, context: List[int]) -> Tuple[str, List[int]]:
if self.template:
jsn["template"] = self.template

r = requests.post(
f"{Config.OLLAMA_URL}/generate",
json=jsn,
stream=True,
)
r.raise_for_status()

response = ""
for line in r.iter_lines():
body = json.loads(line)
response += body.get("response", "")
res = ""
async with client.stream(
"POST", f"{Config.OLLAMA_URL}/generate", json=jsn
) as response:
async for line in response.aiter_lines():
body = json.loads(line)
res += body.get("response", "")

if "error" in body:
raise OllamaError(body["error"])
if "error" in body:
raise OllamaError(body["error"])

if body.get("done", False):
return response, body["context"]
return response, []
if body.get("done", False):
return res, body["context"]
return res, []
100 changes: 99 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ textual = { version = "^0.40.0", extras = ["dev", "syntax"] }
typer = "^0.9.0"
python-dotenv = "^1.0.0"
pydantic = "^2.4.2"
httpx = "^0.25.0"

[tool.poetry.group.dev.dependencies]
pdbpp = "^0.10.3"
Expand Down
23 changes: 14 additions & 9 deletions tests/test_api_client.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
import pytest

from oterm.ollama import OlammaLLM, OllamaError


def test_generate():
@pytest.mark.asyncio
async def test_generate():
llm = OlammaLLM()
res = llm.completion("Please add 2 and 2.")
res = await llm.completion("Please add 2 and 2.")
assert "4" in res


def test_llm_context():
@pytest.mark.asyncio
async def test_llm_context():
llm = OlammaLLM()
llm.completion("I am testing oterm, a python client for Ollama.")
await llm.completion("I am testing oterm, a python client for Ollama.")
# There should now be a context saved for the conversation.
assert llm.context
res = llm.completion("Do you remember what I am testing?")
res = await llm.completion("Do you remember what I am testing?")
assert "oterm" in res


def test_errors():
@pytest.mark.asyncio
async def test_errors():
llm = OlammaLLM(model="non-existent-model")
try:
llm.completion("This should fail.")
except Exception as e:
assert "Error" in str(e)
await llm.completion("This should fail.")
except OllamaError as e:
assert "no such file or directory" in str(e)

0 comments on commit 24b80e9

Please sign in to comment.