Skip to content

Commit

Permalink
feature(new-LM): GPT3-turbo integration
Browse files Browse the repository at this point in the history
  • Loading branch information
lawliet19189 committed Mar 12, 2023
1 parent e1d01b0 commit 64d679b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
46 changes: 38 additions & 8 deletions dsp/modules/gpt3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
from typing import Optional, Any, cast
import json
from typing import Optional, Any
import openai
import openai.error
from openai.openai_object import OpenAIObject
Expand All @@ -12,7 +13,7 @@ def backoff_hdlr(details):
"""Handler from https://pypi.org/project/backoff/"""
print(
"Backing off {wait:0.1f} seconds after {tries} tries "
"calling function {target} with args {args} and kwargs "
"calling function {target} with kwargs "
"{kwargs}".format(**details)
)

Expand Down Expand Up @@ -41,8 +42,14 @@ def __init__(

def _basic_request(self, prompt: str, **kwargs) -> OpenAIObject:
raw_kwargs = kwargs
kwargs = {**self.kwargs, "prompt": prompt, **kwargs}
response = cached_gpt3_request(**kwargs)

kwargs = {**self.kwargs, **kwargs}
if kwargs["model"] in ("gpt-3.5-turbo", "gpt-3.5-turbo-0301"):
kwargs["messages"] = json.dumps([{"role": "user", "content": prompt}])
response = cached_gpt3_turbo_request(**kwargs)
else:
kwargs["prompt"] = prompt
response = cached_gpt3_request(**kwargs)

history = {
"prompt": prompt,
Expand Down Expand Up @@ -91,11 +98,16 @@ def inspect_history(self, n: int = 1):
for prompt, choices in reversed(printed):
print("\n\n\n")
print(prompt, end="")
self.print_green(choices[0]["text"], end="")
self.print_green(self._get_choice_text(choices[0]), end="")
if len(choices) > 1:
self.print_red(f" \t (and {len(choices)-1} other completions)", end="")
print("\n\n\n")

def _get_choice_text(self, choice: dict[str, Any]) -> str:
if self.kwargs["model"] in ("gpt-3.5-turbo", "gpt-3.5-turbo-0301"):
return choice["message"]["content"]
return choice["text"]

def __call__(
self,
prompt: str,
Expand All @@ -117,7 +129,10 @@ def __call__(
assert return_sorted is False, "for now"

if kwargs.get("n", 1) > 1:
kwargs = {**kwargs, "logprobs": 5}
if self.kwargs["model"] in ("gpt-3.5-turbo", "gpt-3.5-turbo-0301"):
kwargs = {**kwargs}
else:
kwargs = {**kwargs, "logprobs": 5}

response = self.request(prompt, **kwargs)
choices = response["choices"]
Expand All @@ -127,7 +142,7 @@ def __call__(
if only_completed and len(completed_choices):
choices = completed_choices

completions = [c["text"] for c in choices]
completions = [self._get_choice_text(c) for c in choices]

if return_sorted and kwargs.get("n", 1) > 1:
scored_completions = []
Expand All @@ -143,7 +158,7 @@ def __call__(
tokens, logprobs = tokens[:index], logprobs[:index]

avglog = sum(logprobs) / len(logprobs)
scored_completions.append((avglog, c["text"]))
scored_completions.append((avglog, self._get_choice_text(c)))

scored_completions = sorted(scored_completions, reverse=True)
completions = [c for _, c in scored_completions]
Expand All @@ -163,3 +178,18 @@ def cached_gpt3_request_v2_wrapped(**kwargs):


cached_gpt3_request = cached_gpt3_request_v2_wrapped


@CacheMemory.cache
def cached_gpt3_turbo_request_v2(**kwargs):
kwargs["messages"] = json.loads(kwargs["messages"])
return openai.ChatCompletion.create(**kwargs)


@functools.lru_cache(maxsize=None if cache_turn_on else 0)
@NotebookCacheMemory.cache
def cached_gpt3_turbo_request_v2_wrapped(**kwargs):
return cached_gpt3_turbo_request_v2(**kwargs)


cached_gpt3_turbo_request = cached_gpt3_turbo_request_v2_wrapped
3 changes: 2 additions & 1 deletion dsp/primitives/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def generate_sc(
):
if not dsp.settings.lm:
raise AssertionError("No LM is loaded.")
kwargs = {"temperature": 0.7, "max_tokens": 150, "n": 20, **kwargs}
kwargs = {"temperature": 0.7, "n": 20, "max_tokens": 150, **kwargs}

completions = dsp.settings.lm(prompt, **kwargs)
completions = extract_final_answer(example, completions, extract=extract)
return majority_vote_(
Expand Down

0 comments on commit 64d679b

Please sign in to comment.