mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
[WIP] Client duplicate fixes (#3843)
* client duplicate fixes * fixes * formatting * changes * fixing tests * formatting
This commit is contained in:
parent
c772c6ae57
commit
40b30a683b
@ -8,6 +8,7 @@ import threading
|
||||
import time
|
||||
import urllib.parse
|
||||
import uuid
|
||||
import warnings
|
||||
from concurrent.futures import Future, TimeoutError
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@ -17,6 +18,7 @@ from typing import Any, Callable, Dict, List, Tuple
|
||||
import huggingface_hub
|
||||
import requests
|
||||
import websockets
|
||||
from huggingface_hub import SpaceStage
|
||||
from huggingface_hub.utils import (
|
||||
RepositoryNotFoundError,
|
||||
build_hf_headers,
|
||||
@ -86,10 +88,10 @@ class Client:
|
||||
self.space_id = src
|
||||
self.src = _src
|
||||
state = self._get_space_state()
|
||||
if state == utils.BUILDING_RUNTIME:
|
||||
if state == SpaceStage.BUILDING:
|
||||
if self.verbose:
|
||||
print("Space is still building. Please wait...")
|
||||
while self._get_space_state() == utils.BUILDING_RUNTIME:
|
||||
while self._get_space_state() == SpaceStage.BUILDING:
|
||||
time.sleep(2) # so we don't get rate limited by the API
|
||||
pass
|
||||
if state in utils.INVALID_RUNTIME:
|
||||
@ -151,13 +153,13 @@ class Client:
|
||||
hf_token: The Hugging Face token to use to access private Spaces. Automatically fetched if you are logged in via the Hugging Face Hub CLI. Obtain from: https://huggingface.co/settings/token
|
||||
private: Whether the new Space should be private (True) or public (False). Defaults to True.
|
||||
hardware: The hardware tier to use for the new Space. Defaults to the same hardware tier as the original Space. Options include "cpu-basic", "cpu-upgrade", "t4-small", "t4-medium", "a10g-small", "a10g-large", "a100-large", subject to availability.
|
||||
secrets: A dictionary of (secret key, secret value) to pass to the new Space. Defaults to None.
|
||||
secrets: A dictionary of (secret key, secret value) to pass to the new Space. Defaults to None. Secrets are only used when the Space is duplicated for the first time, and are not updated if the duplicated Space already exists.
|
||||
sleep_timeout: The number of minutes after which the duplicate Space will be puased if no requests are made to it (to minimize billing charges). Defaults to 5 minutes.
|
||||
max_workers: The maximum number of thread workers that can be used to make requests to the remote Gradio app simultaneously.
|
||||
verbose: Whether the client should print statements to the console.
|
||||
"""
|
||||
try:
|
||||
info = huggingface_hub.get_space_runtime(from_id, token=hf_token)
|
||||
original_info = huggingface_hub.get_space_runtime(from_id, token=hf_token)
|
||||
except RepositoryNotFoundError:
|
||||
raise ValueError(
|
||||
f"Could not find Space: {from_id}. If it is a private Space, please provide an `hf_token`."
|
||||
@ -176,6 +178,10 @@ class Client:
|
||||
print(
|
||||
f"Using your existing Space: {utils.SPACE_URL.format(space_id)} 🤗"
|
||||
)
|
||||
if secrets is not None:
|
||||
warnings.warn(
|
||||
"Secrets are only used when the Space is duplicated for the first time, and are not updated if the duplicated Space already exists."
|
||||
)
|
||||
except RepositoryNotFoundError:
|
||||
if verbose:
|
||||
print(f"Creating a duplicate of {from_id} for your own use... 🤗")
|
||||
@ -186,23 +192,26 @@ class Client:
|
||||
exist_ok=True,
|
||||
private=private,
|
||||
)
|
||||
if secrets is not None:
|
||||
for key, value in secrets.items():
|
||||
huggingface_hub.add_space_secret(
|
||||
space_id, key, value, token=hf_token
|
||||
)
|
||||
utils.set_space_timeout(
|
||||
space_id, hf_token=hf_token, timeout_in_seconds=sleep_timeout * 60
|
||||
)
|
||||
if verbose:
|
||||
print(f"Created new Space: {utils.SPACE_URL.format(space_id)}")
|
||||
current_info = huggingface_hub.get_space_runtime(space_id, token=hf_token)
|
||||
current_hardware = current_info.hardware or "cpu-basic"
|
||||
if hardware is None:
|
||||
hardware = info.hardware
|
||||
current_hardware = (
|
||||
current_info.hardware or huggingface_hub.SpaceHardware.CPU_BASIC
|
||||
)
|
||||
hardware = hardware or original_info.hardware
|
||||
if not current_hardware == hardware:
|
||||
huggingface_hub.request_space_hardware(space_id, hardware) # type: ignore
|
||||
print(
|
||||
f"-------\nNOTE: this Space uses upgraded hardware: {hardware}... see billing info at https://huggingface.co/settings/billing\n-------"
|
||||
)
|
||||
if secrets is not None:
|
||||
for key, value in secrets.items():
|
||||
huggingface_hub.add_space_secret(space_id, key, value, token=hf_token)
|
||||
if verbose:
|
||||
print("")
|
||||
client = cls(
|
||||
|
@ -20,6 +20,7 @@ import fsspec.asyn
|
||||
import httpx
|
||||
import huggingface_hub
|
||||
import requests
|
||||
from huggingface_hub import SpaceStage
|
||||
from websockets.legacy.protocol import WebSocketCommonProtocol
|
||||
|
||||
API_URL = "/api/predict/"
|
||||
@ -29,13 +30,12 @@ RESET_URL = "/reset"
|
||||
SPACE_URL = "https://hf.space/{}"
|
||||
STATE_COMPONENT = "state"
|
||||
INVALID_RUNTIME = [
|
||||
"NO_APP_FILE",
|
||||
"CONFIG_ERROR",
|
||||
"BUILD_ERROR",
|
||||
"RUNTIME_ERROR",
|
||||
"PAUSED",
|
||||
SpaceStage.NO_APP_FILE,
|
||||
SpaceStage.CONFIG_ERROR,
|
||||
SpaceStage.BUILD_ERROR,
|
||||
SpaceStage.RUNTIME_ERROR,
|
||||
SpaceStage.PAUSED,
|
||||
]
|
||||
BUILDING_RUNTIME = "BUILDING"
|
||||
|
||||
__version__ = (pkgutil.get_data(__name__, "version.txt") or b"").decode("ascii").strip()
|
||||
|
||||
@ -58,6 +58,12 @@ class InvalidAPIEndpointError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SpaceDuplicationError(Exception):
|
||||
"""Raised when something goes wrong with a Space Duplication."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
"""Status codes presented to client users."""
|
||||
|
||||
@ -400,11 +406,19 @@ def set_space_timeout(
|
||||
library_name="gradio_client",
|
||||
library_version=__version__,
|
||||
)
|
||||
requests.post(
|
||||
r = requests.post(
|
||||
f"https://huggingface.co/api/spaces/{space_id}/sleeptime",
|
||||
json={"seconds": timeout_in_seconds},
|
||||
headers=headers,
|
||||
)
|
||||
print("r", r, r.status_code)
|
||||
try:
|
||||
huggingface_hub.utils.hf_raise_for_status(r)
|
||||
except huggingface_hub.utils.HfHubHTTPError:
|
||||
raise SpaceDuplicationError(
|
||||
f"Could not set sleep timeout on duplicated Space. Please visit {SPACE_URL.format(space_id)} "
|
||||
"to set a timeout manually to reduce billing charges."
|
||||
)
|
||||
|
||||
|
||||
########################
|
||||
|
@ -3,11 +3,13 @@ import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import CancelledError, TimeoutError
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
from gradio_client import Client
|
||||
from gradio_client.serializing import SimpleSerializable
|
||||
@ -575,23 +577,27 @@ class TestDuplication:
|
||||
)
|
||||
|
||||
@pytest.mark.flaky
|
||||
@patch("huggingface_hub.get_space_runtime", return_value=MagicMock(hardware="cpu"))
|
||||
@patch("huggingface_hub.add_space_secret")
|
||||
@patch("huggingface_hub.duplicate_space")
|
||||
@patch("gradio_client.client.Client.__init__", return_value=None)
|
||||
def test_add_secrets(self, mock_init, mock_add_secret, mock_runtime):
|
||||
@patch("gradio_client.utils.set_space_timeout")
|
||||
def test_add_secrets(self, mock_time, mock_init, mock_duplicate, mock_add_secret):
|
||||
with pytest.raises(RepositoryNotFoundError):
|
||||
name = str(uuid.uuid4())
|
||||
Client.duplicate(
|
||||
"gradio/calculator",
|
||||
name,
|
||||
hf_token=HF_TOKEN,
|
||||
secrets={"test_key": "test_value", "test_key2": "test_value2"},
|
||||
)
|
||||
mock_add_secret.assert_any_call(
|
||||
"gradio-tests/calculator",
|
||||
mock_add_secret.assert_called_with(
|
||||
f"gradio-tests/{name}",
|
||||
"test_key",
|
||||
"test_value",
|
||||
token=HF_TOKEN,
|
||||
)
|
||||
mock_add_secret.assert_any_call(
|
||||
"gradio-tests/calculator",
|
||||
f"gradio-tests/{name}",
|
||||
"test_key2",
|
||||
"test_value2",
|
||||
token=HF_TOKEN,
|
||||
|
@ -2,10 +2,11 @@ import json
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from gradio import media_data
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from gradio_client import utils
|
||||
|
||||
@ -98,3 +99,17 @@ async def test_get_pred_from_ws_raises_if_queue_full():
|
||||
hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"})
|
||||
with pytest.raises(utils.QueueError, match="Queue is full!"):
|
||||
await utils.get_pred_from_ws(mock_ws, data, hash_data)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_sleep_successful(mock_post):
|
||||
utils.set_space_timeout("gradio/calculator")
|
||||
|
||||
|
||||
@patch(
|
||||
"requests.post",
|
||||
return_value=MagicMock(raise_for_status=MagicMock(side_effect=HTTPError)),
|
||||
)
|
||||
def test_sleep_unsuccessful(mock_post):
|
||||
with pytest.raises(utils.SpaceDuplicationError):
|
||||
utils.set_space_timeout("gradio/calculator")
|
||||
|
Loading…
Reference in New Issue
Block a user