[WIP] Client duplicate fixes (#3843)

* client duplicate fixes

* fixes

* formatting

* changes

* fixing tests

* formatting
This commit is contained in:
Abubakar Abid 2023-04-13 12:30:11 -07:00 committed by GitHub
parent c772c6ae57
commit 40b30a683b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 81 additions and 37 deletions

View File

@ -8,6 +8,7 @@ import threading
import time
import urllib.parse
import uuid
import warnings
from concurrent.futures import Future, TimeoutError
from datetime import datetime
from pathlib import Path
@ -17,6 +18,7 @@ from typing import Any, Callable, Dict, List, Tuple
import huggingface_hub
import requests
import websockets
from huggingface_hub import SpaceStage
from huggingface_hub.utils import (
RepositoryNotFoundError,
build_hf_headers,
@ -86,10 +88,10 @@ class Client:
self.space_id = src
self.src = _src
state = self._get_space_state()
if state == utils.BUILDING_RUNTIME:
if state == SpaceStage.BUILDING:
if self.verbose:
print("Space is still building. Please wait...")
while self._get_space_state() == utils.BUILDING_RUNTIME:
while self._get_space_state() == SpaceStage.BUILDING:
time.sleep(2) # so we don't get rate limited by the API
pass
if state in utils.INVALID_RUNTIME:
@ -151,13 +153,13 @@ class Client:
hf_token: The Hugging Face token to use to access private Spaces. Automatically fetched if you are logged in via the Hugging Face Hub CLI. Obtain from: https://huggingface.co/settings/token
private: Whether the new Space should be private (True) or public (False). Defaults to True.
hardware: The hardware tier to use for the new Space. Defaults to the same hardware tier as the original Space. Options include "cpu-basic", "cpu-upgrade", "t4-small", "t4-medium", "a10g-small", "a10g-large", "a100-large", subject to availability.
secrets: A dictionary of (secret key, secret value) to pass to the new Space. Defaults to None.
secrets: A dictionary of (secret key, secret value) to pass to the new Space. Defaults to None. Secrets are only used when the Space is duplicated for the first time, and are not updated if the duplicated Space already exists.
sleep_timeout: The number of minutes after which the duplicate Space will be puased if no requests are made to it (to minimize billing charges). Defaults to 5 minutes.
max_workers: The maximum number of thread workers that can be used to make requests to the remote Gradio app simultaneously.
verbose: Whether the client should print statements to the console.
"""
try:
info = huggingface_hub.get_space_runtime(from_id, token=hf_token)
original_info = huggingface_hub.get_space_runtime(from_id, token=hf_token)
except RepositoryNotFoundError:
raise ValueError(
f"Could not find Space: {from_id}. If it is a private Space, please provide an `hf_token`."
@ -176,6 +178,10 @@ class Client:
print(
f"Using your existing Space: {utils.SPACE_URL.format(space_id)} 🤗"
)
if secrets is not None:
warnings.warn(
"Secrets are only used when the Space is duplicated for the first time, and are not updated if the duplicated Space already exists."
)
except RepositoryNotFoundError:
if verbose:
print(f"Creating a duplicate of {from_id} for your own use... 🤗")
@ -186,23 +192,26 @@ class Client:
exist_ok=True,
private=private,
)
if secrets is not None:
for key, value in secrets.items():
huggingface_hub.add_space_secret(
space_id, key, value, token=hf_token
)
utils.set_space_timeout(
space_id, hf_token=hf_token, timeout_in_seconds=sleep_timeout * 60
)
if verbose:
print(f"Created new Space: {utils.SPACE_URL.format(space_id)}")
current_info = huggingface_hub.get_space_runtime(space_id, token=hf_token)
current_hardware = current_info.hardware or "cpu-basic"
if hardware is None:
hardware = info.hardware
current_hardware = (
current_info.hardware or huggingface_hub.SpaceHardware.CPU_BASIC
)
hardware = hardware or original_info.hardware
if not current_hardware == hardware:
huggingface_hub.request_space_hardware(space_id, hardware) # type: ignore
print(
f"-------\nNOTE: this Space uses upgraded hardware: {hardware}... see billing info at https://huggingface.co/settings/billing\n-------"
)
if secrets is not None:
for key, value in secrets.items():
huggingface_hub.add_space_secret(space_id, key, value, token=hf_token)
if verbose:
print("")
client = cls(

View File

@ -20,6 +20,7 @@ import fsspec.asyn
import httpx
import huggingface_hub
import requests
from huggingface_hub import SpaceStage
from websockets.legacy.protocol import WebSocketCommonProtocol
API_URL = "/api/predict/"
@ -29,13 +30,12 @@ RESET_URL = "/reset"
SPACE_URL = "https://hf.space/{}"
STATE_COMPONENT = "state"
INVALID_RUNTIME = [
"NO_APP_FILE",
"CONFIG_ERROR",
"BUILD_ERROR",
"RUNTIME_ERROR",
"PAUSED",
SpaceStage.NO_APP_FILE,
SpaceStage.CONFIG_ERROR,
SpaceStage.BUILD_ERROR,
SpaceStage.RUNTIME_ERROR,
SpaceStage.PAUSED,
]
BUILDING_RUNTIME = "BUILDING"
__version__ = (pkgutil.get_data(__name__, "version.txt") or b"").decode("ascii").strip()
@ -58,6 +58,12 @@ class InvalidAPIEndpointError(Exception):
pass
class SpaceDuplicationError(Exception):
"""Raised when something goes wrong with a Space Duplication."""
pass
class Status(Enum):
"""Status codes presented to client users."""
@ -400,11 +406,19 @@ def set_space_timeout(
library_name="gradio_client",
library_version=__version__,
)
requests.post(
r = requests.post(
f"https://huggingface.co/api/spaces/{space_id}/sleeptime",
json={"seconds": timeout_in_seconds},
headers=headers,
)
print("r", r, r.status_code)
try:
huggingface_hub.utils.hf_raise_for_status(r)
except huggingface_hub.utils.HfHubHTTPError:
raise SpaceDuplicationError(
f"Could not set sleep timeout on duplicated Space. Please visit {SPACE_URL.format(space_id)} "
"to set a timeout manually to reduce billing charges."
)
########################

View File

@ -3,11 +3,13 @@ import os
import pathlib
import tempfile
import time
import uuid
from concurrent.futures import CancelledError, TimeoutError
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
import pytest
from huggingface_hub.utils import RepositoryNotFoundError
from gradio_client import Client
from gradio_client.serializing import SimpleSerializable
@ -575,23 +577,27 @@ class TestDuplication:
)
@pytest.mark.flaky
@patch("huggingface_hub.get_space_runtime", return_value=MagicMock(hardware="cpu"))
@patch("huggingface_hub.add_space_secret")
@patch("huggingface_hub.duplicate_space")
@patch("gradio_client.client.Client.__init__", return_value=None)
def test_add_secrets(self, mock_init, mock_add_secret, mock_runtime):
@patch("gradio_client.utils.set_space_timeout")
def test_add_secrets(self, mock_time, mock_init, mock_duplicate, mock_add_secret):
with pytest.raises(RepositoryNotFoundError):
name = str(uuid.uuid4())
Client.duplicate(
"gradio/calculator",
name,
hf_token=HF_TOKEN,
secrets={"test_key": "test_value", "test_key2": "test_value2"},
)
mock_add_secret.assert_any_call(
"gradio-tests/calculator",
mock_add_secret.assert_called_with(
f"gradio-tests/{name}",
"test_key",
"test_value",
token=HF_TOKEN,
)
mock_add_secret.assert_any_call(
"gradio-tests/calculator",
f"gradio-tests/{name}",
"test_key2",
"test_value2",
token=HF_TOKEN,

View File

@ -2,10 +2,11 @@ import json
import tempfile
from copy import deepcopy
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import pytest
from gradio import media_data
from requests.exceptions import HTTPError
from gradio_client import utils
@ -98,3 +99,17 @@ async def test_get_pred_from_ws_raises_if_queue_full():
hash_data = json.dumps({"session_hash": "daslskdf", "fn_index": "foo"})
with pytest.raises(utils.QueueError, match="Queue is full!"):
await utils.get_pred_from_ws(mock_ws, data, hash_data)
@patch("requests.post")
def test_sleep_successful(mock_post):
utils.set_space_timeout("gradio/calculator")
@patch(
"requests.post",
return_value=MagicMock(raise_for_status=MagicMock(side_effect=HTTPError)),
)
def test_sleep_unsuccessful(mock_post):
with pytest.raises(utils.SpaceDuplicationError):
utils.set_space_timeout("gradio/calculator")