Skip to content

Commit

Permalink
fix(dspy) lint and ruff fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Anindyadeep committed May 13, 2024
1 parent 7a605d6 commit 0df0730
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 22 deletions.
70 changes: 70 additions & 0 deletions docs/api/language_model_clients/PremAI.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
---
sidebar_position: 5
---

# dsp.PremAI

[PremAI](https://app.premai.io) is a unified platform that lets you build powerful production-ready GenAI-powered applications with the least effort so that you can focus more on user experience and overall growth. With dspy, you can connect with several [best-in-class LLMs](https://models.premai.io/) of your choice with a single interface.

### Prerequisites

Refer to the [quick start](https://docs.premai.io/introduction) guide to getting started with the PremAI platform, create your first project and grab your API key.

### Usage

Please make sure you have premai python sdk installed. Otherwise you can do it using this command:

```bash
pip install -U premai
```

Here is a quick example on how to use premai python sdk with dspy

```python
from dspy import PremAI

llm = PremAI(model='mistral-tiny', project_id=123, api_key="your-premai-api-key")
print(llm("what is a large language model"))
```

> Please note: Project ID 123 is just an example. You can find your project ID inside our platform under which you created your project.
### Constructor

The constructor initializes the base class `LM` and verifies the `api_key` provided or defined through the `PREMAI_API_KEY` environment variable.

```python
class PremAI(LM):
def __init__(
self,
model: str,
project_id: int,
api_key: str,
base_url: Optional[str] = None,
session_id: Optional[int] = None,
**kwargs,
) -> None:
```

**Parameters:**

- `model` (_str_): Models supported by PremAI. Example: `mistral-tiny`. We recommend using the model selected in [project launchpad](https://docs.premai.io/get-started/launchpad).
- `project_id` (_int_): The [project id](https://docs.premai.io/get-started/projects) which contains the model of choice.
- `api_key` (_Optional[str]_, _optional_): API provider from PremAI. Defaults to None.
- `session_id` (_Optional[int]_, _optional_): The ID of the session to use. It helps to track the chat history.
- `**kwargs`: Additional language model arguments will be passed to the API provider.

### Methods

#### `__call__(self, prompt: str, **kwargs) -> List[Dict[str, Any]]`

Retrieves completions from PremAI by calling `request`.

Internally, the method handles the specifics of preparing the request prompt and corresponding payload to obtain the response.

After generation, the completions are post-processed based on the `model_type` parameter.

**Parameters:**

- `prompt` (_str_): Prompt to send to PremAI.
- `**kwargs`: Additional keyword arguments for completion request. Example: parameters like `temperature`, `max_tokens` etc. You can find all the additional kwargs [here](https://docs.premai.io/get-started/sdk#optional-parameters).
5 changes: 4 additions & 1 deletion dsp/modules/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def inspect_history(self, n: int = 1, skip: int = 0):
or provider == "groq"
or provider == "Bedrock"
or provider == "Sagemaker"
or provider == "premai"
):
printed.append((prompt, x["response"]))
elif provider == "anthropic":
Expand Down Expand Up @@ -85,7 +86,7 @@ def inspect_history(self, n: int = 1, skip: int = 0):
if provider == "cohere" or provider == "Bedrock" or provider == "Sagemaker":
text = choices
elif provider == "openai" or provider == "ollama":
text = ' ' + self._get_choice_text(choices[0]).strip()
text = " " + self._get_choice_text(choices[0]).strip()
elif provider == "clarifai" or provider == "claude":
text = choices
elif provider == "groq":
Expand All @@ -98,6 +99,8 @@ def inspect_history(self, n: int = 1, skip: int = 0):
text = choices[0]
elif provider == "ibm":
text = choices
elif provider == "premai":
text = choices
else:
text = choices[0]["text"]
printing_value += self.print_green(text, end="")
Expand Down
50 changes: 29 additions & 21 deletions dsp/modules/premai.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
import logging
from typing import Optional
import os
from typing import Any, Optional

import backoff
import premai.errors

from dsp.modules.lm import LM

try:
import premai
except ImportError as err:
raise ImportError(
"Not loading Mistral AI because it is not installed. Install it with `pip install premai`.",
"Not loading Prem AI because it is not installed. Install it with `pip install -U premai`.",
) from err


# Set up logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class ChatPremAPIError(Exception):
"""Error with the `PremAI` API."""

Expand All @@ -31,7 +25,7 @@ def backoff_hdlr(details) -> None:
See more at: https://pypi.org/project/backoff/
"""
logger.info(
print(
"Backing off {wait:0.1f} seconds after {tries} tries calling function {target} with kwargs {kwargs}".format(
**details,
),
Expand All @@ -45,15 +39,24 @@ def giveup_hdlr(details) -> bool:
return True


def get_premai_api_key(api_key: Optional[str] = None) -> str:
"""Retrieve the PreMAI API key from a passed argument or environment variable."""
api_key = api_key or os.environ.get("PREMAI_API_KEY")
if api_key is None:
raise RuntimeError(
"No API key found. See the quick start guide at https://docs.premai.io/introduction to get your API key.",
)
return api_key


class PremAI(LM):
"""Wrapper around Prem AI's API."""

def __init__(
self,
model: str,
project_id: int,
api_key: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
session_id: Optional[int] = None,
**kwargs,
) -> None:
Expand All @@ -63,9 +66,10 @@ def __init__(
The name of model name
project_id: int
"The project ID in which the experiments or deployments are carried out. can find all your projects here: https://app.premai.io/projects/"
api_key: str
Prem AI API key, to connect with the API.
session_id: int
api_key: Optional[str]
Prem AI API key, to connect with the API. If not provided then it will check from env var by the name
PREMAI_API_KEY
session_id: Optional[int]
The ID of the session to use. It helps to track the chat history.
**kwargs: dict
Additional arguments to pass to the API provider
Expand All @@ -76,11 +80,10 @@ def __init__(
self.project_id = project_id
self.session_id = session_id

if base_url is not None:
self.client = premai.Prem(api_key=api_key, base_url=base_url)
else:
self.client = premai.Prem(api_key=api_key)
api_key = get_premai_api_key(api_key=api_key)
self.client = premai.Prem(api_key=api_key)
self.provider = "premai"
self.history: list[dict[str, Any]] = []

self.kwargs = {
"model": model,
Expand All @@ -89,7 +92,7 @@ def __init__(
**kwargs,
}
if session_id is not None:
kwargs["sesion_id"] = session_id
kwargs["session_id"] = session_id

def _get_all_kwargs(self, **kwargs) -> dict:
other_kwargs = {
Expand Down Expand Up @@ -134,7 +137,12 @@ def basic_request(self, prompt, **kwargs) -> str:
if not response.choices:
raise ChatPremAPIError("ChatResponse must have at least one candidate")

return response.choices[0].message.content or ""
content = response.choices[0].message.content
output_text = content or ""

self.history.append({"prompt": prompt, "response": content, "kwargs": all_kwargs, "raw_kwargs": kwargs})

return output_text

@backoff.on_exception(
backoff.expo,
Expand Down

0 comments on commit 0df0730

Please sign in to comment.