Skip to content

Commit

Permalink
PaLM + Azure integrations
Browse files Browse the repository at this point in the history
Co-authored-by: Iryna Kondrashchenko <[email protected]>
  • Loading branch information
OKUA1 and iryna-kondr committed Jul 2, 2023
1 parent 54c70f4 commit 78e6eb0
Show file tree
Hide file tree
Showing 19 changed files with 829 additions and 417 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,4 @@ cython_debug/
test.py
tmp.ipynb
tmp.py
*.pickle
6 changes: 3 additions & 3 deletions skllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ordering is important here to prevent circular imports
from skllm.models.gpt_zero_shot_clf import (
from skllm.models.gpt.gpt_zero_shot_clf import (
MultiLabelZeroShotGPTClassifier,
ZeroShotGPTClassifier,
)
from skllm.models.gpt_few_shot_clf import FewShotGPTClassifier
from skllm.models.gpt_dyn_few_shot_clf import DynamicFewShotGPTClassifier
from skllm.models.gpt.gpt_few_shot_clf import FewShotGPTClassifier
from skllm.models.gpt.gpt_dyn_few_shot_clf import DynamicFewShotGPTClassifier
15 changes: 10 additions & 5 deletions skllm/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@


def get_chat_completion(
messages: dict, openai_key: str=None, openai_org: str=None, model: str="gpt-3.5-turbo", max_retries: int=3
messages: dict,
openai_key: str = None,
openai_org: str = None,
model: str = "gpt-3.5-turbo",
max_retries: int = 3,
):
"""
Gets a chat completion from the OpenAI API.
"""
"""Gets a chat completion from the OpenAI API."""
if model.startswith("gpt4all::"):
return _g4a_get_chat_completion(messages, model[9:])
else:
api = "azure" if model.startswith("azure::") else "openai"
if api == "azure":
model = model[7:]
return _oai_get_chat_completion(
messages, openai_key, openai_org, model, max_retries
messages, openai_key, openai_org, model, max_retries, api=api
)
110 changes: 105 additions & 5 deletions skllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,121 @@

_OPENAI_KEY_VAR = "SKLLM_CONFIG_OPENAI_KEY"
_OPENAI_ORG_VAR = "SKLLM_CONFIG_OPENAI_ORG"
_AZURE_API_BASE_VAR = "SKLLM_CONFIG_AZURE_API_BASE"
_AZURE_API_VERSION_VAR = "SKLLM_CONFIG_AZURE_API_VERSION"
_GOOGLE_PROJECT = "GOOGLE_CLOUD_PROJECT"

class SKLLMConfig():

class SKLLMConfig:
@staticmethod
def set_openai_key(key: str) -> None:
"""Sets the OpenAI key.
Parameters
----------
key : str
OpenAI key.
"""
os.environ[_OPENAI_KEY_VAR] = key

@staticmethod
def get_openai_key() -> Optional[str]:
"""Gets the OpenAI key.
Returns
-------
Optional[str]
OpenAI key.
"""
return os.environ.get(_OPENAI_KEY_VAR, None)

@staticmethod
def set_openai_org(key: str) -> None:
"""Sets OpenAI organization ID.
Parameters
----------
key : str
OpenAI organization ID.
"""
os.environ[_OPENAI_ORG_VAR] = key

@staticmethod
def get_openai_org() -> Optional[str]:
return os.environ.get(_OPENAI_ORG_VAR, None)
def get_openai_org() -> str:
"""Gets the OpenAI organization ID.
Returns
-------
str
OpenAI organization ID.
"""
return os.environ.get(_OPENAI_ORG_VAR, "")

@staticmethod
def get_azure_api_base() -> str:
"""Gets the API base for Azure.
Returns
-------
str
URL to be used as the base for the Azure API.
"""
base = os.environ.get(_AZURE_API_BASE_VAR, None)
if base is None:
raise RuntimeError("Azure API base is not set")
return base

@staticmethod
def set_azure_api_base(base: str) -> None:
"""Set the API base for Azure.
Parameters
----------
base : str
URL to be used as the base for the Azure API.
"""
os.environ[_AZURE_API_BASE_VAR] = base

@staticmethod
def set_azure_api_version(ver: str) -> None:
"""Set the API version for Azure.
Parameters
----------
ver : str
Azure API version.
"""
os.environ[_AZURE_API_VERSION_VAR] = ver

@staticmethod
def get_azure_api_version() -> str:
"""Gets the API version for Azure.
Returns
-------
str
Azure API version.
"""
return os.environ.get(_AZURE_API_VERSION_VAR, "2023-05-15")

@staticmethod
def get_google_project() -> Optional[str]:
"""Gets the Google Cloud project ID.
Returns
-------
Optional[str]
Google Cloud project ID.
"""
return os.environ.get(_GOOGLE_PROJECT, None)

@staticmethod
def set_google_project(project: str) -> None:
"""Sets the Google Cloud project ID.
Parameters
----------
project : str
Google Cloud project ID.
"""
os.environ[_GOOGLE_PROJECT] = project
41 changes: 41 additions & 0 deletions skllm/google/completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from time import sleep

from vertexai.preview.language_models import ChatModel, TextGenerationModel

# TODO reduce code duplication for retrying logic


def get_completion(model: str, text: str, max_retries: int = 3):
for _ in range(max_retries):
try:
if model.startswith("text-"):
model = TextGenerationModel.from_pretrained(model)
else:
model = TextGenerationModel.get_tuned_model(model)
response = model.predict(text, temperature=0.0)
return response.text
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
sleep(3)
print(
f"Could not obtain the completion after {max_retries} retries: `{error_type} ::"
f" {error_msg}`"
)


def get_completion_chat_mode(model: str, context: str, text: str, max_retries: int = 3):
for _ in range(max_retries):
try:
model = ChatModel.from_pretrained(model)
chat = model.start_chat(context=context)
response = chat.send_message(text, temperature=0.0)
return response.text
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
sleep(3)
print(
f"Could not obtain the completion after {max_retries} retries: `{error_type} ::"
f" {error_msg}`"
)
13 changes: 13 additions & 0 deletions skllm/google/tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pandas import DataFrame
from vertexai.preview.language_models import TextGenerationModel


def tune(model: str, data: DataFrame, train_steps: int = 100):
model = TextGenerationModel.from_pretrained(model)
model.tune_model(
training_data=data,
train_steps=train_steps,
tuning_job_location="europe-west4", # the only supported training location atm
tuned_model_location="us-central1", # the only supported deployment location atm
)
return model # ._job
Loading

0 comments on commit 78e6eb0

Please sign in to comment.