Skip to content

Commit

Permalink
[WIP] Client duplicate fixes (gradio-app#3843)
Browse files Browse the repository at this point in the history
* client duplicate fixes

* fixes

* formatting

* changes

* fixing tests

* formatting
  • Loading branch information
abidlabs authored Apr 13, 2023
1 parent c772c6a commit 40b30a6
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 37 deletions.
29 changes: 19 additions & 10 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
import urllib.parse
import uuid
import warnings
from concurrent.futures import Future, TimeoutError
from datetime import datetime
from pathlib import Path
Expand All @@ -17,6 +18,7 @@
import huggingface_hub
import requests
import websockets
from huggingface_hub import SpaceStage
from huggingface_hub.utils import (
RepositoryNotFoundError,
build_hf_headers,
Expand Down Expand Up @@ -86,10 +88,10 @@ def __init__(
self.space_id = src
self.src = _src
state = self._get_space_state()
if state == utils.BUILDING_RUNTIME:
if state == SpaceStage.BUILDING:
if self.verbose:
print("Space is still building. Please wait...")
while self._get_space_state() == utils.BUILDING_RUNTIME:
while self._get_space_state() == SpaceStage.BUILDING:
time.sleep(2) # so we don't get rate limited by the API
pass
if state in utils.INVALID_RUNTIME:
Expand Down Expand Up @@ -151,13 +153,13 @@ def duplicate(
hf_token: The Hugging Face token to use to access private Spaces. Automatically fetched if you are logged in via the Hugging Face Hub CLI. Obtain from: https://huggingface.co/settings/token
private: Whether the new Space should be private (True) or public (False). Defaults to True.
hardware: The hardware tier to use for the new Space. Defaults to the same hardware tier as the original Space. Options include "cpu-basic", "cpu-upgrade", "t4-small", "t4-medium", "a10g-small", "a10g-large", "a100-large", subject to availability.
secrets: A dictionary of (secret key, secret value) to pass to the new Space. Defaults to None.
secrets: A dictionary of (secret key, secret value) to pass to the new Space. Defaults to None. Secrets are only used when the Space is duplicated for the first time, and are not updated if the duplicated Space already exists.
sleep_timeout: The number of minutes after which the duplicate Space will be puased if no requests are made to it (to minimize billing charges). Defaults to 5 minutes.
max_workers: The maximum number of thread workers that can be used to make requests to the remote Gradio app simultaneously.
verbose: Whether the client should print statements to the console.
"""
try:
info = huggingface_hub.get_space_runtime(from_id, token=hf_token)
original_info = huggingface_hub.get_space_runtime(from_id, token=hf_token)
except RepositoryNotFoundError:
raise ValueError(
f"Could not find Space: {from_id}. If it is a private Space, please provide an `hf_token`."
Expand All @@ -176,6 +178,10 @@ def duplicate(
print(
f"Using your existing Space: {utils.SPACE_URL.format(space_id)} 🤗"
)
if secrets is not None:
warnings.warn(
"Secrets are only used when the Space is duplicated for the first time, and are not updated if the duplicated Space already exists."
)
except RepositoryNotFoundError:
if verbose:
print(f"Creating a duplicate of {from_id} for your own use... 🤗")
Expand All @@ -186,23 +192,26 @@ def duplicate(
exist_ok=True,
private=private,
)
if secrets is not None:
for key, value in secrets.items():
huggingface_hub.add_space_secret(
space_id, key, value, token=hf_token
)
utils.set_space_timeout(
space_id, hf_token=hf_token, timeout_in_seconds=sleep_timeout * 60
)
if verbose:
print(f"Created new Space: {utils.SPACE_URL.format(space_id)}")
current_info = huggingface_hub.get_space_runtime(space_id, token=hf_token)
current_hardware = current_info.hardware or "cpu-basic"
if hardware is None:
hardware = info.hardware
current_hardware = (
current_info.hardware or huggingface_hub.SpaceHardware.CPU_BASIC
)
hardware = hardware or original_info.hardware
if not current_hardware == hardware:
huggingface_hub.request_space_hardware(space_id, hardware) # type: ignore
print(
f"-------\nNOTE: this Space uses upgraded hardware: {hardware}... see billing info at https://huggingface.co/settings/billing\n-------"
)
if secrets is not None:
for key, value in secrets.items():
huggingface_hub.add_space_secret(space_id, key, value, token=hf_token)
if verbose:
print("")
client = cls(
Expand Down
28 changes: 21 additions & 7 deletions client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import httpx
import huggingface_hub
import requests
from huggingface_hub import SpaceStage
from websockets.legacy.protocol import WebSocketCommonProtocol

API_URL = "/api/predict/"
Expand All @@ -29,13 +30,12 @@
SPACE_URL = "https://hf.space/{}"
STATE_COMPONENT = "state"
INVALID_RUNTIME = [
"NO_APP_FILE",
"CONFIG_ERROR",
"BUILD_ERROR",
"RUNTIME_ERROR",
"PAUSED",
SpaceStage.NO_APP_FILE,
SpaceStage.CONFIG_ERROR,
SpaceStage.BUILD_ERROR,
SpaceStage.RUNTIME_ERROR,
SpaceStage.PAUSED,
]
BUILDING_RUNTIME = "BUILDING"

__version__ = (pkgutil.get_data(__name__, "version.txt") or b"").decode("ascii").strip()

Expand All @@ -58,6 +58,12 @@ class InvalidAPIEndpointError(Exception):
pass


class SpaceDuplicationError(Exception):
"""Raised when something goes wrong with a Space Duplication."""

pass


class Status(Enum):
"""Status codes presented to client users."""

Expand Down Expand Up @@ -400,11 +406,19 @@ def set_space_timeout(
library_name="gradio_client",
library_version=__version__,
)
requests.post(
r = requests.post(
f"https://huggingface.co/api/spaces/{space_id}/sleeptime",
json={"seconds": timeout_in_seconds},
headers=headers,
)
print("r", r, r.status_code)
try:
huggingface_hub.utils.hf_raise_for_status(r)
except huggingface_hub.utils.HfHubHTTPError:
raise SpaceDuplicationError(
f"Could not set sleep timeout on duplicated Space. Please visit {SPACE_URL.format(space_id)} "
"to set a timeout manually to reduce billing charges."
)


########################
Expand Down
44 changes: 25 additions & 19 deletions client/python/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import pathlib
import tempfile
import time
import uuid
from concurrent.futures import CancelledError, TimeoutError
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch

import pytest
from huggingface_hub.utils import RepositoryNotFoundError

from gradio_client import Client
from gradio_client.serializing import SimpleSerializable
Expand Down Expand Up @@ -575,24 +577,28 @@ def test_default_space_id(self, mock_init, mock_runtime):
)

@pytest.mark.flaky
@patch("huggingface_hub.get_space_runtime", return_value=MagicMock(hardware="cpu"))
@patch("huggingface_hub.add_space_secret")
@patch("huggingface_hub.duplicate_space")
@patch("gradio_client.client.Client.__init__", return_value=None)
def test_add_secrets(self, mock_init, mock_add_secret, mock_runtime):
Client.duplicate(
"gradio/calculator",
hf_token=HF_TOKEN,
secrets={"test_key": "test_value", "test_key2": "test_value2"},
)
mock_add_secret.assert_any_call(
"gradio-tests/calculator",
"test_key",
"test_value",
token=HF_TOKEN,
)
mock_add_secret.assert_any_call(
"gradio-tests/calculator",
"test_key2",
"test_value2",
token=HF_TOKEN,
)
@patch("gradio_client.utils.set_space_timeout")
def test_add_secrets(self, mock_time, mock_init, mock_duplicate, mock_add_secret):
with pytest.raises(RepositoryNotFoundError):
name = str(uuid.uuid4())
Client.duplicate(
"gradio/calculator",
name,
hf_token=HF_TOKEN,
secrets={"test_key": "test_value", "test_key2": "test_value2"},
)
mock_add_secret.assert_called_with(
f"gradio-tests/{name}",
"test_key",
"test_value",
token=HF_TOKEN,
)
mock_add_secret.assert_any_call(
f"gradio-tests/{name}",
"test_key2",
"test_value2",
token=HF_TOKEN,
)
17 changes: 16 additions & 1 deletion client/python/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import tempfile
from copy import deepcopy
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
from gradio import media_data
from requests.exceptions import HTTPError

from gradio_client import utils

Expand Down Expand Up @@ -98,3 +99,17 @@ async def test_get_pred_from_ws_raises_if_queue_full():
hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"})
with pytest.raises(utils.QueueError, match="Queue is full!"):
await utils.get_pred_from_ws(mock_ws, data, hash_data)


@patch("requests.post")
def test_sleep_successful(mock_post):
utils.set_space_timeout("gradio/calculator")


@patch(
"requests.post",
return_value=MagicMock(raise_for_status=MagicMock(side_effect=HTTPError)),
)
def test_sleep_unsuccessful(mock_post):
with pytest.raises(utils.SpaceDuplicationError):
utils.set_space_timeout("gradio/calculator")

0 comments on commit 40b30a6

Please sign in to comment.