Skip to content

Commit

Permalink
Update Google module to handle multiple generations with temperature 0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
nbqu committed Feb 26, 2024
1 parent e6a31ab commit b519d9e
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions dsp/modules/google.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Optional
from typing import Any, Iterable, Optional
import backoff

from dsp.modules.lm import LM
Expand Down Expand Up @@ -80,6 +80,9 @@ def __init__(
# Google API uses "candidate_count" instead of "n" or "num_generations"
# For now, google API only supports 1 generation at a time. Raises an error if candidate_count > 1
num_generations = kwargs.pop("n", kwargs.pop("num_generations", 1))
if num_generations > 1 and kwargs['temperature'] == 0.0:
kwargs['temperature'] = 0.7

self.provider = "google"
kwargs = {
"candidate_count": 1,
Expand Down Expand Up @@ -110,7 +113,9 @@ def basic_request(self, prompt: str, **kwargs):
}

# Google disallows "n" arguments
kwargs.pop("n", None)
n = kwargs.pop("n", None)
if n is not None and n > 1 and kwargs['temperature'] == 0.0:
kwargs['temperature'] = 0.7

response = self.llm.generate_content(prompt, generation_config=kwargs)

Expand All @@ -128,6 +133,7 @@ def basic_request(self, prompt: str, **kwargs):
backoff.expo,
(google_api_error),
max_time=1000,
max_tries=5,
on_backoff=backoff_hdlr,
giveup=giveup_hdlr,
)
Expand All @@ -150,6 +156,6 @@ def __call__(
completions = []
for i in range(n):
response = self.request(prompt, **kwargs)
completions.append(response.text)
completions.append(response.parts[0].text)

return completions

0 comments on commit b519d9e

Please sign in to comment.