From 7b18891aae567fe98155cc824d745304222d5ed0 Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Mon, 17 Jul 2023 12:21:54 -0500 Subject: [PATCH] Client: Support endpoints that return layout components (#4871) * Add code * CHANGELOG * Add code * Use set * Modify _setup_serializers instead * Push up code * Remove from serializing * Add gradio_client changes * Update requirements.txt --- CHANGELOG.md | 1 + client/python/CHANGELOG.md | 3 +- client/python/gradio_client/client.py | 12 +- client/python/gradio_client/serializing.py | 2 - client/python/gradio_client/utils.py | 14 +++ client/python/gradio_client/version.txt | 2 +- client/python/test/conftest.py | 62 ++++++++++ client/python/test/test_client.py | 125 +++++++++++++++++++++ client/python/test/test_serializing.py | 4 +- gradio/blocks.py | 9 +- requirements.txt | 2 +- 11 files changed, 220 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 29f05a8823..86e5023c80 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ with gr.Blocks() as demo: gr.Markdown("سلام", rtl=True) demo.launch() ``` +- The `get_api_info` method of `Blocks` now supports layout output components [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4871](https://github.com/gradio-app/gradio/pull/4871) ## Bug Fixes: diff --git a/client/python/CHANGELOG.md b/client/python/CHANGELOG.md index 02713aeb46..0aa1623d14 100644 --- a/client/python/CHANGELOG.md +++ b/client/python/CHANGELOG.md @@ -2,7 +2,8 @@ ## New Features: -No changes to highlight +- Endpoints that return layout components are now properly handled in the `submit` and `view_api` methods. Output layout components are not returned by the API but all other components are (excluding `gr.State`). By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4871](https://github.com/gradio-app/gradio/pull/4871) + ## Bug Fixes: diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index 85bb33ce40..24b380c9fb 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -427,7 +427,7 @@ class Client: # Versions of Gradio older than 3.29.0 returned format of the API info # from the /info endpoint - if version.parse(self.config.get("version", "2.0")) > version.Version("3.29.0"): + if version.parse(self.config.get("version", "2.0")) > version.Version("3.36.1"): r = requests.get(api_info_url, headers=self.headers) if r.ok: info = r.json() @@ -775,11 +775,11 @@ class Endpoint: data.insert(i, None) return tuple(data) - def remove_state(self, *data) -> tuple: + def remove_skipped_components(self, *data) -> tuple: data = [ d for d, oct in zip(data, self.output_component_types) - if oct != utils.STATE_COMPONENT + if oct not in utils.SKIP_COMPONENTS ] return tuple(data) @@ -789,7 +789,7 @@ class Endpoint: [ oct for oct in self.output_component_types - if oct != utils.STATE_COMPONENT + if oct not in utils.SKIP_COMPONENTS ] ) == 1 @@ -834,7 +834,7 @@ class Endpoint: def process_predictions(self, *predictions): if self.client.serialize: predictions = self.deserialize(*predictions) - predictions = self.remove_state(*predictions) + predictions = self.remove_skipped_components(*predictions) predictions = self.reduce_singleton_output(*predictions) return predictions @@ -873,6 +873,8 @@ class Endpoint: serializer_name in serializing.SERIALIZER_MAPPING ), f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version." deserializer = serializing.SERIALIZER_MAPPING[serializer_name] + elif component_name in utils.SKIP_COMPONENTS: + deserializer = serializing.SimpleSerializable else: assert ( component_name in serializing.COMPONENT_MAPPING diff --git a/client/python/gradio_client/serializing.py b/client/python/gradio_client/serializing.py index a74867f555..6a4315f6bf 100644 --- a/client/python/gradio_client/serializing.py +++ b/client/python/gradio_client/serializing.py @@ -543,8 +543,6 @@ COMPONENT_MAPPING: dict[str, type] = { "lineplot": JSONSerializable, "scatterplot": JSONSerializable, "markdown": StringSerializable, - "dataset": StringSerializable, "code": StringSerializable, - "interpretation": SimpleSerializable, "annotatedimage": JSONSerializable, } diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index d9d305a05b..850e6f8882 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -35,6 +35,20 @@ SPACE_FETCHER_URL = "https://gradio-space-api-fetcher-v2.hf.space/api" RESET_URL = "reset" SPACE_URL = "https://hf.space/{}" +SKIP_COMPONENTS = { + "state", + "row", + "column", + "tabs", + "tab", + "tabitem", + "box", + "form", + "accordion", + "group", + "interpretation", + "dataset", +} STATE_COMPONENT = "state" INVALID_RUNTIME = [ SpaceStage.NO_APP_FILE, diff --git a/client/python/gradio_client/version.txt b/client/python/gradio_client/version.txt index 1866a362b7..13dead7ebf 100644 --- a/client/python/gradio_client/version.txt +++ b/client/python/gradio_client/version.txt @@ -1 +1 @@ -0.2.9 +0.2.10 diff --git a/client/python/test/conftest.py b/client/python/test/conftest.py index 66ffbb1b11..16eae65057 100644 --- a/client/python/test/conftest.py +++ b/client/python/test/conftest.py @@ -209,6 +209,68 @@ def stateful_chatbot(): return demo +@pytest.fixture +def hello_world_with_group(): + with gr.Blocks() as demo: + name = gr.Textbox(label="name") + output = gr.Textbox(label="greeting") + greet = gr.Button("Greet") + show_group = gr.Button("Show group") + with gr.Group(visible=False) as group: + gr.Textbox("Hello!") + + def greeting(name): + return f"Hello {name}", gr.Group.update(visible=True) + + greet.click( + greeting, inputs=[name], outputs=[output, group], api_name="greeting" + ) + show_group.click( + lambda: gr.Group.update(visible=False), None, group, api_name="show_group" + ) + return demo + + +@pytest.fixture +def hello_world_with_state_and_accordion(): + with gr.Blocks() as demo: + with gr.Row(): + name = gr.Textbox(label="name") + output = gr.Textbox(label="greeting") + num = gr.Number(label="count") + with gr.Row(): + n_counts = gr.State(value=0) + greet = gr.Button("Greet") + open_acc = gr.Button("Open acc") + close_acc = gr.Button("Close acc") + with gr.Accordion(label="Extra stuff", open=False) as accordion: + gr.Textbox("Hello!") + + def greeting(name, state): + state += 1 + return state, f"Hello {name}", state, gr.Accordion.update(open=False) + + greet.click( + greeting, + inputs=[name, n_counts], + outputs=[n_counts, output, num, accordion], + api_name="greeting", + ) + open_acc.click( + lambda state: (state + 1, state + 1, gr.Accordion.update(open=True)), + [n_counts], + [n_counts, num, accordion], + api_name="open", + ) + close_acc.click( + lambda state: (state + 1, state + 1, gr.Accordion.update(open=False)), + [n_counts], + [n_counts, num, accordion], + api_name="close", + ) + return demo + + @pytest.fixture def all_components(): classes_to_check = gr.components.Component.__subclasses__() diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index 76f8ccb327..3d3112af9c 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -368,6 +368,24 @@ class TestClientPredictions: assert client.predict("Hello!", api_name="/run") == "Hello!" assert client.predict("Freddy", api_name="/say_hello") == "hello" + def test_return_layout_component(self, hello_world_with_group): + with connect(hello_world_with_group) as demo: + assert demo.predict("Freddy", api_name="/greeting") == "Hello Freddy" + assert demo.predict(api_name="/show_group") == () + + def test_return_layout_and_state_components( + self, hello_world_with_state_and_accordion + ): + with connect(hello_world_with_state_and_accordion) as demo: + assert demo.predict("Freddy", api_name="/greeting") == ("Hello Freddy", 1) + assert demo.predict("Abubakar", api_name="/greeting") == ( + "Hello Abubakar", + 2, + ) + assert demo.predict(api_name="/open") == 3 + assert demo.predict(api_name="/close") == 4 + assert demo.predict("Ali", api_name="/greeting") == ("Hello Ali", 5) + class TestStatusUpdates: @patch("gradio_client.client.Endpoint.make_end_to_end_fn") @@ -820,6 +838,113 @@ class TestAPIInfo: "description": "filepath or URL to file", } + def test_layout_components_in_output(self, hello_world_with_group): + with connect(hello_world_with_group) as client: + info = client.view_api(return_format="dict") + assert info == { + "named_endpoints": { + "/greeting": { + "parameters": [ + { + "label": "name", + "type": {"type": "string"}, + "python_type": {"type": "str", "description": ""}, + "component": "Textbox", + "example_input": "Howdy!", + "serializer": "StringSerializable", + } + ], + "returns": [ + { + "label": "greeting", + "type": {"type": "string"}, + "python_type": {"type": "str", "description": ""}, + "component": "Textbox", + "serializer": "StringSerializable", + } + ], + }, + "/show_group": {"parameters": [], "returns": []}, + }, + "unnamed_endpoints": {}, + } + assert info["named_endpoints"]["/show_group"] == { + "parameters": [], + "returns": [], + } + + def test_layout_and_state_components_in_output( + self, hello_world_with_state_and_accordion + ): + with connect(hello_world_with_state_and_accordion) as client: + info = client.view_api(return_format="dict") + assert info == { + "named_endpoints": { + "/greeting": { + "parameters": [ + { + "label": "name", + "type": {"type": "string"}, + "python_type": {"type": "str", "description": ""}, + "component": "Textbox", + "example_input": "Howdy!", + "serializer": "StringSerializable", + } + ], + "returns": [ + { + "label": "greeting", + "type": {"type": "string"}, + "python_type": {"type": "str", "description": ""}, + "component": "Textbox", + "serializer": "StringSerializable", + }, + { + "label": "count", + "type": {"type": "number"}, + "python_type": { + "type": "int | float", + "description": "", + }, + "component": "Number", + "serializer": "NumberSerializable", + }, + ], + }, + "/open": { + "parameters": [], + "returns": [ + { + "label": "count", + "type": {"type": "number"}, + "python_type": { + "type": "int | float", + "description": "", + }, + "component": "Number", + "serializer": "NumberSerializable", + } + ], + }, + "/close": { + "parameters": [], + "returns": [ + { + "label": "count", + "type": {"type": "number"}, + "python_type": { + "type": "int | float", + "description": "", + }, + "component": "Number", + "serializer": "NumberSerializable", + } + ], + }, + }, + "unnamed_endpoints": {}, + } + class TestEndpoints: def test_upload(self): diff --git a/client/python/test/test_serializing.py b/client/python/test/test_serializing.py index 521a9b9c48..8bf4fb1044 100644 --- a/client/python/test/test_serializing.py +++ b/client/python/test/test_serializing.py @@ -5,7 +5,7 @@ import pytest from gradio import components from gradio_client.serializing import COMPONENT_MAPPING, FileSerializable, Serializable -from gradio_client.utils import encode_url_or_file_to_base64 +from gradio_client.utils import SKIP_COMPONENTS, encode_url_or_file_to_base64 @pytest.mark.parametrize("serializer_class", Serializable.__subclasses__()) @@ -22,7 +22,7 @@ def test_duplicate(serializer_class): def test_check_component_fallback_serializers(): for component_name, class_type in COMPONENT_MAPPING.items(): # skip components that cannot be instantiated without parameters - if component_name in ["dataset", "interpretation"]: + if component_name in SKIP_COMPONENTS: continue component = components.get_component_instance(component_name) assert isinstance(component, class_type) diff --git a/gradio/blocks.py b/gradio/blocks.py index cbf5547c69..f441d39d9b 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -503,7 +503,6 @@ def get_api_info(config: dict, serialize: bool = True): for d, dependency in enumerate(config["dependencies"]): dependency_info = {"parameters": [], "returns": []} skip_endpoint = False - skip_components = ["state"] inputs = dependency["inputs"] for i in inputs: @@ -514,13 +513,15 @@ def get_api_info(config: dict, serialize: bool = True): skip_endpoint = True # if component not found, skip endpoint break type = component["type"] + if type in client_utils.SKIP_COMPONENTS: + continue if ( not component.get("serializer") and type not in serializing.COMPONENT_MAPPING ): skip_endpoint = True # if component not serializable, skip endpoint break - if type in skip_components: + if type in client_utils.SKIP_COMPONENTS: continue label = component["props"].get("label", f"parameter_{i}") # The config has the most specific API info (taking into account the parameters @@ -568,14 +569,14 @@ def get_api_info(config: dict, serialize: bool = True): skip_endpoint = True # if component not found, skip endpoint break type = component["type"] + if type in client_utils.SKIP_COMPONENTS: + continue if ( not component.get("serializer") and type not in serializing.COMPONENT_MAPPING ): skip_endpoint = True # if component not serializable, skip endpoint break - if type in skip_components: - continue label = component["props"].get("label", f"value_{o}") serializer = serializing.COMPONENT_MAPPING[type]() if component.get("api_info") and after_new_format: diff --git a/requirements.txt b/requirements.txt index faa08fe030..ba3de6c9dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ aiohttp~=3.0 altair>=4.2.0,<6.0 fastapi ffmpy -gradio_client>=0.2.9 +gradio_client>=0.2.10 httpx huggingface_hub>=0.14.0 Jinja2<4.0