Skip to content

Commit

Permalink
Update retry logic in BasePromptDriver
Browse files Browse the repository at this point in the history
  • Loading branch information
parthvshah committed Jun 14, 2023
1 parent b07daa4 commit 2f695b4
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from tenacity import Retrying, wait_exponential, after_log
from attr import define, field
from griptape.tokenizers import BaseTokenizer

Expand All @@ -12,23 +13,27 @@

@define
class BasePromptDriver(ABC):
max_retries: int = field(default=8, kw_only=True)
retry_delay: float = field(default=1, kw_only=True)
min_retry_delay: float = field(default=2, kw_only=True)
max_retry_delay: float = field(default=10, kw_only=True)

temperature: float = field(default=0.1, kw_only=True)
model: str
tokenizer: BaseTokenizer

def run(self, **kwargs) -> TextArtifact:
for attempt in range(0, self.max_retries + 1):
try:
for attempt in Retrying(
wait=wait_exponential(
min=self.min_retry_delay,
max=self.max_retry_delay
),
reraise=True,
after=after_log(
logger=logging.getLogger(__name__),
log_level=logging.ERROR
),
):
with attempt:
return self.try_run(**kwargs)
except Exception as e:
logging.error(f"PromptDriver.run attempt {attempt} failed: {e}\nRetrying in {self.retry_delay} seconds")

if attempt < self.max_retries:
time.sleep(self.retry_delay)
else:
raise e

@abstractmethod
def try_run(self, **kwargs) -> TextArtifact:
Expand Down

0 comments on commit 2f695b4

Please sign in to comment.