Skip to content

Commit

Permalink
Update OpenAITextGenerator based on new openai package
Browse files Browse the repository at this point in the history
  • Loading branch information
vignesh-arivazhagan committed Jan 11, 2024
1 parent c7ee01e commit 4181e17
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions llmx/generators/text/openai_textgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ...utils import cache_request, get_models_maxtoken_dict, num_tokens_from_messages
import os
import openai
from openai import OpenAI
from dataclasses import asdict


Expand All @@ -14,7 +15,6 @@ def __init__(
provider: str = "openai",
organization: str = None,
api_type: str = None,
api_base: str = None,
api_version: str = None,
model: str = None,
models: Dict = None,
Expand All @@ -34,6 +34,8 @@ def __init__(
if api_type:
openai.api_type = api_type

self.client = OpenAI()

self.model_name = model or "gpt-3.5-turbo"

self.model_max_token_dict = get_models_maxtoken_dict(models)
Expand All @@ -48,7 +50,9 @@ def generate(
use_cache = config.use_cache
model = config.model or self.model_name
prompt_tokens = num_tokens_from_messages(messages)
max_tokens = max(self.model_max_token_dict.get(model, 4096) - prompt_tokens - 10, 200)
max_tokens = max(
self.model_max_token_dict.get(model, 4096) - prompt_tokens - 10, 200
)

oai_config = {
"model": model,
Expand All @@ -71,10 +75,10 @@ def generate(
if response:
return TextGenerationResponse(**response)

oai_response = openai.ChatCompletion.create(**oai_config)
oai_response = self.client.chat.completions.create(**oai_config)

response = TextGenerationResponse(
text=[Message(**x.message) for x in oai_response.choices],
text=[Message(**x.message.model_dump()) for x in oai_response.choices],
logprobs=[],
config=oai_config,
usage=dict(oai_response.usage),
Expand Down

0 comments on commit 4181e17

Please sign in to comment.