Skip to content

Commit

Permalink
Merge branch 'stanfordnlp:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle authored Mar 20, 2024
2 parents 6ba9600 + 71dbf30 commit 36a72a7
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 116 deletions.
10 changes: 8 additions & 2 deletions dsp/modules/aws_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import logging
from abc import abstractmethod
from typing import Any, Literal
from typing import Any, Literal, Optional

from dsp.modules.lm import LM

Expand All @@ -27,6 +27,7 @@ def __init__(
region_name: str,
service_name: str,
max_new_tokens: int,
profile_name: Optional[str] = None,
truncate_long_prompts: bool = False,
input_output_ratio: int = 3,
batch_n: bool = True,
Expand Down Expand Up @@ -55,7 +56,12 @@ def __init__(

import boto3

self.predictor = boto3.client(service_name, region_name=region_name)
if profile_name is None:
self.predictor = boto3.client(service_name, region_name=region_name)
else:
self.predictor = boto3.Session(profile_name=profile_name).client(
service_name, region_name=region_name,
)

@abstractmethod
def _create_body(self, prompt: str, **kwargs):
Expand Down
4 changes: 3 additions & 1 deletion dsp/modules/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from typing import Any
from typing import Any, Optional

from dsp.modules.aws_lm import AWSLM

Expand All @@ -11,6 +11,7 @@ def __init__(
self,
region_name: str,
model: str,
profile_name: Optional[str] = None,
input_output_ratio: int = 3,
max_new_tokens: int = 1500,
) -> None:
Expand All @@ -28,6 +29,7 @@ def __init__(
model=model,
service_name="bedrock-runtime",
region_name=region_name,
profile_name=profile_name,
truncate_long_prompts=False,
input_output_ratio=input_output_ratio,
max_new_tokens=max_new_tokens,
Expand Down
17 changes: 15 additions & 2 deletions dsp/modules/hf_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,16 @@ def send_hftgi_request_v00(arg, **kwargs):
class HFClientVLLM(HFModel):
def __init__(self, model, port, url="http://localhost", **kwargs):
super().__init__(model=model, is_client=True)
self.url = f"{url}:{port}"

if isinstance(url, list):
self.urls = url

elif isinstance(url, str):
self.urls = [f'{url}:{port}']

else:
raise ValueError(f"The url provided to `HFClientVLLM` is neither a string nor a list of strings. It is of type {type(url)}.")

self.headers = {"Content-Type": "application/json"}

def _generate(self, prompt, **kwargs):
Expand All @@ -128,9 +137,13 @@ def _generate(self, prompt, **kwargs):
"max_tokens": kwargs["max_tokens"],
"temperature": kwargs["temperature"],
}

# Round robin the urls.
url = self.urls.pop(0)
self.urls.append(url)

response = send_hfvllm_request_v00(
f"{self.url}/v1/completions",
f"{url}/v1/completions",
json=payload,
headers=self.headers,
)
Expand Down
Loading

0 comments on commit 36a72a7

Please sign in to comment.