Skip to content

Commit

Permalink
removed adding model as a required argument, so that it aligns with p…
Browse files Browse the repository at this point in the history
…rem's workflow
  • Loading branch information
Anindyadeep committed May 16, 2024
1 parent 47f9199 commit 62eaf88
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions dsp/modules/premai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
try:
import premai

premai_error = premai.errors.UnexpectedStatus
premai_api_error = premai.errors.UnexpectedStatus
except ImportError:
premai_api_error = Exception
except AttributeError:
Expand Down Expand Up @@ -49,18 +49,18 @@ class PremAI(LM):

def __init__(
self,
model: str,
project_id: int,
model: Optional[str] = None,
api_key: Optional[str] = None,
session_id: Optional[int] = None,
**kwargs,
) -> None:
"""Parameters
model: str
The name of model name
project_id: int
"The project ID in which the experiments or deployments are carried out. can find all your projects here: https://app.premai.io/projects/"
model: Optional[str]
The name of model deployed on launchpad. When None, it will show 'default'
api_key: Optional[str]
Prem AI API key, to connect with the API. If not provided then it will check from env var by the name
PREMAI_API_KEY
Expand All @@ -69,6 +69,7 @@ def __init__(
**kwargs: dict
Additional arguments to pass to the API provider
"""
model = "default" if model is None else model
super().__init__(model)
if premai_api_error == Exception:
raise ImportError(
Expand All @@ -85,13 +86,18 @@ def __init__(
self.history: list[dict[str, Any]] = []

self.kwargs = {
"model": model,
"temperature": 0.17,
"max_tokens": 150,
**kwargs,
}
if session_id is not None:
kwargs["session_id"] = session_id
self.kwargs["session_id"] = session_id

# However this is not recommended to change the model once
# deployed from launchpad

if model != "default":
self.kwargs["model"] = model

def _get_all_kwargs(self, **kwargs) -> dict:
other_kwargs = {
Expand All @@ -111,7 +117,6 @@ def _get_all_kwargs(self, **kwargs) -> dict:
"frequency_penalty",
"presence_penalty",
"tools",
"model",
]

for key in _keys_that_cannot_be_none:
Expand All @@ -122,15 +127,15 @@ def _get_all_kwargs(self, **kwargs) -> dict:
def basic_request(self, prompt, **kwargs) -> str:
"""Handles retrieval of completions from Prem AI whilst handling API errors."""
all_kwargs = self._get_all_kwargs(**kwargs)
message = []
messages = []

if "system_prompt" in all_kwargs:
message.append({"role": "system", "content": all_kwargs["system_prompt"]})
message.append({"role": "user", "content": prompt})
messages.append({"role": "system", "content": all_kwargs["system_prompt"]})
messages.append({"role": "user", "content": prompt})

response = self.client.chat.completions.create(
project_id=self.project_id,
messages=message,
messages=messages,
**all_kwargs,
)
if not response.choices:
Expand Down

0 comments on commit 62eaf88

Please sign in to comment.