From f46f5f986727c4ace9cf7a96298fef729fbc8de2 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Sun, 2 Apr 2023 17:28:46 -0500 Subject: [PATCH] Fix some bugs related to Python client (#3721) * client format * docs * formatting * fix tests * fixed bug * api endpoint changes * fix tests * fix tests * formatting * Add support for sessions [python client] (#3731) * client * add state and tests * remove session param --- client/python/gradio_client/client.py | 183 ++++++++++++++++++-------- client/python/gradio_client/utils.py | 1 + client/python/test/test_client.py | 56 ++++++-- 3 files changed, 173 insertions(+), 67 deletions(-) diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index e3e5c4013b..87f852e7f3 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -9,7 +9,7 @@ import uuid from concurrent.futures import Future from datetime import datetime from threading import Lock -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Literal, Tuple import huggingface_hub import requests @@ -32,7 +32,7 @@ class Client: """ Parameters: src: Either the name of the Hugging Face Space to load, (e.g. "abidlabs/pictionary") or the full URL (including "http" or "https") of the hosted Gradio app to load (e.g. "http://mydomain.com/app" or "https://bec81a83-5b5c-471e.gradio.live/"). - hf_token: The Hugging Face token to use to access private Spaces. If not provided, only public Spaces can be loaded. + hf_token: The Hugging Face token to use to access private Spaces. Automatically fetched if you are logged in via the Hugging Face Hub CLI. max_workers: The maximum number of thread workers that can be used to make requests to the remote Gradio app simultaneously. """ self.hf_token = hf_token @@ -55,6 +55,7 @@ class Client: self.api_url = utils.API_URL.format(self.src) self.ws_url = utils.WS_URL.format(self.src).replace("http", "ws", 1) self.config = self._get_config() + self.session_hash = str(uuid.uuid4()) self.endpoints = [ Endpoint(self, fn_index, dependency) @@ -71,25 +72,24 @@ class Client: self, *args, api_name: str | None = None, - fn_index: int = 0, + fn_index: int | None = None, result_callbacks: Callable | List[Callable] | None = None, ) -> Future: """ Parameters: *args: The arguments to pass to the remote API. The order of the arguments must match the order of the inputs in the Gradio app. - api_name: The name of the API endpoint to call. If not provided, the first API will be called. Takes precedence over fn_index. - fn_index: The index of the API endpoint to call. If not provided, the first API will be called. + api_name: The name of the API endpoint to call starting with a leading slash, e.g. "/predict". Does not need to be provided if the Gradio app has only one named API endpoint. + fn_index: The index of the API endpoint to call, e.g. 0. Both api_name and fn_index can be provided, but if they conflict, api_name will take precedence. result_callbacks: A callback function, or list of callback functions, to be called when the result is ready. If a list of functions is provided, they will be called in order. The return values from the remote API are provided as separate parameters into the callback. If None, no callback will be called. Returns: A Job object that can be used to retrieve the status and result of the remote API call. """ - if api_name: - fn_index = self._infer_fn_index(api_name) + inferred_fn_index = self._infer_fn_index(api_name, fn_index) helper = None - if self.endpoints[fn_index].use_ws: + if self.endpoints[inferred_fn_index].use_ws: helper = Communicator(Lock(), JobStatus()) - end_to_end_fn = self.endpoints[fn_index].make_end_to_end_fn(helper) + end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper) future = self.executor.submit(end_to_end_fn, *args) job = Job(future, communicator=helper) @@ -115,13 +115,15 @@ class Client: def view_api( self, all_endpoints: bool | None = None, - return_info: bool = False, - ) -> Dict | None: + print_info: bool = True, + return_format: Literal["dict", "str"] | None = None, + ) -> Dict | str | None: """ Prints the usage info for the API. If the Gradio app has multiple API endpoints, the usage info for each endpoint will be printed separately. Parameters: all_endpoints: If True, prints information for both named and unnamed endpoints in the Gradio app. If False, will only print info about named endpoints. If None (default), will only print info about unnamed endpoints if there are no named endpoints. - return_info: If False (default), prints the usage info to the console. If True, returns the usage info as a dictionary that can be programmatically parsed (does not print), and *all endpoints are returned in the dictionary* regardless of the value of `all_endpoints`. The format of the dictionary is in the docstring of this method. + print_info: If True, prints the usage info to the console. If False, does not print the usage info. + return_format: If None, nothing is returned. If "str", returns the same string that would be printed to the console. If "dict", returns the usage info as a dictionary that can be programmatically parsed, and *all endpoints are returned in the dictionary* regardless of the value of `all_endpoints`. The format of the dictionary is in the docstring of this method. Dictionary format: { "named_endpoints": { @@ -153,9 +155,6 @@ class Client: else: info["unnamed_endpoints"][endpoint.fn_index] = endpoint.get_info() - if return_info: - return info - num_named_endpoints = len(info["named_endpoints"]) num_unnamed_endpoints = len(info["unnamed_endpoints"]) if num_named_endpoints == 0 and all_endpoints is None: @@ -175,7 +174,15 @@ class Client: if num_unnamed_endpoints > 0: human_info += f"\nUnnamed API endpoints: {num_unnamed_endpoints}, to view, run Client.view_api(`all_endpoints=True`)\n" - print(human_info) + if print_info: + print(human_info) + if return_format == "str": + return human_info + elif return_format == "dict": + return info + + def reset_session(self) -> None: + self.session_hash = str(uuid.uuid4()) def _render_endpoints_info( self, @@ -199,22 +206,26 @@ class Client: raise ValueError("name_or_index must be a string or integer") human_info = f"\n - predict({rendered_parameters}{final_param}) -> {rendered_return_values}\n" + human_info += " Parameters:\n" if endpoints_info["parameters"]: - human_info += " Parameters:\n" - for label, info in endpoints_info["parameters"].items(): - human_info += f" - [{info[2]}] {label}: {info[0]} ({info[1]})\n" + for label, info in endpoints_info["parameters"].items(): + human_info += f" - [{info[2]}] {label}: {info[0]} ({info[1]})\n" + else: + human_info += " - None\n" + human_info += " Returns:\n" if endpoints_info["returns"]: - human_info += " Returns:\n" - for label, info in endpoints_info["returns"].items(): - human_info += f" - [{info[2]}] {label}: {info[0]} ({info[1]})\n" + for label, info in endpoints_info["returns"].items(): + human_info += f" - [{info[2]}] {label}: {info[0]} ({info[1]})\n" + else: + human_info += " - None\n" return human_info def __repr__(self): - return self.view_api() + return self.view_api(print_info=False, return_format="str") def __str__(self): - return self.view_api() + return self.view_api(print_info=False, return_format="str") def _telemetry_thread(self) -> None: # Disable telemetry by setting the env variable HF_HUB_DISABLE_TELEMETRY=1 @@ -231,11 +242,34 @@ class Client: except Exception: pass - def _infer_fn_index(self, api_name: str) -> int: - for i, d in enumerate(self.config["dependencies"]): - if d.get("api_name") == api_name: - return i - raise ValueError(f"Cannot find a function with api_name: {api_name}") + def _infer_fn_index(self, api_name: str | None, fn_index: int | None) -> int: + inferred_fn_index = None + if api_name is not None: + for i, d in enumerate(self.config["dependencies"]): + config_api_name = d.get("api_name") + if config_api_name is None: + continue + if "/" + config_api_name == api_name: + inferred_fn_index = i + break + else: + error_message = f"Cannot find a function with `api_name`: {api_name}." + if not api_name.startswith("/"): + error_message += " Did you mean to use a leading slash?" + raise ValueError(error_message) + elif fn_index is not None: + inferred_fn_index = fn_index + else: + valid_endpoints = [ + e for e in self.endpoints if e.is_valid and e.api_name is not None + ] + if len(valid_endpoints) == 1: + inferred_fn_index = valid_endpoints[0].fn_index + else: + raise ValueError( + "This Gradio app might have multiple endpoints. Please specify an `api_name` or `fn_index`" + ) + return inferred_fn_index def __del__(self): if hasattr(self, "executor"): @@ -264,15 +298,15 @@ class Endpoint: """Helper class for storing all the information about a single API endpoint.""" def __init__(self, client: Client, fn_index: int, dependency: Dict): - self.api_url = client.api_url - self.ws_url = client.ws_url + self.client: Client = client self.fn_index = fn_index self.dependency = dependency self.api_name: str | None = dependency.get("api_name") - self.headers = client.headers - self.config = client.config + if self.api_name: + self.api_name = "/" + self.api_name self.use_ws = self._use_websocket(self.dependency) - self.hf_token = client.hf_token + self.input_component_types = [] + self.output_component_types = [] try: self.serializers, self.deserializers = self._setup_serializers() self.is_valid = self.dependency[ @@ -298,7 +332,7 @@ class Endpoint: """ parameters = {} for i, input in enumerate(self.dependency["inputs"]): - for component in self.config["components"]: + for component in self.client.config["components"]: if component["id"] == input: label = ( component["props"] @@ -311,11 +345,13 @@ class Endpoint: else: info = self.serializers[i].input_api_info() info = list(info) - info.append(component.get("type", "component").capitalize()) - parameters[label] = info + component_type = component.get("type", "component").capitalize() + info.append(component_type) + if not component_type.lower() == utils.STATE_COMPONENT: + parameters[label] = info returns = {} for o, output in enumerate(self.dependency["outputs"]): - for component in self.config["components"]: + for component in self.client.config["components"]: if component["id"] == output: label = ( component["props"] @@ -328,11 +364,19 @@ class Endpoint: else: info = self.deserializers[o].output_api_info() info = list(info) - info.append(component.get("type", "component").capitalize()) - returns[label] = list(info) + component_type = component.get("type", "component").capitalize() + info.append(component_type) + if not component_type.lower() == utils.STATE_COMPONENT: + returns[label] = info return {"parameters": parameters, "returns": returns} + def __repr__(self): + return json.dumps(self.get_info(), indent=4) + + def __str__(self): + return json.dumps(self.get_info(), indent=4) + def make_end_to_end_fn(self, helper: Communicator | None = None): _predict = self.make_predict(helper) @@ -343,7 +387,16 @@ class Endpoint: inputs = self.serialize(*data) predictions = _predict(*inputs) outputs = self.deserialize(*predictions) - if len(self.dependency["outputs"]) == 1: + if ( + len( + [ + oct + for oct in self.output_component_types + if not oct == utils.STATE_COMPONENT + ] + ) + == 1 + ): return outputs[0] return outputs @@ -351,15 +404,27 @@ class Endpoint: def make_predict(self, helper: Communicator | None = None): def _predict(*data) -> Tuple: - data = json.dumps({"data": data, "fn_index": self.fn_index}) - hash_data = json.dumps( - {"fn_index": self.fn_index, "session_hash": str(uuid.uuid4())} + data = json.dumps( + { + "data": data, + "fn_index": self.fn_index, + "session_hash": self.client.session_hash, + } ) + hash_data = json.dumps( + { + "fn_index": self.fn_index, + "session_hash": self.client.session_hash, + } + ) + if self.use_ws: result = utils.synchronize_async(self._ws_fn, data, hash_data, helper) output = result["data"] else: - response = requests.post(self.api_url, headers=self.headers, data=data) + response = requests.post( + self.client.api_url, headers=self.client.headers, data=data + ) result = json.loads(response.content.decode("utf-8")) try: output = result["data"] @@ -383,6 +448,11 @@ class Endpoint: return outputs def serialize(self, *data) -> Tuple: + for i, input_component_type in enumerate(self.input_component_types): + if input_component_type == utils.STATE_COMPONENT: + data = list(data) + data.insert(i, None) + data = tuple(data) assert len(data) == len( self.serializers ), f"Expected {len(self.serializers)} arguments, got {len(data)}" @@ -394,8 +464,11 @@ class Endpoint: ), f"Expected {len(self.deserializers)} outputs, got {len(data)}" return tuple( [ - s.deserialize(d, hf_token=self.hf_token) - for s, d in zip(self.deserializers, data) + s.deserialize(d, hf_token=self.client.hf_token) + for s, d, oct in zip( + self.deserializers, data, self.output_component_types + ) + if not oct == utils.STATE_COMPONENT ] ) @@ -404,8 +477,10 @@ class Endpoint: serializers = [] for i in inputs: - for component in self.config["components"]: + for component in self.client.config["components"]: if component["id"] == i: + component_name = component["type"] + self.input_component_types.append(component_name) if component.get("serializer"): serializer_name = component["serializer"] assert ( @@ -413,7 +488,6 @@ class Endpoint: ), f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version." serializer = serializing.SERIALIZER_MAPPING[serializer_name] else: - component_name = component["type"] assert ( component_name in serializing.COMPONENT_MAPPING ), f"Unknown component: {component_name}, you may need to update your gradio_client version." @@ -423,8 +497,10 @@ class Endpoint: outputs = self.dependency["outputs"] deserializers = [] for i in outputs: - for component in self.config["components"]: + for component in self.client.config["components"]: if component["id"] == i: + component_name = component["type"] + self.output_component_types.append(component_name) if component.get("serializer"): serializer_name = component["serializer"] assert ( @@ -432,7 +508,6 @@ class Endpoint: ), f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version." deserializer = serializing.SERIALIZER_MAPPING[serializer_name] else: - component_name = component["type"] assert ( component_name in serializing.COMPONENT_MAPPING ), f"Unknown component: {component_name}, you may need to update your gradio_client version." @@ -442,16 +517,16 @@ class Endpoint: return serializers, deserializers def _use_websocket(self, dependency: Dict) -> bool: - queue_enabled = self.config.get("enable_queue", False) + queue_enabled = self.client.config.get("enable_queue", False) queue_uses_websocket = version.parse( - self.config.get("version", "2.0") + self.client.config.get("version", "2.0") ) >= version.Version("3.2") dependency_uses_queue = dependency.get("queue", False) is not False return queue_enabled and queue_uses_websocket and dependency_uses_queue async def _ws_fn(self, data, hash_data, helper: Communicator): async with websockets.connect( # type: ignore - self.ws_url, open_timeout=10, extra_headers=self.headers + self.client.ws_url, open_timeout=10, extra_headers=self.client.headers ) as websocket: return await utils.get_pred_from_ws(websocket, data, hash_data, helper) diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index a0903e29a4..e2e828c43d 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -20,6 +20,7 @@ from websockets.legacy.protocol import WebSocketCommonProtocol API_URL = "{}/api/predict/" WS_URL = "{}/queue/join" +STATE_COMPONENT = "state" __version__ = (pkgutil.get_data(__name__, "version.txt") or b"").decode("ascii").strip() diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index 96788d00d6..352998701a 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -19,15 +19,45 @@ class TestPredictionsFromSpaces: @pytest.mark.flaky def test_numerical_to_label_space(self): client = Client("gradio-tests/titanic-survival") - output = client.predict("male", 77, 10).result() + output = client.predict("male", 77, 10, api_name="/predict").result() assert json.load(open(output))["label"] == "Perishes" + with pytest.raises( + ValueError, + match="This Gradio app might have multiple endpoints. Please specify an `api_name` or `fn_index`", + ): + client.predict("male", 77, 10) + with pytest.raises( + ValueError, + match="Cannot find a function with `api_name`: predict. Did you mean to use a leading slash?", + ): + client.predict("male", 77, 10, api_name="predict") @pytest.mark.flaky def test_private_space(self): client = Client("gradio-tests/not-actually-private-space", hf_token=HF_TOKEN) - output = client.predict("abc").result() + output = client.predict("abc", api_name="/predict").result() assert output == "abc" + @pytest.mark.flaky + def test_state(self): + client = Client("gradio-tests/increment") + output = client.predict(api_name="/increment_without_queue").result() + assert output == 1 + output = client.predict(api_name="/increment_without_queue").result() + assert output == 2 + output = client.predict(api_name="/increment_without_queue").result() + assert output == 3 + client.reset_session() + output = client.predict(api_name="/increment_without_queue").result() + assert output == 1 + output = client.predict(api_name="/increment_with_queue").result() + assert output == 2 + client.reset_session() + output = client.predict(api_name="/increment_with_queue").result() + assert output == 1 + output = client.predict(api_name="/increment_with_queue").result() + assert output == 2 + @pytest.mark.flaky def test_job_status(self): statuses = [] @@ -50,7 +80,7 @@ class TestPredictionsFromSpaces: def test_job_status_queue_disabled(self): statuses = [] client = Client(src="freddyaboulton/sentiment-classification") - job = client.predict("I love the gradio python client") + job = client.predict("I love the gradio python client", fn_index=0) while not job.done(): time.sleep(0.02) statuses.append(job.status()) @@ -128,7 +158,7 @@ class TestStatusUpdates: mock_make_end_to_end_fn.side_effect = MockEndToEndFunction client = Client(src="gradio/calculator") - job = client.predict(5, "add", 6, fn_index=0) + job = client.predict(5, "add", 6) statuses = [] while not job.done(): @@ -200,8 +230,8 @@ class TestStatusUpdates: mock_make_end_to_end_fn.side_effect = MockEndToEndFunction client = Client(src="gradio/calculator") - job_1 = client.predict(5, "add", 6, fn_index=0) - job_2 = client.predict(11, "subtract", 1, fn_index=0) + job_1 = client.predict(5, "add", 6) + job_2 = client.predict(11, "subtract", 1) statuses_1 = [] statuses_2 = [] @@ -213,7 +243,7 @@ class TestStatusUpdates: assert all(s in messages_1 for s in statuses_1) -class TestEndpoints: +class TestAPIInfo: @pytest.mark.flaky def test_numerical_to_label_space(self): client = Client("gradio-tests/titanic-survival") @@ -225,9 +255,9 @@ class TestEndpoints: }, "returns": {"output": ["str", "filepath to json file", "Label"]}, } - assert client.view_api(return_info=True) == { + assert client.view_api(return_format="dict") == { "named_endpoints": { - "predict": { + "/predict": { "parameters": { "sex": ["Any", "", "Radio"], "age": ["Any", "", "Slider"], @@ -235,7 +265,7 @@ class TestEndpoints: }, "returns": {"output": ["str", "filepath to json file", "Label"]}, }, - "predict_1": { + "/predict_1": { "parameters": { "sex": ["Any", "", "Radio"], "age": ["Any", "", "Slider"], @@ -243,7 +273,7 @@ class TestEndpoints: }, "returns": {"output": ["str", "filepath to json file", "Label"]}, }, - "predict_2": { + "/predict_2": { "parameters": { "sex": ["Any", "", "Radio"], "age": ["Any", "", "Slider"], @@ -272,9 +302,9 @@ class TestEndpoints: "parameters": {"x": ["Any", "", "Textbox"]}, "returns": {"output": ["Any", "", "Textbox"]}, } - assert client.view_api(return_info=True) == { + assert client.view_api(return_format="dict") == { "named_endpoints": { - "predict": { + "/predict": { "parameters": {"x": ["Any", "", "Textbox"]}, "returns": {"output": ["Any", "", "Textbox"]}, }