Skip to content

Commit

Permalink
updated backends
Browse files Browse the repository at this point in the history
  • Loading branch information
OKUA1 committed May 20, 2024
1 parent eaf7622 commit da78c9e
Show file tree
Hide file tree
Showing 13 changed files with 86 additions and 24 deletions.
Binary file removed logo.png
Binary file not shown.
28 changes: 28 additions & 0 deletions skllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
_AZURE_API_BASE_VAR = "SKLLM_CONFIG_AZURE_API_BASE"
_AZURE_API_VERSION_VAR = "SKLLM_CONFIG_AZURE_API_VERSION"
_GOOGLE_PROJECT = "GOOGLE_CLOUD_PROJECT"
_GPT_URL_VAR = "SKLLM_CONFIG_GPT_URL"


class SKLLMConfig:
Expand Down Expand Up @@ -142,3 +143,30 @@ def set_google_project(project: str) -> None:
Google Cloud project ID.
"""
os.environ[_GOOGLE_PROJECT] = project

@staticmethod
def set_gpt_url(url: str):
"""Sets the GPT URL.
Parameters
----------
url : str
GPT URL.
"""
os.environ[_GPT_URL_VAR] = url

@staticmethod
def get_gpt_url() -> Optional[str]:
"""Gets the GPT URL.
Returns
-------
Optional[str]
GPT URL.
"""
return os.environ.get(_GPT_URL_VAR, None)

@staticmethod
def reset_gpt_url():
"""Resets the GPT URL."""
os.environ.pop(_GPT_URL_VAR, None)
4 changes: 2 additions & 2 deletions skllm/llm/gpt/clients/openai/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ def get_chat_completion(
-------
completion : dict
"""
if api == "openai":
if api in ("openai", "custom_url"):
client = set_credentials(key, org)
elif api == "azure":
client = set_azure_credentials(key, org)
else:
raise ValueError("Invalid API")
model_dict = {"model": model}
if json_response and model in ["gpt-4-1106-preview", "gpt-3.5-turbo-1106"]:
if json_response and api == "openai":
model_dict["response_format"] = {"type": "json_object"}
completion = client.chat.completions.create(
temperature=0.0, messages=messages, **model_dict
Expand Down
4 changes: 3 additions & 1 deletion skllm/llm/gpt/clients/openai/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from skllm.config import SKLLMConfig as _Config
from time import sleep
from openai import OpenAI, AzureOpenAI
from skllm.config import SKLLMConfig as _Config


def set_credentials(key: str, org: str) -> None:
Expand All @@ -14,7 +15,8 @@ def set_credentials(key: str, org: str) -> None:
org : str
The OPEN AI organization ID to use.
"""
client = OpenAI(api_key=key, organization=org)
url = _Config.get_gpt_url()
client = OpenAI(api_key=key, organization=org, base_url=url)
return client


Expand Down
10 changes: 8 additions & 2 deletions skllm/llm/gpt/clients/openai/embedding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from skllm.llm.gpt.clients.openai.credentials import set_credentials
from skllm.llm.gpt.clients.openai.credentials import set_credentials, set_azure_credentials
from skllm.utils import retry
import openai
from openai import OpenAI
Expand All @@ -10,6 +10,7 @@ def get_embedding(
key: str,
org: str,
model: str = "text-embedding-ada-002",
api: str = "openai"
):
"""
Encodes a string and return the embedding for a string.
Expand All @@ -26,13 +27,18 @@ def get_embedding(
The model to use. Defaults to "text-embedding-ada-002".
max_retries : int, optional
The maximum number of retries to use. Defaults to 3.
api: str, optional
The API to use. Must be one of "openai" or "azure". Defaults to "openai".
Returns
-------
emb : list
The GPT embedding for the string.
"""
client = set_credentials(key, org)
if api in ("openai", "custom_url"):
client = set_credentials(key, org)
elif api == "azure":
client = set_azure_credentials(key, org)
text = [str(t).replace("\n", " ") for t in text]
embeddings = []
emb = client.embeddings.create(input=text, model=model)
Expand Down
17 changes: 11 additions & 6 deletions skllm/llm/gpt/completion.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import warnings
from skllm.llm.gpt.clients.openai.completion import (
get_chat_completion as _oai_get_chat_completion,
)
from skllm.llm.gpt.clients.gpt4all.completion import (
get_chat_completion as _g4a_get_chat_completion,
)

from skllm.llm.gpt.utils import split_to_api_and_model
from skllm.config import SKLLMConfig as _Config

def get_chat_completion(
messages: dict,
Expand All @@ -14,12 +16,15 @@ def get_chat_completion(
json_response: bool = False,
):
"""Gets a chat completion from the OpenAI compatible API."""
if model.startswith("gpt4all::"):
return _g4a_get_chat_completion(messages, model[9:])
api, model = split_to_api_and_model(model)
if api == "gpt4all":
return _g4a_get_chat_completion(messages, model)
else:
api = "azure" if model.startswith("azure::") else "openai"
if api == "azure":
model = model[7:]
url = _Config.get_gpt_url()
if api == "openai" and url is not None:
warnings.warn(f"You are using the OpenAI backend with a custom URL: {url}; did you mean to use the `custom_url` backend?\nTo use the OpenAI backend, please remove the custom URL using `SKLLMConfig.reset_gpt_url()`.")
elif api == "custom_url" and url is None:
raise ValueError("You are using the `custom_url` backend but no custom URL was provided. Please set it using `SKLLMConfig.set_gpt_url(<url>)`.")
return _oai_get_chat_completion(
messages,
openai_key,
Expand Down
9 changes: 4 additions & 5 deletions skllm/llm/gpt/embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from skllm.llm.gpt.clients.openai.embedding import get_embedding as _oai_get_embedding

from skllm.llm.gpt.utils import split_to_api_and_model

def get_embedding(
text: str,
Expand All @@ -26,8 +26,7 @@ def get_embedding(
emb : list
The GPT embedding for the string.
"""
if model.startswith("gpt4all::"):
api, model = split_to_api_and_model(model)
if api == ("gpt4all"):
raise ValueError("GPT4All is not supported for embeddings")
elif model.startswith("azure::"):
raise ValueError("Azure is not supported for embeddings")
return _oai_get_embedding(text, key, org, model)
return _oai_get_embedding(text, key, org, model, api=api)
2 changes: 1 addition & 1 deletion skllm/llm/gpt/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _get_embeddings(self, text: np.ndarray) -> List[List[float]]:

# for now this works only with OpenAI
class GPTTunableMixin(BaseTunableMixin):
_supported_tunable_models = ["gpt-3.5-turbo-0613", "gpt-3.5-turbo"]
_supported_tunable_models = ["gpt-3.5-turbo-0125", "gpt-3.5-turbo"]

def _build_label(self, label: str):
return json.dumps({"label": label})
Expand Down
12 changes: 12 additions & 0 deletions skllm/llm/gpt/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Tuple

SUPPORTED_APIS = ["openai", "azure", "gpt4all", "custom_url"]


def split_to_api_and_model(model: str) -> Tuple[str, str]:
if "::" not in model:
return "openai", model
for api in SUPPORTED_APIS:
if model.startswith(f"{api}::"):
return api, model[len(api) + 2 :]
raise ValueError(f"Unsupported API: {model.split('::')[0]}")
12 changes: 11 additions & 1 deletion skllm/llm/vertex/completion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from skllm.utils import retry
from vertexai.preview.language_models import ChatModel, TextGenerationModel
from vertexai.language_models import ChatModel, TextGenerationModel
from vertexai.generative_models import GenerativeModel, GenerationConfig


@retry(max_retries=3)
Expand All @@ -18,3 +19,12 @@ def get_completion_chat_mode(model: str, context: str, text: str):
chat = model_instance.start_chat(context=context)
response = chat.send_message(text, temperature=0.0)
return response.text


@retry(max_retries=3)
def get_completion_chat_gemini(model: str, context: str, text: str):
model_instance = GenerativeModel(model, system_instruction=context)
response = model_instance.generate_content(
text, generation_config=GenerationConfig(temperature=0.0)
)
return response.text
8 changes: 4 additions & 4 deletions skllm/llm/vertex/mixin.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from typing import Optional, Union, List, Any, Dict, Mapping
from skllm.config import SKLLMConfig as _Config
from typing import Optional, Union, List, Any, Dict
from skllm.llm.base import (
BaseClassifierMixin,
BaseEmbeddingMixin,
BaseTextCompletionMixin,
BaseTunableMixin,
)
from skllm.llm.vertex.tuning import tune
from skllm.llm.vertex.completion import get_completion_chat_mode, get_completion
from skllm.llm.vertex.completion import get_completion_chat_mode, get_completion, get_completion_chat_gemini
from skllm.utils import extract_json_key
import numpy as np
from tqdm import tqdm
import pandas as pd


Expand All @@ -34,6 +32,8 @@ def _get_chat_completion(
raise ValueError("Only messages as strings are supported.")
if model.startswith("chat-"):
completion = get_completion_chat_mode(model, system_message, messages)
elif model.startswith("gemini-"):
completion = get_completion_chat_gemini(model, system_message, messages)
else:
completion = get_completion(model, messages)
return str(completion)
Expand Down
2 changes: 1 addition & 1 deletion skllm/llm/vertex/tuning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pandas import DataFrame
from vertexai.preview.language_models import TextGenerationModel
from vertexai.language_models import TextGenerationModel


def tune(model: str, data: DataFrame, train_steps: int = 100):
Expand Down
2 changes: 1 addition & 1 deletion skllm/models/gpt/vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class GPTVectorizer(_BaseVectorizer, _GPTEmbeddingMixin):
def __init__(
self,
model: str = "text-embedding-ada-002",
model: str = "text-embedding-3-small",
batch_size: int = 1,
key: Optional[str] = None,
org: Optional[str] = None,
Expand Down

0 comments on commit da78c9e

Please sign in to comment.