Client: Support endpoints that return layout components (#4871)

* Add code

* CHANGELOG

* Add code

* Use set

* Modify _setup_serializers instead

* Push up code

* Remove from serializing

* Add gradio_client changes

* Update requirements.txt
This commit is contained in:
Freddy Boulton 2023-07-17 12:21:54 -05:00 committed by GitHub
parent f2fd37ee59
commit 7b18891aae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 220 additions and 16 deletions

View File

@ -18,6 +18,7 @@ with gr.Blocks() as demo:
gr.Markdown("سلام", rtl=True)
demo.launch()
```
- The `get_api_info` method of `Blocks` now supports layout output components [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4871](https://github.com/gradio-app/gradio/pull/4871)
## Bug Fixes:

View File

@ -2,7 +2,8 @@
## New Features:
No changes to highlight
- Endpoints that return layout components are now properly handled in the `submit` and `view_api` methods. Output layout components are not returned by the API but all other components are (excluding `gr.State`). By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4871](https://github.com/gradio-app/gradio/pull/4871)
## Bug Fixes:

View File

@ -427,7 +427,7 @@ class Client:
# Versions of Gradio older than 3.29.0 returned format of the API info
# from the /info endpoint
if version.parse(self.config.get("version", "2.0")) > version.Version("3.29.0"):
if version.parse(self.config.get("version", "2.0")) > version.Version("3.36.1"):
r = requests.get(api_info_url, headers=self.headers)
if r.ok:
info = r.json()
@ -775,11 +775,11 @@ class Endpoint:
data.insert(i, None)
return tuple(data)
def remove_state(self, *data) -> tuple:
def remove_skipped_components(self, *data) -> tuple:
data = [
d
for d, oct in zip(data, self.output_component_types)
if oct != utils.STATE_COMPONENT
if oct not in utils.SKIP_COMPONENTS
]
return tuple(data)
@ -789,7 +789,7 @@ class Endpoint:
[
oct
for oct in self.output_component_types
if oct != utils.STATE_COMPONENT
if oct not in utils.SKIP_COMPONENTS
]
)
== 1
@ -834,7 +834,7 @@ class Endpoint:
def process_predictions(self, *predictions):
if self.client.serialize:
predictions = self.deserialize(*predictions)
predictions = self.remove_state(*predictions)
predictions = self.remove_skipped_components(*predictions)
predictions = self.reduce_singleton_output(*predictions)
return predictions
@ -873,6 +873,8 @@ class Endpoint:
serializer_name in serializing.SERIALIZER_MAPPING
), f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
deserializer = serializing.SERIALIZER_MAPPING[serializer_name]
elif component_name in utils.SKIP_COMPONENTS:
deserializer = serializing.SimpleSerializable
else:
assert (
component_name in serializing.COMPONENT_MAPPING

View File

@ -543,8 +543,6 @@ COMPONENT_MAPPING: dict[str, type] = {
"lineplot": JSONSerializable,
"scatterplot": JSONSerializable,
"markdown": StringSerializable,
"dataset": StringSerializable,
"code": StringSerializable,
"interpretation": SimpleSerializable,
"annotatedimage": JSONSerializable,
}

View File

@ -35,6 +35,20 @@ SPACE_FETCHER_URL = "https://gradio-space-api-fetcher-v2.hf.space/api"
RESET_URL = "reset"
SPACE_URL = "https://hf.space/{}"
SKIP_COMPONENTS = {
"state",
"row",
"column",
"tabs",
"tab",
"tabitem",
"box",
"form",
"accordion",
"group",
"interpretation",
"dataset",
}
STATE_COMPONENT = "state"
INVALID_RUNTIME = [
SpaceStage.NO_APP_FILE,

View File

@ -1 +1 @@
0.2.9
0.2.10

View File

@ -209,6 +209,68 @@ def stateful_chatbot():
return demo
@pytest.fixture
def hello_world_with_group():
with gr.Blocks() as demo:
name = gr.Textbox(label="name")
output = gr.Textbox(label="greeting")
greet = gr.Button("Greet")
show_group = gr.Button("Show group")
with gr.Group(visible=False) as group:
gr.Textbox("Hello!")
def greeting(name):
return f"Hello {name}", gr.Group.update(visible=True)
greet.click(
greeting, inputs=[name], outputs=[output, group], api_name="greeting"
)
show_group.click(
lambda: gr.Group.update(visible=False), None, group, api_name="show_group"
)
return demo
@pytest.fixture
def hello_world_with_state_and_accordion():
with gr.Blocks() as demo:
with gr.Row():
name = gr.Textbox(label="name")
output = gr.Textbox(label="greeting")
num = gr.Number(label="count")
with gr.Row():
n_counts = gr.State(value=0)
greet = gr.Button("Greet")
open_acc = gr.Button("Open acc")
close_acc = gr.Button("Close acc")
with gr.Accordion(label="Extra stuff", open=False) as accordion:
gr.Textbox("Hello!")
def greeting(name, state):
state += 1
return state, f"Hello {name}", state, gr.Accordion.update(open=False)
greet.click(
greeting,
inputs=[name, n_counts],
outputs=[n_counts, output, num, accordion],
api_name="greeting",
)
open_acc.click(
lambda state: (state + 1, state + 1, gr.Accordion.update(open=True)),
[n_counts],
[n_counts, num, accordion],
api_name="open",
)
close_acc.click(
lambda state: (state + 1, state + 1, gr.Accordion.update(open=False)),
[n_counts],
[n_counts, num, accordion],
api_name="close",
)
return demo
@pytest.fixture
def all_components():
classes_to_check = gr.components.Component.__subclasses__()

View File

@ -368,6 +368,24 @@ class TestClientPredictions:
assert client.predict("Hello!", api_name="/run") == "Hello!"
assert client.predict("Freddy", api_name="/say_hello") == "hello"
def test_return_layout_component(self, hello_world_with_group):
with connect(hello_world_with_group) as demo:
assert demo.predict("Freddy", api_name="/greeting") == "Hello Freddy"
assert demo.predict(api_name="/show_group") == ()
def test_return_layout_and_state_components(
self, hello_world_with_state_and_accordion
):
with connect(hello_world_with_state_and_accordion) as demo:
assert demo.predict("Freddy", api_name="/greeting") == ("Hello Freddy", 1)
assert demo.predict("Abubakar", api_name="/greeting") == (
"Hello Abubakar",
2,
)
assert demo.predict(api_name="/open") == 3
assert demo.predict(api_name="/close") == 4
assert demo.predict("Ali", api_name="/greeting") == ("Hello Ali", 5)
class TestStatusUpdates:
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")
@ -820,6 +838,113 @@ class TestAPIInfo:
"description": "filepath or URL to file",
}
def test_layout_components_in_output(self, hello_world_with_group):
with connect(hello_world_with_group) as client:
info = client.view_api(return_format="dict")
assert info == {
"named_endpoints": {
"/greeting": {
"parameters": [
{
"label": "name",
"type": {"type": "string"},
"python_type": {"type": "str", "description": ""},
"component": "Textbox",
"example_input": "Howdy!",
"serializer": "StringSerializable",
}
],
"returns": [
{
"label": "greeting",
"type": {"type": "string"},
"python_type": {"type": "str", "description": ""},
"component": "Textbox",
"serializer": "StringSerializable",
}
],
},
"/show_group": {"parameters": [], "returns": []},
},
"unnamed_endpoints": {},
}
assert info["named_endpoints"]["/show_group"] == {
"parameters": [],
"returns": [],
}
def test_layout_and_state_components_in_output(
self, hello_world_with_state_and_accordion
):
with connect(hello_world_with_state_and_accordion) as client:
info = client.view_api(return_format="dict")
assert info == {
"named_endpoints": {
"/greeting": {
"parameters": [
{
"label": "name",
"type": {"type": "string"},
"python_type": {"type": "str", "description": ""},
"component": "Textbox",
"example_input": "Howdy!",
"serializer": "StringSerializable",
}
],
"returns": [
{
"label": "greeting",
"type": {"type": "string"},
"python_type": {"type": "str", "description": ""},
"component": "Textbox",
"serializer": "StringSerializable",
},
{
"label": "count",
"type": {"type": "number"},
"python_type": {
"type": "int | float",
"description": "",
},
"component": "Number",
"serializer": "NumberSerializable",
},
],
},
"/open": {
"parameters": [],
"returns": [
{
"label": "count",
"type": {"type": "number"},
"python_type": {
"type": "int | float",
"description": "",
},
"component": "Number",
"serializer": "NumberSerializable",
}
],
},
"/close": {
"parameters": [],
"returns": [
{
"label": "count",
"type": {"type": "number"},
"python_type": {
"type": "int | float",
"description": "",
},
"component": "Number",
"serializer": "NumberSerializable",
}
],
},
},
"unnamed_endpoints": {},
}
class TestEndpoints:
def test_upload(self):

View File

@ -5,7 +5,7 @@ import pytest
from gradio import components
from gradio_client.serializing import COMPONENT_MAPPING, FileSerializable, Serializable
from gradio_client.utils import encode_url_or_file_to_base64
from gradio_client.utils import SKIP_COMPONENTS, encode_url_or_file_to_base64
@pytest.mark.parametrize("serializer_class", Serializable.__subclasses__())
@ -22,7 +22,7 @@ def test_duplicate(serializer_class):
def test_check_component_fallback_serializers():
for component_name, class_type in COMPONENT_MAPPING.items():
# skip components that cannot be instantiated without parameters
if component_name in ["dataset", "interpretation"]:
if component_name in SKIP_COMPONENTS:
continue
component = components.get_component_instance(component_name)
assert isinstance(component, class_type)

View File

@ -503,7 +503,6 @@ def get_api_info(config: dict, serialize: bool = True):
for d, dependency in enumerate(config["dependencies"]):
dependency_info = {"parameters": [], "returns": []}
skip_endpoint = False
skip_components = ["state"]
inputs = dependency["inputs"]
for i in inputs:
@ -514,13 +513,15 @@ def get_api_info(config: dict, serialize: bool = True):
skip_endpoint = True # if component not found, skip endpoint
break
type = component["type"]
if type in client_utils.SKIP_COMPONENTS:
continue
if (
not component.get("serializer")
and type not in serializing.COMPONENT_MAPPING
):
skip_endpoint = True # if component not serializable, skip endpoint
break
if type in skip_components:
if type in client_utils.SKIP_COMPONENTS:
continue
label = component["props"].get("label", f"parameter_{i}")
# The config has the most specific API info (taking into account the parameters
@ -568,14 +569,14 @@ def get_api_info(config: dict, serialize: bool = True):
skip_endpoint = True # if component not found, skip endpoint
break
type = component["type"]
if type in client_utils.SKIP_COMPONENTS:
continue
if (
not component.get("serializer")
and type not in serializing.COMPONENT_MAPPING
):
skip_endpoint = True # if component not serializable, skip endpoint
break
if type in skip_components:
continue
label = component["props"].get("label", f"value_{o}")
serializer = serializing.COMPONENT_MAPPING[type]()
if component.get("api_info") and after_new_format:

View File

@ -3,7 +3,7 @@ aiohttp~=3.0
altair>=4.2.0,<6.0
fastapi
ffmpy
gradio_client>=0.2.9
gradio_client>=0.2.10
httpx
huggingface_hub>=0.14.0
Jinja2<4.0