mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-18 10:44:33 +08:00
Client: Support endpoints that return layout components (#4871)
* Add code * CHANGELOG * Add code * Use set * Modify _setup_serializers instead * Push up code * Remove from serializing * Add gradio_client changes * Update requirements.txt
This commit is contained in:
parent
f2fd37ee59
commit
7b18891aae
@ -18,6 +18,7 @@ with gr.Blocks() as demo:
|
||||
gr.Markdown("سلام", rtl=True)
|
||||
demo.launch()
|
||||
```
|
||||
- The `get_api_info` method of `Blocks` now supports layout output components [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4871](https://github.com/gradio-app/gradio/pull/4871)
|
||||
|
||||
## Bug Fixes:
|
||||
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
## New Features:
|
||||
|
||||
No changes to highlight
|
||||
- Endpoints that return layout components are now properly handled in the `submit` and `view_api` methods. Output layout components are not returned by the API but all other components are (excluding `gr.State`). By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4871](https://github.com/gradio-app/gradio/pull/4871)
|
||||
|
||||
|
||||
## Bug Fixes:
|
||||
|
||||
|
@ -427,7 +427,7 @@ class Client:
|
||||
|
||||
# Versions of Gradio older than 3.29.0 returned format of the API info
|
||||
# from the /info endpoint
|
||||
if version.parse(self.config.get("version", "2.0")) > version.Version("3.29.0"):
|
||||
if version.parse(self.config.get("version", "2.0")) > version.Version("3.36.1"):
|
||||
r = requests.get(api_info_url, headers=self.headers)
|
||||
if r.ok:
|
||||
info = r.json()
|
||||
@ -775,11 +775,11 @@ class Endpoint:
|
||||
data.insert(i, None)
|
||||
return tuple(data)
|
||||
|
||||
def remove_state(self, *data) -> tuple:
|
||||
def remove_skipped_components(self, *data) -> tuple:
|
||||
data = [
|
||||
d
|
||||
for d, oct in zip(data, self.output_component_types)
|
||||
if oct != utils.STATE_COMPONENT
|
||||
if oct not in utils.SKIP_COMPONENTS
|
||||
]
|
||||
return tuple(data)
|
||||
|
||||
@ -789,7 +789,7 @@ class Endpoint:
|
||||
[
|
||||
oct
|
||||
for oct in self.output_component_types
|
||||
if oct != utils.STATE_COMPONENT
|
||||
if oct not in utils.SKIP_COMPONENTS
|
||||
]
|
||||
)
|
||||
== 1
|
||||
@ -834,7 +834,7 @@ class Endpoint:
|
||||
def process_predictions(self, *predictions):
|
||||
if self.client.serialize:
|
||||
predictions = self.deserialize(*predictions)
|
||||
predictions = self.remove_state(*predictions)
|
||||
predictions = self.remove_skipped_components(*predictions)
|
||||
predictions = self.reduce_singleton_output(*predictions)
|
||||
return predictions
|
||||
|
||||
@ -873,6 +873,8 @@ class Endpoint:
|
||||
serializer_name in serializing.SERIALIZER_MAPPING
|
||||
), f"Unknown serializer: {serializer_name}, you may need to update your gradio_client version."
|
||||
deserializer = serializing.SERIALIZER_MAPPING[serializer_name]
|
||||
elif component_name in utils.SKIP_COMPONENTS:
|
||||
deserializer = serializing.SimpleSerializable
|
||||
else:
|
||||
assert (
|
||||
component_name in serializing.COMPONENT_MAPPING
|
||||
|
@ -543,8 +543,6 @@ COMPONENT_MAPPING: dict[str, type] = {
|
||||
"lineplot": JSONSerializable,
|
||||
"scatterplot": JSONSerializable,
|
||||
"markdown": StringSerializable,
|
||||
"dataset": StringSerializable,
|
||||
"code": StringSerializable,
|
||||
"interpretation": SimpleSerializable,
|
||||
"annotatedimage": JSONSerializable,
|
||||
}
|
||||
|
@ -35,6 +35,20 @@ SPACE_FETCHER_URL = "https://gradio-space-api-fetcher-v2.hf.space/api"
|
||||
RESET_URL = "reset"
|
||||
SPACE_URL = "https://hf.space/{}"
|
||||
|
||||
SKIP_COMPONENTS = {
|
||||
"state",
|
||||
"row",
|
||||
"column",
|
||||
"tabs",
|
||||
"tab",
|
||||
"tabitem",
|
||||
"box",
|
||||
"form",
|
||||
"accordion",
|
||||
"group",
|
||||
"interpretation",
|
||||
"dataset",
|
||||
}
|
||||
STATE_COMPONENT = "state"
|
||||
INVALID_RUNTIME = [
|
||||
SpaceStage.NO_APP_FILE,
|
||||
|
@ -1 +1 @@
|
||||
0.2.9
|
||||
0.2.10
|
||||
|
@ -209,6 +209,68 @@ def stateful_chatbot():
|
||||
return demo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hello_world_with_group():
|
||||
with gr.Blocks() as demo:
|
||||
name = gr.Textbox(label="name")
|
||||
output = gr.Textbox(label="greeting")
|
||||
greet = gr.Button("Greet")
|
||||
show_group = gr.Button("Show group")
|
||||
with gr.Group(visible=False) as group:
|
||||
gr.Textbox("Hello!")
|
||||
|
||||
def greeting(name):
|
||||
return f"Hello {name}", gr.Group.update(visible=True)
|
||||
|
||||
greet.click(
|
||||
greeting, inputs=[name], outputs=[output, group], api_name="greeting"
|
||||
)
|
||||
show_group.click(
|
||||
lambda: gr.Group.update(visible=False), None, group, api_name="show_group"
|
||||
)
|
||||
return demo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hello_world_with_state_and_accordion():
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Row():
|
||||
name = gr.Textbox(label="name")
|
||||
output = gr.Textbox(label="greeting")
|
||||
num = gr.Number(label="count")
|
||||
with gr.Row():
|
||||
n_counts = gr.State(value=0)
|
||||
greet = gr.Button("Greet")
|
||||
open_acc = gr.Button("Open acc")
|
||||
close_acc = gr.Button("Close acc")
|
||||
with gr.Accordion(label="Extra stuff", open=False) as accordion:
|
||||
gr.Textbox("Hello!")
|
||||
|
||||
def greeting(name, state):
|
||||
state += 1
|
||||
return state, f"Hello {name}", state, gr.Accordion.update(open=False)
|
||||
|
||||
greet.click(
|
||||
greeting,
|
||||
inputs=[name, n_counts],
|
||||
outputs=[n_counts, output, num, accordion],
|
||||
api_name="greeting",
|
||||
)
|
||||
open_acc.click(
|
||||
lambda state: (state + 1, state + 1, gr.Accordion.update(open=True)),
|
||||
[n_counts],
|
||||
[n_counts, num, accordion],
|
||||
api_name="open",
|
||||
)
|
||||
close_acc.click(
|
||||
lambda state: (state + 1, state + 1, gr.Accordion.update(open=False)),
|
||||
[n_counts],
|
||||
[n_counts, num, accordion],
|
||||
api_name="close",
|
||||
)
|
||||
return demo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def all_components():
|
||||
classes_to_check = gr.components.Component.__subclasses__()
|
||||
|
@ -368,6 +368,24 @@ class TestClientPredictions:
|
||||
assert client.predict("Hello!", api_name="/run") == "Hello!"
|
||||
assert client.predict("Freddy", api_name="/say_hello") == "hello"
|
||||
|
||||
def test_return_layout_component(self, hello_world_with_group):
|
||||
with connect(hello_world_with_group) as demo:
|
||||
assert demo.predict("Freddy", api_name="/greeting") == "Hello Freddy"
|
||||
assert demo.predict(api_name="/show_group") == ()
|
||||
|
||||
def test_return_layout_and_state_components(
|
||||
self, hello_world_with_state_and_accordion
|
||||
):
|
||||
with connect(hello_world_with_state_and_accordion) as demo:
|
||||
assert demo.predict("Freddy", api_name="/greeting") == ("Hello Freddy", 1)
|
||||
assert demo.predict("Abubakar", api_name="/greeting") == (
|
||||
"Hello Abubakar",
|
||||
2,
|
||||
)
|
||||
assert demo.predict(api_name="/open") == 3
|
||||
assert demo.predict(api_name="/close") == 4
|
||||
assert demo.predict("Ali", api_name="/greeting") == ("Hello Ali", 5)
|
||||
|
||||
|
||||
class TestStatusUpdates:
|
||||
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")
|
||||
@ -820,6 +838,113 @@ class TestAPIInfo:
|
||||
"description": "filepath or URL to file",
|
||||
}
|
||||
|
||||
def test_layout_components_in_output(self, hello_world_with_group):
|
||||
with connect(hello_world_with_group) as client:
|
||||
info = client.view_api(return_format="dict")
|
||||
assert info == {
|
||||
"named_endpoints": {
|
||||
"/greeting": {
|
||||
"parameters": [
|
||||
{
|
||||
"label": "name",
|
||||
"type": {"type": "string"},
|
||||
"python_type": {"type": "str", "description": ""},
|
||||
"component": "Textbox",
|
||||
"example_input": "Howdy!",
|
||||
"serializer": "StringSerializable",
|
||||
}
|
||||
],
|
||||
"returns": [
|
||||
{
|
||||
"label": "greeting",
|
||||
"type": {"type": "string"},
|
||||
"python_type": {"type": "str", "description": ""},
|
||||
"component": "Textbox",
|
||||
"serializer": "StringSerializable",
|
||||
}
|
||||
],
|
||||
},
|
||||
"/show_group": {"parameters": [], "returns": []},
|
||||
},
|
||||
"unnamed_endpoints": {},
|
||||
}
|
||||
assert info["named_endpoints"]["/show_group"] == {
|
||||
"parameters": [],
|
||||
"returns": [],
|
||||
}
|
||||
|
||||
def test_layout_and_state_components_in_output(
|
||||
self, hello_world_with_state_and_accordion
|
||||
):
|
||||
with connect(hello_world_with_state_and_accordion) as client:
|
||||
info = client.view_api(return_format="dict")
|
||||
assert info == {
|
||||
"named_endpoints": {
|
||||
"/greeting": {
|
||||
"parameters": [
|
||||
{
|
||||
"label": "name",
|
||||
"type": {"type": "string"},
|
||||
"python_type": {"type": "str", "description": ""},
|
||||
"component": "Textbox",
|
||||
"example_input": "Howdy!",
|
||||
"serializer": "StringSerializable",
|
||||
}
|
||||
],
|
||||
"returns": [
|
||||
{
|
||||
"label": "greeting",
|
||||
"type": {"type": "string"},
|
||||
"python_type": {"type": "str", "description": ""},
|
||||
"component": "Textbox",
|
||||
"serializer": "StringSerializable",
|
||||
},
|
||||
{
|
||||
"label": "count",
|
||||
"type": {"type": "number"},
|
||||
"python_type": {
|
||||
"type": "int | float",
|
||||
"description": "",
|
||||
},
|
||||
"component": "Number",
|
||||
"serializer": "NumberSerializable",
|
||||
},
|
||||
],
|
||||
},
|
||||
"/open": {
|
||||
"parameters": [],
|
||||
"returns": [
|
||||
{
|
||||
"label": "count",
|
||||
"type": {"type": "number"},
|
||||
"python_type": {
|
||||
"type": "int | float",
|
||||
"description": "",
|
||||
},
|
||||
"component": "Number",
|
||||
"serializer": "NumberSerializable",
|
||||
}
|
||||
],
|
||||
},
|
||||
"/close": {
|
||||
"parameters": [],
|
||||
"returns": [
|
||||
{
|
||||
"label": "count",
|
||||
"type": {"type": "number"},
|
||||
"python_type": {
|
||||
"type": "int | float",
|
||||
"description": "",
|
||||
},
|
||||
"component": "Number",
|
||||
"serializer": "NumberSerializable",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
"unnamed_endpoints": {},
|
||||
}
|
||||
|
||||
|
||||
class TestEndpoints:
|
||||
def test_upload(self):
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
from gradio import components
|
||||
|
||||
from gradio_client.serializing import COMPONENT_MAPPING, FileSerializable, Serializable
|
||||
from gradio_client.utils import encode_url_or_file_to_base64
|
||||
from gradio_client.utils import SKIP_COMPONENTS, encode_url_or_file_to_base64
|
||||
|
||||
|
||||
@pytest.mark.parametrize("serializer_class", Serializable.__subclasses__())
|
||||
@ -22,7 +22,7 @@ def test_duplicate(serializer_class):
|
||||
def test_check_component_fallback_serializers():
|
||||
for component_name, class_type in COMPONENT_MAPPING.items():
|
||||
# skip components that cannot be instantiated without parameters
|
||||
if component_name in ["dataset", "interpretation"]:
|
||||
if component_name in SKIP_COMPONENTS:
|
||||
continue
|
||||
component = components.get_component_instance(component_name)
|
||||
assert isinstance(component, class_type)
|
||||
|
@ -503,7 +503,6 @@ def get_api_info(config: dict, serialize: bool = True):
|
||||
for d, dependency in enumerate(config["dependencies"]):
|
||||
dependency_info = {"parameters": [], "returns": []}
|
||||
skip_endpoint = False
|
||||
skip_components = ["state"]
|
||||
|
||||
inputs = dependency["inputs"]
|
||||
for i in inputs:
|
||||
@ -514,13 +513,15 @@ def get_api_info(config: dict, serialize: bool = True):
|
||||
skip_endpoint = True # if component not found, skip endpoint
|
||||
break
|
||||
type = component["type"]
|
||||
if type in client_utils.SKIP_COMPONENTS:
|
||||
continue
|
||||
if (
|
||||
not component.get("serializer")
|
||||
and type not in serializing.COMPONENT_MAPPING
|
||||
):
|
||||
skip_endpoint = True # if component not serializable, skip endpoint
|
||||
break
|
||||
if type in skip_components:
|
||||
if type in client_utils.SKIP_COMPONENTS:
|
||||
continue
|
||||
label = component["props"].get("label", f"parameter_{i}")
|
||||
# The config has the most specific API info (taking into account the parameters
|
||||
@ -568,14 +569,14 @@ def get_api_info(config: dict, serialize: bool = True):
|
||||
skip_endpoint = True # if component not found, skip endpoint
|
||||
break
|
||||
type = component["type"]
|
||||
if type in client_utils.SKIP_COMPONENTS:
|
||||
continue
|
||||
if (
|
||||
not component.get("serializer")
|
||||
and type not in serializing.COMPONENT_MAPPING
|
||||
):
|
||||
skip_endpoint = True # if component not serializable, skip endpoint
|
||||
break
|
||||
if type in skip_components:
|
||||
continue
|
||||
label = component["props"].get("label", f"value_{o}")
|
||||
serializer = serializing.COMPONENT_MAPPING[type]()
|
||||
if component.get("api_info") and after_new_format:
|
||||
|
@ -3,7 +3,7 @@ aiohttp~=3.0
|
||||
altair>=4.2.0,<6.0
|
||||
fastapi
|
||||
ffmpy
|
||||
gradio_client>=0.2.9
|
||||
gradio_client>=0.2.10
|
||||
httpx
|
||||
huggingface_hub>=0.14.0
|
||||
Jinja2<4.0
|
||||
|
Loading…
Reference in New Issue
Block a user