Skip to content

Commit

Permalink
Merge pull request stanfordnlp#37 from stanfordnlp/enhancement/azure_…
Browse files Browse the repository at this point in the history
…support_for_openai_models

enhancement(GPT-LLM): azure API support
  • Loading branch information
okhat authored Apr 8, 2023
2 parents a55faad + c898dff commit 1c5e734
Showing 1 changed file with 36 additions and 12 deletions.
48 changes: 36 additions & 12 deletions dsp/modules/gpt3.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import functools
import json
from typing import Optional, Any, cast
from typing import Any, Literal, Optional, cast

import backoff
import openai
import openai.error
from openai.openai_object import OpenAIObject
import backoff

from dsp.modules.lm import LM
from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory, cache_turn_on
from dsp.modules.lm import LM


def backoff_hdlr(details):
Expand All @@ -20,21 +21,43 @@ def backoff_hdlr(details):


class GPT3(LM):
"""Wrapper around OpenAI's GPT-3 API.
Currently supported models include `davinci`, `curie`, `babbage`, `ada`, `gpt-3.5-turbo`, and `gpt-3.5-turbo-0301`.
"""Wrapper around OpenAI's GPT API. Supports both the OpenAI and Azure APIs.
Args:
model (str, optional): OpenAI or Azure supported LLM model to use. Defaults to "text-davinci-002".
api_key (Optional[str], optional): API provider Authentication token. use Defaults to None.
api_provider (Literal["openai", "azure"], optional): The API provider to use. Defaults to "openai".
model_type (Literal["chat", "text"], optional): The type of model that was specified. Mainly to decide the optimal prompting strategy. Defaults to "text".
**kwargs: Additional arguments to pass to the API provider.
"""

def __init__(
self, model: str = "text-davinci-002", api_key: Optional[str] = None, **kwargs
self,
model: str = "text-davinci-002",
api_key: Optional[str] = None,
api_provider: Literal["openai", "azure"] = "openai",
model_type: Literal["chat", "text"] = "text",
**kwargs,
):
super().__init__(model)
self.provider = "openai"
self.model_type = model_type

if api_provider == "azure":
assert (
"engine" in kwargs or "deployment_id" in kwargs
), "Must specify engine or deployment_id for Azure API instead of model."
assert "api_version" in kwargs, "Must specify api_version for Azure API"
assert "api_base" in kwargs, "Must specify api_base for Azure API"
openai.api_type = "azure"
openai.api_base = kwargs["api_base"]
if kwargs.get("api_version"):
openai.api_version = kwargs["api_version"]

if api_key:
openai.api_key = api_key

self.kwargs = {
"model": model,
"temperature": 0.0,
"max_tokens": 150,
"top_p": 1,
Expand All @@ -43,14 +66,15 @@ def __init__(
"n": 1,
**kwargs,
} # TODO: add kwargs above for </s>

if self.provider == "openai":
self.kwargs["model"] = model
self.history: list[dict[str, Any]] = []

def basic_request(self, prompt: str, **kwargs) -> OpenAIObject:
raw_kwargs = kwargs

kwargs = {**self.kwargs, **kwargs}
if kwargs["model"] in ("gpt-3.5-turbo", "gpt-3.5-turbo-0301"):
if self.model_type == "chat":
# caching mechanism requires hashable kwargs
kwargs["messages"] = json.dumps([{"role": "user", "content": prompt}])
response = cached_gpt3_turbo_request(**kwargs)
Expand Down Expand Up @@ -79,7 +103,7 @@ def request(self, prompt: str, **kwargs) -> OpenAIObject:
return self.basic_request(prompt, **kwargs)

def _get_choice_text(self, choice: dict[str, Any]) -> str:
if self.kwargs["model"] in ("gpt-3.5-turbo", "gpt-3.5-turbo-0301"):
if self.model_type == "chat":
return choice["message"]["content"]
return choice["text"]

Expand All @@ -105,7 +129,7 @@ def __call__(
assert return_sorted is False, "for now"

if kwargs.get("n", 1) > 1:
if self.kwargs["model"] in ("gpt-3.5-turbo", "gpt-3.5-turbo-0301"):
if self.model_type == "chat":
kwargs = {**kwargs}
else:
kwargs = {**kwargs, "logprobs": 5}
Expand Down

0 comments on commit 1c5e734

Please sign in to comment.