Skip to content

Commit

Permalink
Merge pull request stanfordnlp#638 from umarbutler/HFClientVLLM-endpo…
Browse files Browse the repository at this point in the history
…int-round-robin

Added support for multiple API endpoints to `HFClientVLLM`
  • Loading branch information
arnavsinghvi11 authored Mar 20, 2024
2 parents 96c9128 + 8c6d9d6 commit 71dbf30
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions dsp/modules/hf_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,16 @@ def send_hftgi_request_v00(arg, **kwargs):
class HFClientVLLM(HFModel):
def __init__(self, model, port, url="http://localhost", **kwargs):
super().__init__(model=model, is_client=True)
self.url = f"{url}:{port}"

if isinstance(url, list):
self.urls = url

elif isinstance(url, str):
self.urls = [f'{url}:{port}']

else:
raise ValueError(f"The url provided to `HFClientVLLM` is neither a string nor a list of strings. It is of type {type(url)}.")

self.headers = {"Content-Type": "application/json"}

def _generate(self, prompt, **kwargs):
Expand All @@ -128,9 +137,13 @@ def _generate(self, prompt, **kwargs):
"max_tokens": kwargs["max_tokens"],
"temperature": kwargs["temperature"],
}

# Round robin the urls.
url = self.urls.pop(0)
self.urls.append(url)

response = send_hfvllm_request_v00(
f"{self.url}/v1/completions",
f"{url}/v1/completions",
json=payload,
headers=self.headers,
)
Expand Down

0 comments on commit 71dbf30

Please sign in to comment.