From 11379b92f1b0f6629c85edd560ca34f8c5904174 Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Wed, 21 Sep 2022 13:18:40 -0400 Subject: [PATCH] Respect Upstream Queue when loading interfaces/blocks from Spaces (#2294) * Fix fn used in load when queue is enabled * Respect upstream queue * Fix test * Skip in 3.7 * Update logic to respect if fn does not have queue * Fix impl + test --- gradio/external.py | 71 ++++++++++++++++++++++++++++++++--------- test/test_external.py | 73 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 129 insertions(+), 15 deletions(-) diff --git a/gradio/external.py b/gradio/external.py index 73d26851a5..13d073a3ac 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -11,13 +11,15 @@ import operator import re import warnings from copy import deepcopy -from typing import TYPE_CHECKING, Callable, Dict, List, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple import requests +import websockets import yaml +from packaging import version import gradio -from gradio import components, utils +from gradio import components, exceptions, utils if TYPE_CHECKING: from gradio.components import DataframeData @@ -417,6 +419,38 @@ def get_spaces(model_name, api_key, alias, **kwargs): return get_spaces_blocks(model_name, config) +async def get_pred_from_ws( + websocket: websockets.WebSocketClientProtocol, data: str +) -> Dict[str, Any]: + completed = False + while not completed: + msg = await websocket.recv() + resp = json.loads(msg) + if resp["msg"] == "queue_full": + raise exceptions.Error("Queue is full! Please try again.") + elif resp["msg"] == "send_data": + await websocket.send(data) + completed = resp["msg"] == "process_completed" + return resp["output"] + + +def get_ws_fn(ws_url): + async def ws_fn(data): + async with websockets.connect(ws_url, open_timeout=10) as websocket: + return await get_pred_from_ws(websocket, data) + + return ws_fn + + +def use_websocket(config, dependency): + queue_enabled = config.get("enable_queue", False) + queue_uses_websocket = version.parse( + config.get("version", "2.0") + ) >= version.Version("3.2") + dependency_uses_queue = dependency.get("queue", False) is not False + return queue_enabled and queue_uses_websocket and dependency_uses_queue + + def get_spaces_blocks(model_name, config): def streamline_config(config: dict) -> dict: """Streamlines the blocks config dictionary to fix components that don't render correctly.""" @@ -429,33 +463,42 @@ def get_spaces_blocks(model_name, config): config = streamline_config(config) api_url = "https://hf.space/embed/{}/api/predict/".format(model_name) headers = {"Content-Type": "application/json"} + ws_url = "wss://spaces.huggingface.tech/{}/queue/join".format(model_name) + + ws_fn = get_ws_fn(ws_url) fns = [] for d, dependency in enumerate(config["dependencies"]): if dependency["backend_fn"]: - def get_fn(outputs, fn_index): + def get_fn(outputs, fn_index, use_ws): def fn(*data): data = json.dumps({"data": data, "fn_index": fn_index}) - response = requests.post(api_url, headers=headers, data=data) - result = json.loads(response.content.decode("utf-8")) - try: + if use_ws: + result = utils.synchronize_async(ws_fn, data) output = result["data"] - except KeyError: - if "error" in result and "429" in result["error"]: - raise TooManyRequestsError( - "Too many requests to the Hugging Face API" + else: + response = requests.post(api_url, headers=headers, data=data) + result = json.loads(response.content.decode("utf-8")) + try: + output = result["data"] + except KeyError: + if "error" in result and "429" in result["error"]: + raise TooManyRequestsError( + "Too many requests to the Hugging Face API" + ) + raise KeyError( + f"Could not find 'data' key in response from external Space. Response received: {result}" ) - raise KeyError( - f"Could not find 'data' key in response from external Space. Response received: {result}" - ) if len(outputs) == 1: output = output[0] return output return fn - fn = get_fn(deepcopy(dependency["outputs"]), d) + fn = get_fn( + deepcopy(dependency["outputs"]), d, use_websocket(config, dependency) + ) fns.append(fn) else: fns.append(None) diff --git a/test/test_external.py b/test/test_external.py index 2837b89e24..5f4bd1fc75 100644 --- a/test/test_external.py +++ b/test/test_external.py @@ -1,6 +1,7 @@ import json import os import pathlib +import sys import textwrap import unittest from unittest.mock import MagicMock, patch @@ -8,9 +9,16 @@ from unittest.mock import MagicMock, patch import pytest import transformers +import gradio import gradio as gr from gradio import utils -from gradio.external import TooManyRequestsError, cols_to_rows, get_tabular_examples +from gradio.external import ( + TooManyRequestsError, + cols_to_rows, + get_pred_from_ws, + get_tabular_examples, + use_websocket, +) """ WARNING: These tests have an external dependency: namely that Hugging Face's @@ -331,5 +339,68 @@ def test_can_load_tabular_model_with_different_widget_data(hypothetical_readme): check_dataset(io.config, hypothetical_readme) +@pytest.mark.parametrize( + "config, dependency, answer", + [ + ({"version": "3.3", "enable_queue": True}, {"queue": True}, True), + ({"version": "3.3", "enable_queue": False}, {"queue": None}, False), + ({"version": "3.3", "enable_queue": True}, {"queue": None}, True), + ({"version": "3.3", "enable_queue": True}, {"queue": False}, False), + ({"enable_queue": True}, {"queue": False}, False), + ({"version": "3.2", "enable_queue": False}, {"queue": None}, False), + ({"version": "3.2", "enable_queue": True}, {"queue": None}, True), + ({"version": "3.2", "enable_queue": True}, {"queue": False}, False), + ({"version": "3.1.3", "enable_queue": True}, {"queue": None}, False), + ({"version": "3.1.3", "enable_queue": False}, {"queue": True}, False), + ], +) +def test_use_websocket_after_315(config, dependency, answer): + assert use_websocket(config, dependency) == answer + + +class AsyncMock(MagicMock): + async def __call__(self, *args, **kwargs): + return super(AsyncMock, self).__call__(*args, **kwargs) + + +@pytest.mark.asyncio +async def test_get_pred_from_ws(): + mock_ws = AsyncMock(name="ws") + messages = [ + json.dumps({"msg": "estimation"}), + json.dumps({"msg": "send_data"}), + json.dumps({"msg": "process_generating"}), + json.dumps({"msg": "process_completed", "output": {"data": ["result!"]}}), + ] + mock_ws.recv.side_effect = messages + data = json.dumps({"data": ["foo"], "fn_index": "foo"}) + output = await get_pred_from_ws(mock_ws, data) + assert output == {"data": ["result!"]} + mock_ws.send.assert_called_once_with(data) + + +@pytest.mark.asyncio +async def test_get_pred_from_ws_raises_if_queue_full(): + mock_ws = AsyncMock(name="ws") + messages = [json.dumps({"msg": "queue_full"})] + mock_ws.recv.side_effect = messages + data = json.dumps({"data": ["foo"], "fn_index": "foo"}) + with pytest.raises(gradio.Error, match="Queue is full!"): + await get_pred_from_ws(mock_ws, data) + + +@pytest.mark.skipif( + sys.version_info < (3, 8), + reason="Mocks of async context manager don't work for 3.7", +) +def test_respect_queue_when_load_from_config(): + with unittest.mock.patch("websockets.connect"): + with unittest.mock.patch( + "gradio.external.get_pred_from_ws", return_value={"data": ["foo"]} + ): + interface = gr.Interface.load("spaces/freddyaboulton/saymyname") + assert interface("bob") == "foo" + + if __name__ == "__main__": unittest.main()