mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
Respect Upstream Queue when loading interfaces/blocks from Spaces (#2294)
* Fix fn used in load when queue is enabled * Respect upstream queue * Fix test * Skip in 3.7 * Update logic to respect if fn does not have queue * Fix impl + test
This commit is contained in:
parent
c977ef1fa8
commit
11379b92f1
@ -11,13 +11,15 @@ import operator
|
||||
import re
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
|
||||
|
||||
import requests
|
||||
import websockets
|
||||
import yaml
|
||||
from packaging import version
|
||||
|
||||
import gradio
|
||||
from gradio import components, utils
|
||||
from gradio import components, exceptions, utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import DataframeData
|
||||
@ -417,6 +419,38 @@ def get_spaces(model_name, api_key, alias, **kwargs):
|
||||
return get_spaces_blocks(model_name, config)
|
||||
|
||||
|
||||
async def get_pred_from_ws(
|
||||
websocket: websockets.WebSocketClientProtocol, data: str
|
||||
) -> Dict[str, Any]:
|
||||
completed = False
|
||||
while not completed:
|
||||
msg = await websocket.recv()
|
||||
resp = json.loads(msg)
|
||||
if resp["msg"] == "queue_full":
|
||||
raise exceptions.Error("Queue is full! Please try again.")
|
||||
elif resp["msg"] == "send_data":
|
||||
await websocket.send(data)
|
||||
completed = resp["msg"] == "process_completed"
|
||||
return resp["output"]
|
||||
|
||||
|
||||
def get_ws_fn(ws_url):
|
||||
async def ws_fn(data):
|
||||
async with websockets.connect(ws_url, open_timeout=10) as websocket:
|
||||
return await get_pred_from_ws(websocket, data)
|
||||
|
||||
return ws_fn
|
||||
|
||||
|
||||
def use_websocket(config, dependency):
|
||||
queue_enabled = config.get("enable_queue", False)
|
||||
queue_uses_websocket = version.parse(
|
||||
config.get("version", "2.0")
|
||||
) >= version.Version("3.2")
|
||||
dependency_uses_queue = dependency.get("queue", False) is not False
|
||||
return queue_enabled and queue_uses_websocket and dependency_uses_queue
|
||||
|
||||
|
||||
def get_spaces_blocks(model_name, config):
|
||||
def streamline_config(config: dict) -> dict:
|
||||
"""Streamlines the blocks config dictionary to fix components that don't render correctly."""
|
||||
@ -429,33 +463,42 @@ def get_spaces_blocks(model_name, config):
|
||||
config = streamline_config(config)
|
||||
api_url = "https://hf.space/embed/{}/api/predict/".format(model_name)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
ws_url = "wss://spaces.huggingface.tech/{}/queue/join".format(model_name)
|
||||
|
||||
ws_fn = get_ws_fn(ws_url)
|
||||
|
||||
fns = []
|
||||
for d, dependency in enumerate(config["dependencies"]):
|
||||
if dependency["backend_fn"]:
|
||||
|
||||
def get_fn(outputs, fn_index):
|
||||
def get_fn(outputs, fn_index, use_ws):
|
||||
def fn(*data):
|
||||
data = json.dumps({"data": data, "fn_index": fn_index})
|
||||
response = requests.post(api_url, headers=headers, data=data)
|
||||
result = json.loads(response.content.decode("utf-8"))
|
||||
try:
|
||||
if use_ws:
|
||||
result = utils.synchronize_async(ws_fn, data)
|
||||
output = result["data"]
|
||||
except KeyError:
|
||||
if "error" in result and "429" in result["error"]:
|
||||
raise TooManyRequestsError(
|
||||
"Too many requests to the Hugging Face API"
|
||||
else:
|
||||
response = requests.post(api_url, headers=headers, data=data)
|
||||
result = json.loads(response.content.decode("utf-8"))
|
||||
try:
|
||||
output = result["data"]
|
||||
except KeyError:
|
||||
if "error" in result and "429" in result["error"]:
|
||||
raise TooManyRequestsError(
|
||||
"Too many requests to the Hugging Face API"
|
||||
)
|
||||
raise KeyError(
|
||||
f"Could not find 'data' key in response from external Space. Response received: {result}"
|
||||
)
|
||||
raise KeyError(
|
||||
f"Could not find 'data' key in response from external Space. Response received: {result}"
|
||||
)
|
||||
if len(outputs) == 1:
|
||||
output = output[0]
|
||||
return output
|
||||
|
||||
return fn
|
||||
|
||||
fn = get_fn(deepcopy(dependency["outputs"]), d)
|
||||
fn = get_fn(
|
||||
deepcopy(dependency["outputs"]), d, use_websocket(config, dependency)
|
||||
)
|
||||
fns.append(fn)
|
||||
else:
|
||||
fns.append(None)
|
||||
|
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import textwrap
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
@ -8,9 +9,16 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
import transformers
|
||||
|
||||
import gradio
|
||||
import gradio as gr
|
||||
from gradio import utils
|
||||
from gradio.external import TooManyRequestsError, cols_to_rows, get_tabular_examples
|
||||
from gradio.external import (
|
||||
TooManyRequestsError,
|
||||
cols_to_rows,
|
||||
get_pred_from_ws,
|
||||
get_tabular_examples,
|
||||
use_websocket,
|
||||
)
|
||||
|
||||
"""
|
||||
WARNING: These tests have an external dependency: namely that Hugging Face's
|
||||
@ -331,5 +339,68 @@ def test_can_load_tabular_model_with_different_widget_data(hypothetical_readme):
|
||||
check_dataset(io.config, hypothetical_readme)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config, dependency, answer",
|
||||
[
|
||||
({"version": "3.3", "enable_queue": True}, {"queue": True}, True),
|
||||
({"version": "3.3", "enable_queue": False}, {"queue": None}, False),
|
||||
({"version": "3.3", "enable_queue": True}, {"queue": None}, True),
|
||||
({"version": "3.3", "enable_queue": True}, {"queue": False}, False),
|
||||
({"enable_queue": True}, {"queue": False}, False),
|
||||
({"version": "3.2", "enable_queue": False}, {"queue": None}, False),
|
||||
({"version": "3.2", "enable_queue": True}, {"queue": None}, True),
|
||||
({"version": "3.2", "enable_queue": True}, {"queue": False}, False),
|
||||
({"version": "3.1.3", "enable_queue": True}, {"queue": None}, False),
|
||||
({"version": "3.1.3", "enable_queue": False}, {"queue": True}, False),
|
||||
],
|
||||
)
|
||||
def test_use_websocket_after_315(config, dependency, answer):
|
||||
assert use_websocket(config, dependency) == answer
|
||||
|
||||
|
||||
class AsyncMock(MagicMock):
|
||||
async def __call__(self, *args, **kwargs):
|
||||
return super(AsyncMock, self).__call__(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pred_from_ws():
|
||||
mock_ws = AsyncMock(name="ws")
|
||||
messages = [
|
||||
json.dumps({"msg": "estimation"}),
|
||||
json.dumps({"msg": "send_data"}),
|
||||
json.dumps({"msg": "process_generating"}),
|
||||
json.dumps({"msg": "process_completed", "output": {"data": ["result!"]}}),
|
||||
]
|
||||
mock_ws.recv.side_effect = messages
|
||||
data = json.dumps({"data": ["foo"], "fn_index": "foo"})
|
||||
output = await get_pred_from_ws(mock_ws, data)
|
||||
assert output == {"data": ["result!"]}
|
||||
mock_ws.send.assert_called_once_with(data)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pred_from_ws_raises_if_queue_full():
|
||||
mock_ws = AsyncMock(name="ws")
|
||||
messages = [json.dumps({"msg": "queue_full"})]
|
||||
mock_ws.recv.side_effect = messages
|
||||
data = json.dumps({"data": ["foo"], "fn_index": "foo"})
|
||||
with pytest.raises(gradio.Error, match="Queue is full!"):
|
||||
await get_pred_from_ws(mock_ws, data)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 8),
|
||||
reason="Mocks of async context manager don't work for 3.7",
|
||||
)
|
||||
def test_respect_queue_when_load_from_config():
|
||||
with unittest.mock.patch("websockets.connect"):
|
||||
with unittest.mock.patch(
|
||||
"gradio.external.get_pred_from_ws", return_value={"data": ["foo"]}
|
||||
):
|
||||
interface = gr.Interface.load("spaces/freddyaboulton/saymyname")
|
||||
assert interface("bob") == "foo"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user