Respect Upstream Queue when loading interfaces/blocks from Spaces (#2294)

* Fix fn used in load when queue is enabled

* Respect upstream queue

* Fix test

* Skip in 3.7

* Update logic to respect if fn does not have queue

* Fix impl + test
This commit is contained in:
Freddy Boulton 2022-09-21 13:18:40 -04:00 committed by GitHub
parent c977ef1fa8
commit 11379b92f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 129 additions and 15 deletions

View File

@ -11,13 +11,15 @@ import operator
import re
import warnings
from copy import deepcopy
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple
import requests
import websockets
import yaml
from packaging import version
import gradio
from gradio import components, utils
from gradio import components, exceptions, utils
if TYPE_CHECKING:
from gradio.components import DataframeData
@ -417,6 +419,38 @@ def get_spaces(model_name, api_key, alias, **kwargs):
return get_spaces_blocks(model_name, config)
async def get_pred_from_ws(
websocket: websockets.WebSocketClientProtocol, data: str
) -> Dict[str, Any]:
completed = False
while not completed:
msg = await websocket.recv()
resp = json.loads(msg)
if resp["msg"] == "queue_full":
raise exceptions.Error("Queue is full! Please try again.")
elif resp["msg"] == "send_data":
await websocket.send(data)
completed = resp["msg"] == "process_completed"
return resp["output"]
def get_ws_fn(ws_url):
async def ws_fn(data):
async with websockets.connect(ws_url, open_timeout=10) as websocket:
return await get_pred_from_ws(websocket, data)
return ws_fn
def use_websocket(config, dependency):
queue_enabled = config.get("enable_queue", False)
queue_uses_websocket = version.parse(
config.get("version", "2.0")
) >= version.Version("3.2")
dependency_uses_queue = dependency.get("queue", False) is not False
return queue_enabled and queue_uses_websocket and dependency_uses_queue
def get_spaces_blocks(model_name, config):
def streamline_config(config: dict) -> dict:
"""Streamlines the blocks config dictionary to fix components that don't render correctly."""
@ -429,33 +463,42 @@ def get_spaces_blocks(model_name, config):
config = streamline_config(config)
api_url = "https://hf.space/embed/{}/api/predict/".format(model_name)
headers = {"Content-Type": "application/json"}
ws_url = "wss://spaces.huggingface.tech/{}/queue/join".format(model_name)
ws_fn = get_ws_fn(ws_url)
fns = []
for d, dependency in enumerate(config["dependencies"]):
if dependency["backend_fn"]:
def get_fn(outputs, fn_index):
def get_fn(outputs, fn_index, use_ws):
def fn(*data):
data = json.dumps({"data": data, "fn_index": fn_index})
response = requests.post(api_url, headers=headers, data=data)
result = json.loads(response.content.decode("utf-8"))
try:
if use_ws:
result = utils.synchronize_async(ws_fn, data)
output = result["data"]
except KeyError:
if "error" in result and "429" in result["error"]:
raise TooManyRequestsError(
"Too many requests to the Hugging Face API"
else:
response = requests.post(api_url, headers=headers, data=data)
result = json.loads(response.content.decode("utf-8"))
try:
output = result["data"]
except KeyError:
if "error" in result and "429" in result["error"]:
raise TooManyRequestsError(
"Too many requests to the Hugging Face API"
)
raise KeyError(
f"Could not find 'data' key in response from external Space. Response received: {result}"
)
raise KeyError(
f"Could not find 'data' key in response from external Space. Response received: {result}"
)
if len(outputs) == 1:
output = output[0]
return output
return fn
fn = get_fn(deepcopy(dependency["outputs"]), d)
fn = get_fn(
deepcopy(dependency["outputs"]), d, use_websocket(config, dependency)
)
fns.append(fn)
else:
fns.append(None)

View File

@ -1,6 +1,7 @@
import json
import os
import pathlib
import sys
import textwrap
import unittest
from unittest.mock import MagicMock, patch
@ -8,9 +9,16 @@ from unittest.mock import MagicMock, patch
import pytest
import transformers
import gradio
import gradio as gr
from gradio import utils
from gradio.external import TooManyRequestsError, cols_to_rows, get_tabular_examples
from gradio.external import (
TooManyRequestsError,
cols_to_rows,
get_pred_from_ws,
get_tabular_examples,
use_websocket,
)
"""
WARNING: These tests have an external dependency: namely that Hugging Face's
@ -331,5 +339,68 @@ def test_can_load_tabular_model_with_different_widget_data(hypothetical_readme):
check_dataset(io.config, hypothetical_readme)
@pytest.mark.parametrize(
"config, dependency, answer",
[
({"version": "3.3", "enable_queue": True}, {"queue": True}, True),
({"version": "3.3", "enable_queue": False}, {"queue": None}, False),
({"version": "3.3", "enable_queue": True}, {"queue": None}, True),
({"version": "3.3", "enable_queue": True}, {"queue": False}, False),
({"enable_queue": True}, {"queue": False}, False),
({"version": "3.2", "enable_queue": False}, {"queue": None}, False),
({"version": "3.2", "enable_queue": True}, {"queue": None}, True),
({"version": "3.2", "enable_queue": True}, {"queue": False}, False),
({"version": "3.1.3", "enable_queue": True}, {"queue": None}, False),
({"version": "3.1.3", "enable_queue": False}, {"queue": True}, False),
],
)
def test_use_websocket_after_315(config, dependency, answer):
assert use_websocket(config, dependency) == answer
class AsyncMock(MagicMock):
async def __call__(self, *args, **kwargs):
return super(AsyncMock, self).__call__(*args, **kwargs)
@pytest.mark.asyncio
async def test_get_pred_from_ws():
mock_ws = AsyncMock(name="ws")
messages = [
json.dumps({"msg": "estimation"}),
json.dumps({"msg": "send_data"}),
json.dumps({"msg": "process_generating"}),
json.dumps({"msg": "process_completed", "output": {"data": ["result!"]}}),
]
mock_ws.recv.side_effect = messages
data = json.dumps({"data": ["foo"], "fn_index": "foo"})
output = await get_pred_from_ws(mock_ws, data)
assert output == {"data": ["result!"]}
mock_ws.send.assert_called_once_with(data)
@pytest.mark.asyncio
async def test_get_pred_from_ws_raises_if_queue_full():
mock_ws = AsyncMock(name="ws")
messages = [json.dumps({"msg": "queue_full"})]
mock_ws.recv.side_effect = messages
data = json.dumps({"data": ["foo"], "fn_index": "foo"})
with pytest.raises(gradio.Error, match="Queue is full!"):
await get_pred_from_ws(mock_ws, data)
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Mocks of async context manager don't work for 3.7",
)
def test_respect_queue_when_load_from_config():
with unittest.mock.patch("websockets.connect"):
with unittest.mock.patch(
"gradio.external.get_pred_from_ws", return_value={"data": ["foo"]}
):
interface = gr.Interface.load("spaces/freddyaboulton/saymyname")
assert interface("bob") == "foo"
if __name__ == "__main__":
unittest.main()