Skip to content

Commit

Permalink
Merge pull request stanfordnlp#701 from Arbaaz-Mahmood/anthropic
Browse files Browse the repository at this point in the history
feat: Add support for additional Anthropic models
  • Loading branch information
okhat authored Mar 26, 2024
2 parents 46dfece + 5385301 commit c4e72be
Showing 1 changed file with 6 additions and 20 deletions.
26 changes: 6 additions & 20 deletions dsp/modules/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
except ImportError:
anthropic_rate_limit = Exception


logger = logging.getLogger(__name__)

BASE_URL = "https://api.anthropic.com/v1/messages"


def backoff_hdlr(details):
"""Handler from https://pypi.org/project/backoff/."""
print(
Expand All @@ -26,25 +23,22 @@ def backoff_hdlr(details):
"{kwargs}".format(**details),
)


def giveup_hdlr(details):
"""Wrapper function that decides when to give up on retry."""
if "rate limits" in details.message:
return False
return True


class Claude(LM):
"""Wrapper around anthropic's API. Supports both the Anthropic and Azure APIs."""
def __init__(
self,
model: str = "claude-instant-1.2",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
**kwargs,
self,
model: str = "claude-3-opus-20240229",
api_key: Optional[str] = None,
api_base: Optional[str] = None,
**kwargs,
):
super().__init__(model)

try:
from anthropic import Anthropic
except ImportError as err:
Expand All @@ -53,7 +47,6 @@ def __init__(
self.provider = "anthropic"
self.api_key = api_key = os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key
self.api_base = BASE_URL if api_base is None else api_base

self.kwargs = {
"temperature": kwargs.get("temperature", 0.0),
"max_tokens": min(kwargs.get("max_tokens", 4096), 4096),
Expand All @@ -75,21 +68,18 @@ def log_usage(self, response):

def basic_request(self, prompt: str, **kwargs):
raw_kwargs = kwargs

kwargs = {**self.kwargs, **kwargs}
# caching mechanism requires hashable kwargs
kwargs["messages"] = [{"role": "user", "content": prompt}]
kwargs.pop("n")
response = self.client.messages.create(**kwargs)

history = {
"prompt": prompt,
"response": response,
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}
self.history.append(history)

return response

@backoff.on_exception(
Expand All @@ -115,15 +105,11 @@ def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
Returns:
list[str]: list of completion choices
"""

assert only_completed, "for now"
assert return_sorted is False, "for now"


# per eg here: https://docs.anthropic.com/claude/reference/messages-examples
# max tokens can be used as a proxy to return smaller responses
# so this cannot be a proper indicator for incomplete response unless it isnt the user-intent.

n = kwargs.pop("n", 1)
completions = []
for _ in range(n):
Expand All @@ -134,4 +120,4 @@ def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
if only_completed and response.stop_reason == "max_tokens":
continue
completions = [c.text for c in response.content]
return completions
return completions

0 comments on commit c4e72be

Please sign in to comment.