Skip to content

Commit

Permalink
fix: small completion typo and updated model type
Browse files Browse the repository at this point in the history
  • Loading branch information
KCaverly committed Feb 21, 2024
1 parent 69f201c commit bf10630
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 12 deletions.
2 changes: 1 addition & 1 deletion docs/language_models_client.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class AzureOpenAI(LM):
- `api_base` (str): Azure Base URL.
- `api_version` (str): Version identifier for Azure OpenAI API.
- `api_key` (_Optional[str]_, _optional_): API provider authentication token. Retrieves from `AZURE_OPENAI_KEY` environment variable if None.
- `model_type` (_Literal["chat", "text"]_): Specified model type to use.
- `model_type` (_Literal["chat", "text"]_): Specified model type to use, defaults to 'chat'.
- `**kwargs`: Additional language model arguments to pass to the API provider.

### Methods
Expand Down
15 changes: 4 additions & 11 deletions dsp/modules/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class AzureOpenAI(LM):
api_version (str): Version identifier for API.
model (str, optional): OpenAI or Azure supported LLM model to use. Defaults to "text-davinci-002".
api_key (Optional[str], optional): API provider Authentication token. use Defaults to None.
model_type (Literal["chat", "text"], optional): The type of model that was specified. Mainly to decide the optimal prompting strategy. Defaults to "text".
model_type (Literal["chat", "text"], optional): The type of model that was specified. Mainly to decide the optimal prompting strategy. Defaults to "chat".
**kwargs: Additional arguments to pass to the API provider.
"""

Expand All @@ -64,7 +64,7 @@ def __init__(
api_version: str,
model: str = "gpt-3.5-turbo-instruct",
api_key: Optional[str] = None,
model_type: Literal["chat", "text"] = None,
model_type: Literal["chat", "text"] = "chat",
**kwargs,
):
super().__init__(model)
Expand Down Expand Up @@ -93,14 +93,7 @@ def __init__(

self.client = client

# Define model type
default_model_type = (
"chat"
if ("gpt-3.5" in model or "turbo" in model or "gpt-4" in model)
and ("instruct" not in model)
else "text"
)
self.model_type = model_type if model_type else default_model_type
self.model_type = model_type

if not OPENAI_LEGACY and "model" not in kwargs:
if "deployment_id" in kwargs:
Expand Down Expand Up @@ -268,7 +261,7 @@ def v1_cached_gpt3_turbo_request_v2_wrapped(**kwargs):
def v1_cached_gpt3_turbo_request_v2(**kwargs):
if "stringify_request" in kwargs:
kwargs = json.loads(kwargs["stringify_request"])
return client.chat.completion.create(**kwargs)
return client.chat.completions.create(**kwargs)

return v1_cached_gpt3_turbo_request_v2(**kwargs)

Expand Down

0 comments on commit bf10630

Please sign in to comment.