mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
Documentation-related fixes to the python client (#3663)
* docstring * add documentation * added more serialization classes * format * info * is valid * formatting * changes * fixups * fix tests * machine readable * formatting * client * format * tweaks on printing * version * linting * fix tests * update pypi requirements * updates * type ignore' * fixes * formatting
This commit is contained in:
parent
2f160e2b90
commit
2a8c82de01
@ -9,7 +9,7 @@ Here's the entire code to do it:
|
||||
```python
|
||||
import gradio_client as grc
|
||||
|
||||
client = grc.Client(space="stability-ai/stable-diffusion")
|
||||
client = grc.Client("stability-ai/stable-diffusion")
|
||||
job = client.predict("a hyperrealistic portrait of a cat wearing cyberpunk armor")
|
||||
job.result()
|
||||
|
||||
@ -39,7 +39,7 @@ that is running on Spaces (or anywhere else)!
|
||||
```python
|
||||
import gradio_client as grc
|
||||
|
||||
client = grc.Client(space="abidlabs/en2fr")
|
||||
client = grc.Client("abidlabs/en2fr")
|
||||
```
|
||||
|
||||
**Connecting a general Gradio app**
|
||||
|
@ -22,11 +22,16 @@ from gradio_client.serializing import Serializable
|
||||
class Client:
|
||||
def __init__(
|
||||
self,
|
||||
space: str | None = None,
|
||||
src: str | None = None,
|
||||
src: str,
|
||||
hf_token: str | None = None,
|
||||
max_workers: int = 40,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
src: Either the name of the Hugging Face Space to load, (e.g. "abidlabs/pictionary") or the full URL (including "http" or "https") of the hosted Gradio app to load (e.g. "http://mydomain.com/app" or "https://bec81a83-5b5c-471e.gradio.live/").
|
||||
hf_token: The Hugging Face token to use to access private Spaces. If not provided, only public Spaces can be loaded.
|
||||
max_workers: The maximum number of thread workers that can be used to make requests to the remote Gradio app simultaneously.
|
||||
"""
|
||||
self.hf_token = hf_token
|
||||
self.headers = build_hf_headers(
|
||||
token=hf_token,
|
||||
@ -34,17 +39,15 @@ class Client:
|
||||
library_version=utils.__version__,
|
||||
)
|
||||
|
||||
if space is None and src is None:
|
||||
raise ValueError("Either `space` or `src` must be provided")
|
||||
elif space and src:
|
||||
raise ValueError("Only one of `space` or `src` should be provided")
|
||||
self.src = src or self._space_name_to_src(space)
|
||||
if self.src is None:
|
||||
raise ValueError(
|
||||
f"Could not find Space: {space}. If it is a private Space, please provide an hf_token."
|
||||
)
|
||||
if src.startswith("http://") or src.startswith("https://"):
|
||||
self.src = src
|
||||
else:
|
||||
print(f"Loaded as API: {self.src} ✔")
|
||||
self.src = self._space_name_to_src(src)
|
||||
if self.src is None:
|
||||
raise ValueError(
|
||||
f"Could not find Space: {src}. If it is a private Space, please provide an hf_token."
|
||||
)
|
||||
print(f"Loaded as API: {self.src} ✔")
|
||||
|
||||
self.api_url = utils.API_URL.format(self.src)
|
||||
self.ws_url = utils.WS_URL.format(self.src).replace("http", "ws", 1)
|
||||
@ -68,6 +71,15 @@ class Client:
|
||||
fn_index: int = 0,
|
||||
result_callbacks: Callable | List[Callable] | None = None,
|
||||
) -> Future:
|
||||
"""
|
||||
Parameters:
|
||||
*args: The arguments to pass to the remote API. The order of the arguments must match the order of the inputs in the Gradio app.
|
||||
api_name: The name of the API endpoint to call. If not provided, the first API will be called. Takes precedence over fn_index.
|
||||
fn_index: The index of the API endpoint to call. If not provided, the first API will be called.
|
||||
result_callbacks: A callback function, or list of callback functions, to be called when the result is ready. If a list of functions is provided, they will be called in order. The return values from the remote API are provided as separate parameters into the callback. If None, no callback will be called.
|
||||
Returns:
|
||||
A Job object that can be used to retrieve the status and result of the remote API call.
|
||||
"""
|
||||
if api_name:
|
||||
fn_index = self._infer_fn_index(api_name)
|
||||
|
||||
@ -93,6 +105,110 @@ class Client:
|
||||
|
||||
return job
|
||||
|
||||
def view_api(
|
||||
self,
|
||||
all_endpoints: bool | None = None,
|
||||
return_info: bool = False,
|
||||
) -> Dict | None:
|
||||
"""
|
||||
Prints the usage info for the API. If the Gradio app has multiple API endpoints, the usage info for each endpoint will be printed separately.
|
||||
Parameters:
|
||||
all_endpoints: If True, prints information for both named and unnamed endpoints in the Gradio app. If False, will only print info about named endpoints. If None (default), will only print info about unnamed endpoints if there are no named endpoints.
|
||||
return_info: If False (default), prints the usage info to the console. If True, returns the usage info as a dictionary that can be programmatically parsed (does not print), and *all endpoints are returned in the dictionary* regardless of the value of `all_endpoints`. The format of the dictionary is in the docstring of this method.
|
||||
Dictionary format:
|
||||
{
|
||||
"named_endpoints": {
|
||||
"endpoint_1_name": {
|
||||
"parameters": {
|
||||
"parameter_1_name": ["python type", "description", "component_type"],
|
||||
"parameter_2_name": ["python type", "description", "component_type"],
|
||||
},
|
||||
"returns": {
|
||||
"value_1_name": ["python type", "description", "component_type"],
|
||||
}
|
||||
...
|
||||
"unnamed_endpoints": {
|
||||
"fn_index_1": {
|
||||
...
|
||||
}
|
||||
...
|
||||
}
|
||||
"""
|
||||
info: Dict[str, Dict[str | int, Dict[str, Dict[str, List[str]]]]] = {
|
||||
"named_endpoints": {},
|
||||
"unnamed_endpoints": {},
|
||||
}
|
||||
|
||||
for endpoint in self.endpoints:
|
||||
if endpoint.is_valid:
|
||||
if endpoint.api_name:
|
||||
info["named_endpoints"][endpoint.api_name] = endpoint.get_info()
|
||||
else:
|
||||
info["unnamed_endpoints"][endpoint.fn_index] = endpoint.get_info()
|
||||
|
||||
if return_info:
|
||||
return info
|
||||
|
||||
num_named_endpoints = len(info["named_endpoints"])
|
||||
num_unnamed_endpoints = len(info["unnamed_endpoints"])
|
||||
if num_named_endpoints == 0 and all_endpoints is None:
|
||||
all_endpoints = True
|
||||
|
||||
human_info = "Client.predict() Usage Info\n---------------------------\n"
|
||||
human_info += f"Named API endpoints: {num_named_endpoints}\n"
|
||||
|
||||
for api_name, endpoint_info in info["named_endpoints"].items():
|
||||
human_info += self._render_endpoints_info(api_name, endpoint_info)
|
||||
|
||||
if all_endpoints:
|
||||
human_info += f"\nUnnamed API endpoints: {num_unnamed_endpoints}\n"
|
||||
for fn_index, endpoint_info in info["unnamed_endpoints"].items():
|
||||
human_info += self._render_endpoints_info(fn_index, endpoint_info)
|
||||
else:
|
||||
if num_unnamed_endpoints > 0:
|
||||
human_info += f"\nUnnamed API endpoints: {num_unnamed_endpoints}, to view, run Client.view_api(`all_endpoints=True`)\n"
|
||||
|
||||
print(human_info)
|
||||
|
||||
def _render_endpoints_info(
|
||||
self,
|
||||
name_or_index: str | int,
|
||||
endpoints_info: Dict[str, Dict[str, List[str]]],
|
||||
) -> str:
|
||||
parameter_names = list(endpoints_info["parameters"].keys())
|
||||
rendered_parameters = ", ".join(parameter_names)
|
||||
if rendered_parameters:
|
||||
rendered_parameters = rendered_parameters + ", "
|
||||
return_value_names = list(endpoints_info["returns"].keys())
|
||||
rendered_return_values = ", ".join(return_value_names)
|
||||
if len(return_value_names) > 1:
|
||||
rendered_return_values = f"({rendered_return_values})"
|
||||
|
||||
if isinstance(name_or_index, str):
|
||||
final_param = f'api_name="{name_or_index}"'
|
||||
elif isinstance(name_or_index, int):
|
||||
final_param = f"fn_index={name_or_index}"
|
||||
else:
|
||||
raise ValueError("name_or_index must be a string or integer")
|
||||
|
||||
human_info = f"\n - predict({rendered_parameters}{final_param}) -> {rendered_return_values}\n"
|
||||
if endpoints_info["parameters"]:
|
||||
human_info += " Parameters:\n"
|
||||
for label, info in endpoints_info["parameters"].items():
|
||||
human_info += f" - [{info[2]}] {label}: {info[0]} ({info[1]})\n"
|
||||
if endpoints_info["returns"]:
|
||||
human_info += " Returns:\n"
|
||||
for label, info in endpoints_info["returns"].items():
|
||||
human_info += f" - [{info[2]}] {label}: {info[0]} ({info[1]})\n"
|
||||
|
||||
return human_info
|
||||
|
||||
def __repr__(self):
|
||||
return self.view_api()
|
||||
|
||||
def __str__(self):
|
||||
return self.view_api()
|
||||
|
||||
def _telemetry_thread(self) -> None:
|
||||
# Disable telemetry by setting the env variable HF_HUB_DISABLE_TELEMETRY=1
|
||||
data = {
|
||||
@ -132,7 +248,7 @@ class Client:
|
||||
raise ValueError(f"Could not get Gradio config from: {self.src}")
|
||||
if "allow_flagging" in config:
|
||||
raise ValueError(
|
||||
"Gradio 2.x is not supported by this client. Please upgrade this app to Gradio 3.x."
|
||||
"Gradio 2.x is not supported by this client. Please upgrade your Gradio app to Gradio 3.x or higher."
|
||||
)
|
||||
return config
|
||||
|
||||
@ -145,6 +261,7 @@ class Endpoint:
|
||||
self.ws_url = client.ws_url
|
||||
self.fn_index = fn_index
|
||||
self.dependency = dependency
|
||||
self.api_name: str | None = dependency.get("api_name")
|
||||
self.headers = client.headers
|
||||
self.config = client.config
|
||||
self.use_ws = self._use_websocket(self.dependency)
|
||||
@ -153,10 +270,62 @@ class Endpoint:
|
||||
self.serializers, self.deserializers = self._setup_serializers()
|
||||
self.is_valid = self.dependency[
|
||||
"backend_fn"
|
||||
] # Only a real API endpoint if backend_fn is True
|
||||
] # Only a real API endpoint if backend_fn is True and serializers are valid
|
||||
except AssertionError:
|
||||
self.is_valid = False
|
||||
|
||||
def get_info(self) -> Dict[str, Dict[str, List[str]]]:
|
||||
"""
|
||||
Dictionary format:
|
||||
{
|
||||
"parameters": {
|
||||
"parameter_1_name": ["type", "description", "component_type"],
|
||||
"parameter_2_name": ["type", "description", "component_type"],
|
||||
...
|
||||
},
|
||||
"returns": {
|
||||
"value_1_name": ["type", "description", "component_type"],
|
||||
...
|
||||
}
|
||||
}
|
||||
"""
|
||||
parameters = {}
|
||||
for i, input in enumerate(self.dependency["inputs"]):
|
||||
for component in self.config["components"]:
|
||||
if component["id"] == input:
|
||||
label = (
|
||||
component["props"]
|
||||
.get("label", f"parameter_{i}")
|
||||
.lower()
|
||||
.replace(" ", "_")
|
||||
)
|
||||
if "info" in component:
|
||||
info = component["info"]["input"]
|
||||
else:
|
||||
info = self.serializers[i].input_api_info()
|
||||
info = list(info)
|
||||
info.append(component.get("type", "component").capitalize())
|
||||
parameters[label] = info
|
||||
returns = {}
|
||||
for o, output in enumerate(self.dependency["outputs"]):
|
||||
for component in self.config["components"]:
|
||||
if component["id"] == output:
|
||||
label = (
|
||||
component["props"]
|
||||
.get("label", f"value_{o}")
|
||||
.lower()
|
||||
.replace(" ", "_")
|
||||
)
|
||||
if "info" in component:
|
||||
info = component["info"]["output"]
|
||||
else:
|
||||
info = self.deserializers[o].output_api_info()
|
||||
info = list(info)
|
||||
info.append(component.get("type", "component").capitalize())
|
||||
returns[label] = list(info)
|
||||
|
||||
return {"parameters": parameters, "returns": returns}
|
||||
|
||||
def end_to_end_fn(self, *data):
|
||||
if not self.is_valid:
|
||||
raise utils.InvalidAPIEndpointError()
|
||||
|
@ -5,20 +5,32 @@ import os
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from gradio_client import utils
|
||||
|
||||
|
||||
class Serializable(ABC):
|
||||
@abstractmethod
|
||||
def serialize(self, x: Any, load_dir: str | Path = ""):
|
||||
def input_api_info(self) -> Tuple[str, str]:
|
||||
"""
|
||||
Convert data from human-readable format to serialized format for a browser.
|
||||
Get the type of input that should be provided via API, and a human-readable description of the input as a tuple (for documentation generation).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def output_api_info(self) -> Tuple[str, str]:
|
||||
"""
|
||||
Get the type of output that should be returned via API, and a human-readable description of the output as a tuple (for documentation generation).
|
||||
"""
|
||||
pass
|
||||
|
||||
def serialize(self, x: Any, load_dir: str | Path = ""):
|
||||
"""
|
||||
Convert data from human-readable format to serialized format for a browser.
|
||||
"""
|
||||
return x
|
||||
|
||||
def deserialize(
|
||||
self,
|
||||
x: Any,
|
||||
@ -29,38 +41,68 @@ class Serializable(ABC):
|
||||
"""
|
||||
Convert data from serialized format for a browser to human-readable format.
|
||||
"""
|
||||
pass
|
||||
return x
|
||||
|
||||
|
||||
class SimpleSerializable(Serializable):
|
||||
def serialize(self, x: Any, load_dir: str | Path = "") -> Any:
|
||||
"""
|
||||
Convert data from human-readable format to serialized format. For SimpleSerializable components, this is a no-op.
|
||||
Parameters:
|
||||
x: Input data to serialize
|
||||
load_dir: Ignored
|
||||
"""
|
||||
return x
|
||||
"""General class that does not perform any serialization or deserialization."""
|
||||
|
||||
def deserialize(
|
||||
self,
|
||||
x: Any,
|
||||
save_dir: str | Path | None = None,
|
||||
root_url: str | None = None,
|
||||
hf_token: str | None = None,
|
||||
):
|
||||
"""
|
||||
Convert data from serialized format to human-readable format. For SimpleSerializable components, this is a no-op.
|
||||
Parameters:
|
||||
x: Input data to deserialize
|
||||
save_dir: Ignored
|
||||
root_url: Ignored
|
||||
hf_token: Ignored
|
||||
"""
|
||||
return x
|
||||
def input_api_info(self) -> Tuple[str, str]:
|
||||
return "Any", ""
|
||||
|
||||
def output_api_info(self) -> Tuple[str, str]:
|
||||
return "Any", ""
|
||||
|
||||
|
||||
class StringSerializable(Serializable):
|
||||
"""Expects a string as input/output but performs no serialization."""
|
||||
|
||||
def input_api_info(self) -> Tuple[str, str]:
|
||||
return "str", "value"
|
||||
|
||||
def output_api_info(self) -> Tuple[str, str]:
|
||||
return "str", "value"
|
||||
|
||||
|
||||
class ListStringSerializable(Serializable):
|
||||
"""Expects a list of strings as input/output but performs no serialization."""
|
||||
|
||||
def input_api_info(self) -> Tuple[str, str]:
|
||||
return "List[str]", "values"
|
||||
|
||||
def output_api_info(self) -> Tuple[str, str]:
|
||||
return "List[str]", "values"
|
||||
|
||||
|
||||
class BooleanSerializable(Serializable):
|
||||
"""Expects a boolean as input/output but performs no serialization."""
|
||||
|
||||
def input_api_info(self) -> Tuple[str, str]:
|
||||
return "bool", "value"
|
||||
|
||||
def output_api_info(self) -> Tuple[str, str]:
|
||||
return "bool", "value"
|
||||
|
||||
|
||||
class NumberSerializable(Serializable):
|
||||
"""Expects a number (int/float) as input/output but performs no serialization."""
|
||||
|
||||
def input_api_info(self) -> Tuple[str, str]:
|
||||
return "int | float", "value"
|
||||
|
||||
def output_api_info(self) -> Tuple[str, str]:
|
||||
return "int | float", "value"
|
||||
|
||||
|
||||
class ImgSerializable(Serializable):
|
||||
"""Expects a base64 string as input/output which is ."""
|
||||
|
||||
def input_api_info(self) -> Tuple[str, str]:
|
||||
return "str", "filepath or URL"
|
||||
|
||||
def output_api_info(self) -> Tuple[str, str]:
|
||||
return "str", "filepath or URL"
|
||||
|
||||
def serialize(
|
||||
self,
|
||||
x: str | None,
|
||||
@ -102,6 +144,12 @@ class ImgSerializable(Serializable):
|
||||
|
||||
|
||||
class FileSerializable(Serializable):
|
||||
def input_api_info(self) -> Tuple[str, str]:
|
||||
return "str", "filepath or URL"
|
||||
|
||||
def output_api_info(self) -> Tuple[str, str]:
|
||||
return "str", "filepath or URL"
|
||||
|
||||
def serialize(
|
||||
self,
|
||||
x: str | None,
|
||||
@ -168,6 +216,12 @@ class FileSerializable(Serializable):
|
||||
|
||||
|
||||
class JSONSerializable(Serializable):
|
||||
def input_api_info(self) -> Tuple[str, str]:
|
||||
return "str", "filepath to json file"
|
||||
|
||||
def output_api_info(self) -> Tuple[str, str]:
|
||||
return "str", "filepath to json file"
|
||||
|
||||
def serialize(
|
||||
self,
|
||||
x: str | None,
|
||||
@ -206,6 +260,12 @@ class JSONSerializable(Serializable):
|
||||
|
||||
|
||||
class GallerySerializable(Serializable):
|
||||
def input_api_info(self) -> Tuple[str, str]:
|
||||
return "str", "path to directory with images and captions.json"
|
||||
|
||||
def output_api_info(self) -> Tuple[str, str]:
|
||||
return "str", "path to directory with images and captions.json"
|
||||
|
||||
def serialize(
|
||||
self, x: str | None, load_dir: str | Path = ""
|
||||
) -> List[List[str]] | None:
|
||||
@ -249,13 +309,13 @@ class GallerySerializable(Serializable):
|
||||
|
||||
SERIALIZER_MAPPING = {cls.__name__: cls for cls in Serializable.__subclasses__()}
|
||||
|
||||
COMPONENT_MAPPING = {
|
||||
"textbox": SimpleSerializable,
|
||||
"number": SimpleSerializable,
|
||||
"slider": SimpleSerializable,
|
||||
"checkbox": SimpleSerializable,
|
||||
"checkboxgroup": SimpleSerializable,
|
||||
"radio": SimpleSerializable,
|
||||
COMPONENT_MAPPING: Dict[str, type] = {
|
||||
"textbox": StringSerializable,
|
||||
"number": NumberSerializable,
|
||||
"slider": NumberSerializable,
|
||||
"checkbox": BooleanSerializable,
|
||||
"checkboxgroup": ListStringSerializable,
|
||||
"radio": StringSerializable,
|
||||
"dropdown": SimpleSerializable,
|
||||
"image": ImgSerializable,
|
||||
"video": FileSerializable,
|
||||
@ -264,18 +324,18 @@ COMPONENT_MAPPING = {
|
||||
"dataframe": JSONSerializable,
|
||||
"timeseries": JSONSerializable,
|
||||
"state": SimpleSerializable,
|
||||
"button": SimpleSerializable,
|
||||
"button": StringSerializable,
|
||||
"uploadbutton": FileSerializable,
|
||||
"colorpicker": SimpleSerializable,
|
||||
"colorpicker": StringSerializable,
|
||||
"label": JSONSerializable,
|
||||
"highlightedtext": JSONSerializable,
|
||||
"json": JSONSerializable,
|
||||
"html": SimpleSerializable,
|
||||
"html": StringSerializable,
|
||||
"gallery": GallerySerializable,
|
||||
"chatbot": JSONSerializable,
|
||||
"model3d": FileSerializable,
|
||||
"plot": JSONSerializable,
|
||||
"markdown": SimpleSerializable,
|
||||
"dataset": SimpleSerializable,
|
||||
"code": SimpleSerializable,
|
||||
"markdown": StringSerializable,
|
||||
"dataset": StringSerializable,
|
||||
"code": StringSerializable,
|
||||
}
|
||||
|
@ -1 +1 @@
|
||||
0.0.4
|
||||
0.0.5
|
@ -9,4 +9,5 @@ python -m isort --profile=black --check-only test gradio_client
|
||||
python -m flake8 --ignore=E731,E501,E722,W503,E126,E203,F403,F541 test gradio_client --exclude gradio_client/__init__.py
|
||||
|
||||
echo "Testing..."
|
||||
python -m pip install -e ../../. # Install gradio from local source (as the latest version may not yet be published to PyPI)
|
||||
python -m pytest test
|
||||
|
@ -2,5 +2,4 @@ black==22.6.0
|
||||
flake8==4.0.1
|
||||
isort==5.10.1
|
||||
pytest==7.1.2
|
||||
gradio>=3.23.0
|
||||
pytest-asyncio
|
@ -4,19 +4,89 @@ import pytest
|
||||
|
||||
from gradio_client import Client
|
||||
|
||||
HF_TOKEN = "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" # Intentionally revealing this key for testing purposes
|
||||
|
||||
|
||||
class TestPredictionsFromSpaces:
|
||||
@pytest.mark.flaky
|
||||
def test_numerical_to_label_space(self):
|
||||
client = Client(space="abidlabs/titanic-survival")
|
||||
client = Client("gradio-tests/titanic-survival")
|
||||
output = client.predict("male", 77, 10).result()
|
||||
assert json.load(open(output))["label"] == "Perishes"
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_private_space(self):
|
||||
hf_token = "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" # Intentionally revealing this key for testing purposes
|
||||
client = Client(
|
||||
space="gradio-tests/not-actually-private-space", hf_token=hf_token
|
||||
)
|
||||
client = Client("gradio-tests/not-actually-private-space", hf_token=HF_TOKEN)
|
||||
output = client.predict("abc").result()
|
||||
assert output == "abc"
|
||||
|
||||
|
||||
class TestEndpoints:
|
||||
@pytest.mark.flaky
|
||||
def test_numerical_to_label_space(self):
|
||||
client = Client("gradio-tests/titanic-survival")
|
||||
assert client.endpoints[0].get_info() == {
|
||||
"parameters": {
|
||||
"sex": ["Any", "", "Radio"],
|
||||
"age": ["Any", "", "Slider"],
|
||||
"fare_(british_pounds)": ["Any", "", "Slider"],
|
||||
},
|
||||
"returns": {"output": ["str", "filepath to json file", "Label"]},
|
||||
}
|
||||
assert client.view_api(return_info=True) == {
|
||||
"named_endpoints": {
|
||||
"predict": {
|
||||
"parameters": {
|
||||
"sex": ["Any", "", "Radio"],
|
||||
"age": ["Any", "", "Slider"],
|
||||
"fare_(british_pounds)": ["Any", "", "Slider"],
|
||||
},
|
||||
"returns": {"output": ["str", "filepath to json file", "Label"]},
|
||||
},
|
||||
"predict_1": {
|
||||
"parameters": {
|
||||
"sex": ["Any", "", "Radio"],
|
||||
"age": ["Any", "", "Slider"],
|
||||
"fare_(british_pounds)": ["Any", "", "Slider"],
|
||||
},
|
||||
"returns": {"output": ["str", "filepath to json file", "Label"]},
|
||||
},
|
||||
"predict_2": {
|
||||
"parameters": {
|
||||
"sex": ["Any", "", "Radio"],
|
||||
"age": ["Any", "", "Slider"],
|
||||
"fare_(british_pounds)": ["Any", "", "Slider"],
|
||||
},
|
||||
"returns": {"output": ["str", "filepath to json file", "Label"]},
|
||||
},
|
||||
},
|
||||
"unnamed_endpoints": {},
|
||||
}
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_private_space(self):
|
||||
client = Client("gradio-tests/not-actually-private-space", hf_token=HF_TOKEN)
|
||||
assert len(client.endpoints) == 3
|
||||
assert len([e for e in client.endpoints if e.is_valid]) == 2
|
||||
assert len([e for e in client.endpoints if e.is_valid and e.api_name]) == 1
|
||||
assert client.endpoints[0].get_info() == {
|
||||
"parameters": {"x": ["Any", "", "Textbox"]},
|
||||
"returns": {"output": ["Any", "", "Textbox"]},
|
||||
}
|
||||
assert client.view_api(return_info=True) == {
|
||||
"named_endpoints": {
|
||||
"predict": {
|
||||
"parameters": {"x": ["Any", "", "Textbox"]},
|
||||
"returns": {"output": ["Any", "", "Textbox"]},
|
||||
}
|
||||
},
|
||||
"unnamed_endpoints": {
|
||||
2: {
|
||||
"parameters": {"parameter_0": ["Any", "", "Dataset"]},
|
||||
"returns": {
|
||||
"x": ["Any", "", "Textbox"],
|
||||
"output": ["Any", "", "Textbox"],
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
11
client/python/test/test_serializing.py
Normal file
11
client/python/test/test_serializing.py
Normal file
@ -0,0 +1,11 @@
|
||||
from gradio import components
|
||||
|
||||
from gradio_client.serializing import COMPONENT_MAPPING
|
||||
|
||||
|
||||
def test_check_component_fallback_serializers():
|
||||
for component_name, class_type in COMPONENT_MAPPING.items():
|
||||
if component_name == "dataset": # cannot be instantiated without parameters
|
||||
continue
|
||||
component = components.get_component_instance(component_name)
|
||||
assert isinstance(component, class_type)
|
@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Set, Tupl
|
||||
import anyio
|
||||
import requests
|
||||
from anyio import CapacityLimiter
|
||||
from gradio_client import serializing
|
||||
from gradio_client import utils as client_utils
|
||||
from typing_extensions import Literal
|
||||
|
||||
@ -1169,7 +1170,12 @@ class Blocks(BlockContext):
|
||||
}
|
||||
serializer = utils.get_serializer_name(block)
|
||||
if serializer:
|
||||
assert isinstance(block, serializing.Serializable)
|
||||
block_config["serializer"] = serializer
|
||||
block_config["info"] = {
|
||||
"input": list(block.input_api_info()), # type: ignore
|
||||
"output": list(block.output_api_info()), # type: ignore
|
||||
}
|
||||
config["components"].append(block_config)
|
||||
config["dependencies"] = self.dependencies
|
||||
return config
|
||||
|
@ -33,12 +33,16 @@ from fastapi import UploadFile
|
||||
from ffmpy import FFmpeg
|
||||
from gradio_client import utils as client_utils
|
||||
from gradio_client.serializing import (
|
||||
BooleanSerializable,
|
||||
FileSerializable,
|
||||
GallerySerializable,
|
||||
ImgSerializable,
|
||||
JSONSerializable,
|
||||
ListStringSerializable,
|
||||
NumberSerializable,
|
||||
Serializable,
|
||||
SimpleSerializable,
|
||||
StringSerializable,
|
||||
)
|
||||
from pandas.api.types import is_numeric_dtype
|
||||
from PIL import Image as _Image # using _ to minimize namespace pollution
|
||||
@ -367,7 +371,7 @@ class Textbox(
|
||||
Submittable,
|
||||
Blurrable,
|
||||
IOComponent,
|
||||
SimpleSerializable,
|
||||
StringSerializable,
|
||||
TokenInterpretable,
|
||||
):
|
||||
"""
|
||||
@ -584,7 +588,7 @@ class Number(
|
||||
Submittable,
|
||||
Blurrable,
|
||||
IOComponent,
|
||||
SimpleSerializable,
|
||||
NumberSerializable,
|
||||
NeighborInterpretable,
|
||||
):
|
||||
"""
|
||||
@ -764,7 +768,7 @@ class Slider(
|
||||
Changeable,
|
||||
Releaseable,
|
||||
IOComponent,
|
||||
SimpleSerializable,
|
||||
NumberSerializable,
|
||||
NeighborInterpretable,
|
||||
):
|
||||
"""
|
||||
@ -838,6 +842,12 @@ class Slider(
|
||||
self.cleared_value = self.value
|
||||
self.test_input = self.value
|
||||
|
||||
def input_api_info(self) -> Tuple[str, str]:
|
||||
return "int | float", f"value between {self.minimum} and {self.maximum}"
|
||||
|
||||
def get_output_type(self) -> Tuple[str, str]:
|
||||
return "int | float", f"value between {self.minimum} and {self.maximum})"
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"minimum": self.minimum,
|
||||
@ -929,7 +939,7 @@ class Checkbox(
|
||||
Changeable,
|
||||
Selectable,
|
||||
IOComponent,
|
||||
SimpleSerializable,
|
||||
BooleanSerializable,
|
||||
NeighborInterpretable,
|
||||
):
|
||||
"""
|
||||
@ -1033,7 +1043,7 @@ class CheckboxGroup(
|
||||
Changeable,
|
||||
Selectable,
|
||||
IOComponent,
|
||||
SimpleSerializable,
|
||||
ListStringSerializable,
|
||||
NeighborInterpretable,
|
||||
):
|
||||
"""
|
||||
@ -1217,7 +1227,7 @@ class Radio(
|
||||
Selectable,
|
||||
Changeable,
|
||||
IOComponent,
|
||||
SimpleSerializable,
|
||||
StringSerializable,
|
||||
NeighborInterpretable,
|
||||
):
|
||||
"""
|
||||
@ -1456,6 +1466,18 @@ class Dropdown(Changeable, Selectable, IOComponent, SimpleSerializable, FormComp
|
||||
|
||||
self.cleared_value = self.value or ([] if multiselect else "")
|
||||
|
||||
def input_api_info(self) -> Tuple[str, str]:
|
||||
if self.multiselect:
|
||||
return "List[str]", f"List of options from: {self.choices}"
|
||||
else:
|
||||
return "str", f"Option from: {self.choices}"
|
||||
|
||||
def get_output_type(self) -> Tuple[str, str]:
|
||||
if self.multiselect:
|
||||
return "List[str]", f"List of options from: {self.choices}"
|
||||
else:
|
||||
return "str", f"Option from: {self.choices}"
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"choices": self.choices,
|
||||
@ -3135,7 +3157,7 @@ class Variable(State):
|
||||
|
||||
|
||||
@document("style")
|
||||
class Button(Clickable, IOComponent, SimpleSerializable):
|
||||
class Button(Clickable, IOComponent, StringSerializable):
|
||||
"""
|
||||
Used to create a button, that can be assigned arbitrary click() events. The label (value) of the button can be used as an input or set via the output of a function.
|
||||
|
||||
@ -3385,7 +3407,7 @@ class UploadButton(Clickable, Uploadable, IOComponent, FileSerializable):
|
||||
|
||||
|
||||
@document("style")
|
||||
class ColorPicker(Changeable, Submittable, IOComponent, SimpleSerializable):
|
||||
class ColorPicker(Changeable, Submittable, IOComponent, StringSerializable):
|
||||
"""
|
||||
Creates a color picker for user to select a color as string input.
|
||||
Preprocessing: passes selected color value as a {str} into the function.
|
||||
@ -3902,7 +3924,7 @@ class JSON(Changeable, IOComponent, JSONSerializable):
|
||||
|
||||
|
||||
@document()
|
||||
class HTML(Changeable, IOComponent, SimpleSerializable):
|
||||
class HTML(Changeable, IOComponent, StringSerializable):
|
||||
"""
|
||||
Used to display arbitrary HTML output.
|
||||
Preprocessing: this component does *not* accept input.
|
||||
@ -5564,7 +5586,7 @@ class BarPlot(Plot):
|
||||
|
||||
|
||||
@document()
|
||||
class Markdown(IOComponent, Changeable, SimpleSerializable):
|
||||
class Markdown(IOComponent, Changeable, StringSerializable):
|
||||
"""
|
||||
Used to render arbitrary Markdown output. Can also render latex enclosed by dollar signs.
|
||||
Preprocessing: this component does *not* accept input.
|
||||
@ -5639,7 +5661,7 @@ class Markdown(IOComponent, Changeable, SimpleSerializable):
|
||||
|
||||
|
||||
@document("languages")
|
||||
class Code(Changeable, IOComponent, SimpleSerializable):
|
||||
class Code(Changeable, IOComponent, StringSerializable):
|
||||
"""
|
||||
Creates a Code editor for entering, editing or viewing code.
|
||||
Preprocessing: passes a {str} of code into the function.
|
||||
@ -5748,7 +5770,7 @@ class Code(Changeable, IOComponent, SimpleSerializable):
|
||||
|
||||
|
||||
@document("style")
|
||||
class Dataset(Clickable, Selectable, Component, SimpleSerializable):
|
||||
class Dataset(Clickable, Selectable, Component, StringSerializable):
|
||||
"""
|
||||
Used to create an output widget for showing datasets. Used to render the examples
|
||||
box.
|
||||
|
@ -452,7 +452,7 @@ def from_spaces(
|
||||
|
||||
|
||||
def from_spaces_blocks(space: str, api_key: str | None) -> Blocks:
|
||||
client = Client(space=space, hf_token=api_key)
|
||||
client = Client(space, hf_token=api_key)
|
||||
predict_fns = [endpoint._predict_resolve for endpoint in client.endpoints]
|
||||
return gradio.Blocks.from_config(client.config, predict_fns, client.src)
|
||||
|
||||
|
@ -1,9 +1,8 @@
|
||||
XRAY_CONFIG = {
|
||||
"version": "3.21.0\n",
|
||||
"version": "3.23.1b3",
|
||||
"mode": "blocks",
|
||||
"dev_mode": True,
|
||||
"analytics_enabled": False,
|
||||
"theme": "default",
|
||||
"components": [
|
||||
{
|
||||
"id": 1,
|
||||
@ -14,7 +13,8 @@ XRAY_CONFIG = {
|
||||
"visible": True,
|
||||
"style": {},
|
||||
},
|
||||
"serializer": "SimpleSerializable",
|
||||
"serializer": "Serializable",
|
||||
"info": {"input": ["str", "value"], "output": ["str", "value"]},
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
@ -28,7 +28,11 @@ XRAY_CONFIG = {
|
||||
"visible": True,
|
||||
"style": {},
|
||||
},
|
||||
"serializer": "SimpleSerializable",
|
||||
"serializer": "Serializable",
|
||||
"info": {
|
||||
"input": ["List[str]", "values"],
|
||||
"output": ["List[str]", "values"],
|
||||
},
|
||||
},
|
||||
{"id": 3, "type": "tabs", "props": {"visible": True, "style": {}}},
|
||||
{
|
||||
@ -61,12 +65,20 @@ XRAY_CONFIG = {
|
||||
"style": {},
|
||||
},
|
||||
"serializer": "ImgSerializable",
|
||||
"info": {
|
||||
"input": ["str", "filepath or URL"],
|
||||
"output": ["str", "filepath or URL"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 7,
|
||||
"type": "json",
|
||||
"props": {"show_label": True, "name": "json", "visible": True, "style": {}},
|
||||
"serializer": "JSONSerializable",
|
||||
"info": {
|
||||
"input": ["str", "filepath to json file"],
|
||||
"output": ["str", "filepath to json file"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 8,
|
||||
@ -79,7 +91,8 @@ XRAY_CONFIG = {
|
||||
"visible": True,
|
||||
"style": {},
|
||||
},
|
||||
"serializer": "SimpleSerializable",
|
||||
"serializer": "Serializable",
|
||||
"info": {"input": ["str", "value"], "output": ["str", "value"]},
|
||||
},
|
||||
{
|
||||
"id": 9,
|
||||
@ -111,12 +124,20 @@ XRAY_CONFIG = {
|
||||
"style": {},
|
||||
},
|
||||
"serializer": "ImgSerializable",
|
||||
"info": {
|
||||
"input": ["str", "filepath or URL"],
|
||||
"output": ["str", "filepath or URL"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"type": "json",
|
||||
"props": {"show_label": True, "name": "json", "visible": True, "style": {}},
|
||||
"serializer": "JSONSerializable",
|
||||
"info": {
|
||||
"input": ["str", "filepath to json file"],
|
||||
"output": ["str", "filepath to json file"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
@ -129,7 +150,8 @@ XRAY_CONFIG = {
|
||||
"visible": True,
|
||||
"style": {},
|
||||
},
|
||||
"serializer": "SimpleSerializable",
|
||||
"serializer": "Serializable",
|
||||
"info": {"input": ["str", "value"], "output": ["str", "value"]},
|
||||
},
|
||||
{
|
||||
"id": 14,
|
||||
@ -144,7 +166,8 @@ XRAY_CONFIG = {
|
||||
"visible": True,
|
||||
"style": {},
|
||||
},
|
||||
"serializer": "SimpleSerializable",
|
||||
"serializer": "Serializable",
|
||||
"info": {"input": ["str", "value"], "output": ["str", "value"]},
|
||||
},
|
||||
{
|
||||
"id": 15,
|
||||
@ -164,6 +187,12 @@ XRAY_CONFIG = {
|
||||
"show_error": True,
|
||||
"show_api": True,
|
||||
"is_colab": False,
|
||||
"stylesheets": [
|
||||
"https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap",
|
||||
"https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600&display=swap",
|
||||
],
|
||||
"root": "",
|
||||
"theme": "default",
|
||||
"layout": {
|
||||
"id": 0,
|
||||
"children": [
|
||||
@ -257,11 +286,10 @@ XRAY_CONFIG = {
|
||||
|
||||
|
||||
XRAY_CONFIG_DIFF_IDS = {
|
||||
"version": "3.21.0\n",
|
||||
"version": "3.23.1b3",
|
||||
"mode": "blocks",
|
||||
"dev_mode": True,
|
||||
"analytics_enabled": False,
|
||||
"theme": "default",
|
||||
"components": [
|
||||
{
|
||||
"id": 1,
|
||||
@ -272,7 +300,8 @@ XRAY_CONFIG_DIFF_IDS = {
|
||||
"visible": True,
|
||||
"style": {},
|
||||
},
|
||||
"serializer": "SimpleSerializable",
|
||||
"serializer": "Serializable",
|
||||
"info": {"input": ["str", "value"], "output": ["str", "value"]},
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
@ -286,7 +315,11 @@ XRAY_CONFIG_DIFF_IDS = {
|
||||
"visible": True,
|
||||
"style": {},
|
||||
},
|
||||
"serializer": "SimpleSerializable",
|
||||
"serializer": "Serializable",
|
||||
"info": {
|
||||
"input": ["List[str]", "values"],
|
||||
"output": ["List[str]", "values"],
|
||||
},
|
||||
},
|
||||
{"id": 3, "type": "tabs", "props": {"visible": True, "style": {}}},
|
||||
{
|
||||
@ -319,12 +352,20 @@ XRAY_CONFIG_DIFF_IDS = {
|
||||
"style": {},
|
||||
},
|
||||
"serializer": "ImgSerializable",
|
||||
"info": {
|
||||
"input": ["str", "filepath or URL"],
|
||||
"output": ["str", "filepath or URL"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 7,
|
||||
"type": "json",
|
||||
"props": {"show_label": True, "name": "json", "visible": True, "style": {}},
|
||||
"serializer": "JSONSerializable",
|
||||
"info": {
|
||||
"input": ["str", "filepath to json file"],
|
||||
"output": ["str", "filepath to json file"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 8,
|
||||
@ -337,7 +378,8 @@ XRAY_CONFIG_DIFF_IDS = {
|
||||
"visible": True,
|
||||
"style": {},
|
||||
},
|
||||
"serializer": "SimpleSerializable",
|
||||
"serializer": "Serializable",
|
||||
"info": {"input": ["str", "value"], "output": ["str", "value"]},
|
||||
},
|
||||
{
|
||||
"id": 9,
|
||||
@ -369,12 +411,20 @@ XRAY_CONFIG_DIFF_IDS = {
|
||||
"style": {},
|
||||
},
|
||||
"serializer": "ImgSerializable",
|
||||
"info": {
|
||||
"input": ["str", "filepath or URL"],
|
||||
"output": ["str", "filepath or URL"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 1212,
|
||||
"type": "json",
|
||||
"props": {"show_label": True, "name": "json", "visible": True, "style": {}},
|
||||
"serializer": "JSONSerializable",
|
||||
"info": {
|
||||
"input": ["str", "filepath to json file"],
|
||||
"output": ["str", "filepath to json file"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
@ -387,7 +437,8 @@ XRAY_CONFIG_DIFF_IDS = {
|
||||
"visible": True,
|
||||
"style": {},
|
||||
},
|
||||
"serializer": "SimpleSerializable",
|
||||
"serializer": "Serializable",
|
||||
"info": {"input": ["str", "value"], "output": ["str", "value"]},
|
||||
},
|
||||
{
|
||||
"id": 14,
|
||||
@ -402,7 +453,8 @@ XRAY_CONFIG_DIFF_IDS = {
|
||||
"visible": True,
|
||||
"style": {},
|
||||
},
|
||||
"serializer": "SimpleSerializable",
|
||||
"serializer": "Serializable",
|
||||
"info": {"input": ["str", "value"], "output": ["str", "value"]},
|
||||
},
|
||||
{
|
||||
"id": 15,
|
||||
@ -422,6 +474,12 @@ XRAY_CONFIG_DIFF_IDS = {
|
||||
"show_error": True,
|
||||
"show_api": True,
|
||||
"is_colab": False,
|
||||
"stylesheets": [
|
||||
"https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap",
|
||||
"https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;600&display=swap",
|
||||
],
|
||||
"root": "",
|
||||
"theme": "default",
|
||||
"layout": {
|
||||
"id": 0,
|
||||
"children": [
|
||||
|
@ -3,7 +3,7 @@ aiohttp
|
||||
altair>=4.2.0
|
||||
fastapi
|
||||
ffmpy
|
||||
gradio_client==0.0.4
|
||||
gradio_client>=0.0.5
|
||||
httpx
|
||||
huggingface_hub>=0.13.0
|
||||
Jinja2
|
||||
|
@ -126,6 +126,7 @@ class TestBlocksMethods:
|
||||
demo.load(fake_func, [], [textbox])
|
||||
|
||||
config = demo.get_config_file()
|
||||
print(config)
|
||||
assert assert_configs_are_equivalent_besides_ids(XRAY_CONFIG, config)
|
||||
assert config["show_api"] is True
|
||||
_ = demo.launch(prevent_thread_lock=True, show_api=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user