diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index 79230dde..d3c01fc3 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -1,6 +1,6 @@ -from typing import AsyncIterator, cast +from typing import Any, AsyncIterator, Union, cast -from ragna.core import Message, Source +from ragna.core import Message, MessageRole, Source from ._http_api import HttpApiAssistant @@ -14,7 +14,7 @@ class Ai21LabsAssistant(HttpApiAssistant): def display_name(cls) -> str: return f"AI21Labs/jurassic-2-{cls._MODEL_TYPE}" - def _make_system_content(self, sources: list[Source]) -> str: + def _make_rag_system_content(self, sources: list[Source]) -> str: instruction = ( "You are a helpful assistant that answers user questions given the context below. " "If you don't know the answer, just say so. Don't try to make up an answer. " @@ -22,13 +22,46 @@ def _make_system_content(self, sources: list[Source]) -> str: ) return instruction + "\n\n".join(source.content for source in sources) - async def answer( - self, messages: list[Message], *, max_new_tokens: int = 256 - ) -> AsyncIterator[str]: + def _render_prompt(self, prompt: Union[str, list[Message]]) -> Union[str, list]: + """ + Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format. + + Returns: + ordered list of dicts with 'text' and 'role' keys + """ + if isinstance(prompt, str): + messages = [Message(content=prompt, role=MessageRole.USER)] + else: + messages = prompt + return [ + {"text": message.content, "role": message.role.value} + for message in messages + if message.role is not MessageRole.SYSTEM + ] + + async def generate( + self, + prompt: Union[str, list[Message]], + *, + system_prompt: str = "You are a helpful assistant.", + max_new_tokens: int = 256, + ) -> AsyncIterator[dict[str, Any]]: + """ + Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() + This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. + + Args: + prompt: Either a single prompt string or a list of ragna messages + system_prompt: System prompt string + max_new_tokens: Max number of completion tokens (default 256_ + + Returns: + async streamed inference response string chunks + """ # See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters # See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters # See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response - prompt, sources = (message := messages[-1]).content, message.sources + async with self._call_api( "POST", f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat", @@ -41,17 +74,23 @@ async def answer( "numResults": 1, "temperature": 0.0, "maxTokens": max_new_tokens, - "messages": [ - { - "text": prompt, - "role": "user", - } - ], - "system": self._make_system_content(sources), + "messages": self._render_prompt(prompt), + "system": system_prompt, }, ) as stream: async for data in stream: - yield cast(str, data["outputs"][0]["text"]) + yield data + + async def answer( + self, messages: list[Message], *, max_new_tokens: int = 256 + ) -> AsyncIterator[str]: + message = messages[-1] + async for data in self.generate( + [message], + system_prompt=self._make_rag_system_content(message.sources), + max_new_tokens=max_new_tokens, + ): + yield cast(str, data["outputs"][0]["text"]) # The Jurassic2Mid assistant receives a 500 internal service error from the remote diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index 11d08e07..06183675 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -1,6 +1,6 @@ -from typing import AsyncIterator, cast +from typing import Any, AsyncIterator, Union, cast -from ragna.core import Message, PackageRequirement, RagnaException, Requirement, Source +from ragna.core import Message, MessageRole, RagnaException, Source from ._http_api import HttpApiAssistant, HttpStreamingProtocol @@ -10,15 +10,11 @@ class AnthropicAssistant(HttpApiAssistant): _STREAMING_PROTOCOL = HttpStreamingProtocol.SSE _MODEL: str - @classmethod - def _extra_requirements(cls) -> list[Requirement]: - return [PackageRequirement("httpx_sse")] - @classmethod def display_name(cls) -> str: return f"Anthropic/{cls._MODEL}" - def _instructize_system_prompt(self, sources: list[Source]) -> str: + def _make_rag_system_prompt(self, sources: list[Source]) -> str: # See https://docs.anthropic.com/claude/docs/system-prompts # See https://docs.anthropic.com/claude/docs/long-context-window-tips#tips-for-document-qa instruction = ( @@ -36,12 +32,45 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str: + "" ) - async def answer( - self, messages: list[Message], *, max_new_tokens: int = 256 - ) -> AsyncIterator[str]: + def _render_prompt(self, prompt: Union[str, list[Message]]) -> list[dict]: + """ + Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format. + + Returns: + ordered list of dicts with 'content' and 'role' keys + """ + if isinstance(prompt, str): + messages = [Message(content=prompt, role=MessageRole.USER)] + else: + messages = prompt + return [ + {"role": message.role.value, "content": message.content} + for message in messages + if message.role is not MessageRole.SYSTEM + ] + + async def generate( + self, + prompt: Union[str, list[Message]], + *, + system_prompt: str = "You are a helpful assistant.", + max_new_tokens: int = 256, + ) -> AsyncIterator[dict[str, Any]]: + """ + Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() + This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. + + Args: + prompt: Either a single prompt string or a list of ragna messages + system_prompt: System prompt string + max_new_tokens: Max number of completion tokens (default 256) + + Returns: + async streamed inference response string chunks + """ # See https://docs.anthropic.com/claude/reference/messages_post # See https://docs.anthropic.com/claude/reference/streaming - prompt, sources = (message := messages[-1]).content, message.sources + async with self._call_api( "POST", "https://api.anthropic.com/v1/messages", @@ -53,23 +82,34 @@ async def answer( }, json={ "model": self._MODEL, - "system": self._instructize_system_prompt(sources), - "messages": [{"role": "user", "content": prompt}], + "system": system_prompt, + "messages": self._render_prompt(prompt), "max_tokens": max_new_tokens, "temperature": 0.0, "stream": True, }, ) as stream: async for data in stream: - # See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response - if "error" in data: - raise RagnaException(data["error"].pop("message"), **data["error"]) - elif data["type"] == "message_stop": - break - elif data["type"] != "content_block_delta": - continue - - yield cast(str, data["delta"].pop("text")) + yield data + + async def answer( + self, messages: list[Message], *, max_new_tokens: int = 256 + ) -> AsyncIterator[str]: + message = messages[-1] + async for data in self.generate( + [message], + system_prompt=self._make_rag_system_prompt(message.sources), + max_new_tokens=max_new_tokens, + ): + # See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response + if "error" in data: + raise RagnaException(data["error"].pop("message"), **data["error"]) + elif data["type"] == "message_stop": + break + elif data["type"] != "content_block_delta": + continue + + yield cast(str, data["delta"].pop("text")) class ClaudeOpus(AnthropicAssistant): diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index e32acb82..4da28e69 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -1,6 +1,6 @@ -from typing import AsyncIterator, cast +from typing import Any, AsyncIterator, Union, cast -from ragna.core import Message, RagnaException, Source +from ragna.core import Message, MessageRole, RagnaException, Source from ._http_api import HttpApiAssistant, HttpStreamingProtocol @@ -14,23 +14,59 @@ class CohereAssistant(HttpApiAssistant): def display_name(cls) -> str: return f"Cohere/{cls._MODEL}" - def _make_preamble(self) -> str: + def _make_rag_preamble(self) -> str: return ( "You are a helpful assistant that answers user questions given the included context. " "If you don't know the answer, just say so. Don't try to make up an answer. " "Only use the included documents below to generate the answer." ) - def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]: + def _make_rag_source_documents(self, sources: list[Source]) -> list[dict[str, str]]: return [{"title": source.id, "snippet": source.content} for source in sources] - async def answer( - self, messages: list[Message], *, max_new_tokens: int = 256 - ) -> AsyncIterator[str]: + def _render_prompt(self, prompt: Union[str, list[Message]]) -> str: + """ + Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format. + + Returns: + prompt string + """ + if isinstance(prompt, str): + messages = [Message(content=prompt, role=MessageRole.USER)] + else: + messages = prompt + + for message in reversed(messages): + if message.role is MessageRole.USER: + return message.content + else: + raise RagnaException + + async def generate( + self, + prompt: Union[str, list[Message]], + source_documents: list[dict[str, str]], + *, + system_prompt: str = "You are a helpful assistant.", + max_new_tokens: int = 256, + ) -> AsyncIterator[dict[str, Any]]: + """ + Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() + This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. + + Args: + prompt: Either a single prompt string or a list of ragna messages + system_prompt: System prompt string + source_documents: List of source content dicts with 'title' and 'snippet' keys + max_new_tokens: Max number of completion tokens (default 256) + + Returns: + async streamed inference response string chunks + """ # See https://docs.cohere.com/docs/cochat-beta # See https://docs.cohere.com/reference/chat # See https://docs.cohere.com/docs/retrieval-augmented-generation-rag - prompt, sources = (message := messages[-1]).content, message.sources + async with self._call_api( "POST", "https://api.cohere.ai/v1/chat", @@ -40,23 +76,35 @@ async def answer( "authorization": f"Bearer {self._api_key}", }, json={ - "preamble_override": self._make_preamble(), - "message": prompt, + "preamble_override": system_prompt, + "message": self._render_prompt(prompt), "model": self._MODEL, "stream": True, "temperature": 0.0, "max_tokens": max_new_tokens, - "documents": self._make_source_documents(sources), + "documents": source_documents, }, ) as stream: - async for event in stream: - if event["event_type"] == "stream-end": - if event["event_type"] == "COMPLETE": - break - - raise RagnaException(event["error_message"]) - if "text" in event: - yield cast(str, event["text"]) + async for data in stream: + yield data + + async def answer( + self, messages: list[Message], *, max_new_tokens: int = 256 + ) -> AsyncIterator[str]: + message = messages[-1] + async for data in self.generate( + prompt=message.content, + system_prompt=self._make_rag_preamble(), + source_documents=self._make_rag_source_documents(message.sources), + max_new_tokens=max_new_tokens, + ): + if data["event_type"] == "stream-end": + if data["event_type"] == "COMPLETE": + break + + raise RagnaException(data["error_message"]) + if "text" in data: + yield cast(str, data["text"]) class Command(CohereAssistant): diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index 627afd80..95b094c7 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -1,6 +1,6 @@ -from typing import AsyncIterator +from typing import Any, AsyncIterator, Union -from ragna.core import Message, Source +from ragna.core import Message, MessageRole, Source from ._http_api import HttpApiAssistant, HttpStreamingProtocol @@ -25,19 +25,39 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str: ] ) - async def answer( - self, messages: list[Message], *, max_new_tokens: int = 256 - ) -> AsyncIterator[str]: - prompt, sources = (message := messages[-1]).content, message.sources + def _render_prompt(self, prompt: Union[str, list[Message]]) -> list[dict]: + if isinstance(prompt, str): + messages = [Message(content=prompt, role=MessageRole.USER)] + else: + messages = prompt + return [ + {"parts": [{"text": message.content, "role": message.role.value}]} + for message in messages + if message.role is not MessageRole.SYSTEM + ] + + async def generate( + self, prompt: Union[str, list[Message]], *, max_new_tokens: int = 256 + ) -> AsyncIterator[dict[str, Any]]: + """ + Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() + This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. + + Args: + prompt: Either a single prompt string or a list of ragna messages + max_new_tokens: Max number of completion tokens (default 256) + + Returns: + async streamed inference response string chunks + """ + # See https://ai.google.dev/api/generate-content#v1beta.models.streamGenerateContent async with self._call_api( "POST", f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent", params={"key": self._api_key}, headers={"Content-Type": "application/json"}, json={ - "contents": [ - {"parts": [{"text": self._instructize_prompt(prompt, sources)}]} - ], + "contents": self._render_prompt(prompt), # https://ai.google.dev/docs/safety_setting_gemini "safetySettings": [ { @@ -57,10 +77,20 @@ async def answer( "maxOutputTokens": max_new_tokens, }, }, - parse_kwargs=dict(item="item.candidates.item.content.parts.item.text"), + parse_kwargs=dict(item="item"), ) as stream: - async for chunk in stream: - yield chunk + async for data in stream: + yield data + + async def answer( + self, messages: list[Message], *, max_new_tokens: int = 256 + ) -> AsyncIterator[str]: + message = messages[-1] + async for data in self.generate( + self._instructize_prompt(message.content, message.sources), + max_new_tokens=max_new_tokens, + ): + yield data["candidates"][0]["content"]["parts"][0]["text"] class GeminiPro(GoogleAssistant): diff --git a/ragna/assistants/_http_api.py b/ragna/assistants/_http_api.py index adc794b8..0969ec35 100644 --- a/ragna/assistants/_http_api.py +++ b/ragna/assistants/_http_api.py @@ -97,6 +97,8 @@ async def stream() -> AsyncIterator[Any]: yield stream() + assert False + @contextlib.asynccontextmanager async def _stream_jsonl( self, @@ -115,6 +117,8 @@ async def stream() -> AsyncIterator[Any]: yield stream() + assert False + # ijson does not support reading from an (async) iterator, but only from file-like # objects, i.e. https://docs.python.org/3/tutorial/inputoutput.html#methods-of-file-objects. # See https://github.com/ICRAR/ijson/issues/44 for details. @@ -158,6 +162,8 @@ async def stream() -> AsyncIterator[Any]: yield stream() + assert False + async def _assert_api_call_is_success(self, response: httpx.Response) -> None: if response.is_success: return diff --git a/ragna/assistants/_ollama.py b/ragna/assistants/_ollama.py index 591c7ed1..aaee53af 100644 --- a/ragna/assistants/_ollama.py +++ b/ragna/assistants/_ollama.py @@ -32,17 +32,18 @@ def _url(self) -> str: async def answer( self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: - prompt, sources = (message := messages[-1]).content, message.sources - async with self._call_openai_api( - prompt, sources, max_new_tokens=max_new_tokens - ) as stream: - async for data in stream: - # Modeled after - # https://github.com/ollama/ollama/blob/06a1508bfe456e82ba053ea554264e140c5057b5/examples/python-loganalysis/readme.md?plain=1#L57-L62 - if "error" in data: - raise RagnaException(data["error"]) - if not data["done"]: - yield cast(str, data["message"]["content"]) + message = messages[-1] + async for data in self.generate( + [message], + system_prompt=self._make_system_content(message.sources), + max_new_tokens=max_new_tokens, + ): + # Modeled after + # https://github.com/ollama/ollama/blob/06a1508bfe456e82ba053ea554264e140c5057b5/examples/python-loganalysis/readme.md?plain=1#L57-L62 + if "error" in data: + raise RagnaException(data["error"]) + if not data["done"]: + yield cast(str, data["message"]["content"]) class OllamaGemma2B(OllamaAssistant): diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index 0227beb8..7674393f 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -1,8 +1,8 @@ import abc from functools import cached_property -from typing import Any, AsyncContextManager, AsyncIterator, Optional, cast +from typing import Any, AsyncIterator, Optional, Union, cast -from ragna.core import Message, Source +from ragna.core import Message, MessageRole, Source from ._http_api import HttpApiAssistant, HttpStreamingProtocol @@ -14,7 +14,7 @@ class OpenaiLikeHttpApiAssistant(HttpApiAssistant): @abc.abstractmethod def _url(self) -> str: ... - def _make_system_content(self, sources: list[Source]) -> str: + def _make_rag_system_content(self, sources: list[Source]) -> str: # See https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb instruction = ( "You are an helpful assistants that answers user questions given the context below. " @@ -23,11 +23,47 @@ def _make_system_content(self, sources: list[Source]) -> str: ) return instruction + "\n\n".join(source.content for source in sources) - def _call_openai_api( - self, prompt: str, sources: list[Source], *, max_new_tokens: int - ) -> AsyncContextManager[AsyncIterator[dict[str, Any]]]: - # See https://platform.openai.com/docs/api-reference/chat/create - # and https://platform.openai.com/docs/api-reference/chat/streaming + def _render_prompt( + self, prompt: Union[str, list[Message]], system_prompt: str + ) -> list[dict]: + """ + Ingests ragna messages-list or a single string prompt and converts to assistant-appropriate format. + + Returns: + ordered list of dicts with 'content' and 'role' keys + """ + if isinstance(prompt, str): + messages = [Message(content=prompt, role=MessageRole.USER)] + else: + messages = prompt + return [ + {"role": "system", "content": system_prompt}, + *( + {"role": message.role.value, "content": message.content} + for message in messages + if message.role is not MessageRole.SYSTEM + ), + ] + + async def generate( + self, + prompt: Union[str, list[Message]], + *, + system_prompt: str = "You are a helpful assistant.", + max_new_tokens: int = 256, + ) -> AsyncIterator[dict[str, Any]]: + """ + Primary method for calling assistant inference, either as a one-off request from anywhere in ragna, or as part of self.answer() + This method should be called for tasks like pre-processing, agentic tasks, or any other user-defined calls. + + Args: + prompt: Either a single prompt string or a list of ragna messages + system_prompt: System prompt string + max_new_tokens: Max number of completion tokens (default 256) + + Returns: + yield call to self._call_api with formatted headers and json + """ headers = { "Content-Type": "application/json", } @@ -35,16 +71,7 @@ def _call_openai_api( headers["Authorization"] = f"Bearer {self._api_key}" json_ = { - "messages": [ - { - "role": "system", - "content": self._make_system_content(sources), - }, - { - "role": "user", - "content": prompt, - }, - ], + "messages": self._render_prompt(prompt, system_prompt), "temperature": 0.0, "max_tokens": max_new_tokens, "stream": True, @@ -52,21 +79,25 @@ def _call_openai_api( if self._MODEL is not None: json_["model"] = self._MODEL - return self._call_api("POST", self._url, headers=headers, json=json_) + async with self._call_api( + "POST", self._url, headers=headers, json=json_ + ) as stream: + async for data in stream: + yield data async def answer( self, messages: list[Message], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: - prompt, sources = (message := messages[-1]).content, message.sources - async with self._call_openai_api( - prompt, sources, max_new_tokens=max_new_tokens - ) as stream: - async for data in stream: - choice = data["choices"][0] - if choice["finish_reason"] is not None: - break - - yield cast(str, choice["delta"]["content"]) + message = messages[-1] + async for data in self.generate( + [message], + system_prompt=self._make_rag_system_content(message.sources), + max_new_tokens=max_new_tokens, + ): + choice = data["choices"][0] + if choice["finish_reason"] is not None: + break + yield cast(str, choice["delta"]["content"]) class OpenaiAssistant(OpenaiLikeHttpApiAssistant): diff --git a/tests/assistants/test_api.py b/tests/assistants/test_api.py index f7c9c594..067b7269 100644 --- a/tests/assistants/test_api.py +++ b/tests/assistants/test_api.py @@ -82,7 +82,28 @@ def __init__(self, base_url): super().__init__() self._endpoint = f"{base_url}/{self._STREAMING_PROTOCOL.name.lower()}" - async def answer(self, messages): + # def generate(self, messages): + # if self._STREAMING_PROTOCOL is HttpStreamingProtocol.JSON: + # parse_kwargs = dict(item="item") + # else: + # parse_kwargs = dict() + # + # return self._call_api( + # "POST", + # self._endpoint, + # content=messages[-1].content, + # parse_kwargs=parse_kwargs, + # ) + # + # async def answer(self, messages): + # async with self.generate(messages) as stream: + # async for data in stream: + # if data.get("break"): + # break + # + # yield data + + async def generate(self, messages): if self._STREAMING_PROTOCOL is HttpStreamingProtocol.JSON: parse_kwargs = dict(item="item") else: @@ -94,11 +115,15 @@ async def answer(self, messages): content=messages[-1].content, parse_kwargs=parse_kwargs, ) as stream: - async for chunk in stream: - if chunk.get("break"): - break + async for data in stream: + yield data + + async def answer(self, messages): + async for data in self.generate(messages): + if data.get("break"): + break - yield chunk + yield data @skip_on_windows