2
0
mirror of https://github.com/gradio-app/gradio.git synced 2025-02-17 11:29:58 +08:00

Allow users to add a custom API route ()

* changes

* add changeset

* changes

* add changeset

* changes

* changes

* changes

* chagnes

* changes

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
aliabid94 2025-01-10 11:46:57 -08:00 committed by GitHub
parent decb594455
commit e742dcccb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 347 additions and 35 deletions

View File

@ -0,0 +1,8 @@
---
"@gradio/client": patch
"@gradio/core": patch
"gradio": patch
"gradio_client": patch
---
fix:Allow users to add a custom API route

View File

@ -181,6 +181,7 @@ export function get_type(
serializer: string,
signature_type: "return" | "parameter"
): string | undefined {
if (component === "Api") return type.type;
switch (type?.type) {
case "string":
return "string";

View File

@ -4,6 +4,7 @@ import asyncio
import base64
import concurrent.futures
import copy
import inspect
import json
import mimetypes
import os
@ -19,7 +20,17 @@ from datetime import datetime
from enum import Enum
from pathlib import Path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict
from typing import (
TYPE_CHECKING,
Any,
Literal,
Optional,
TypedDict,
Union,
get_args,
get_origin,
get_type_hints,
)
import fsspec.asyn
import httpx
@ -994,6 +1005,93 @@ def _json_schema_to_python_type(schema: Any, defs) -> str:
raise APIInfoParseError(f"Cannot parse schema {schema}")
def python_type_to_json_schema(type_hint: Any) -> dict:
try:
return _python_type_to_json_schema(type_hint)
except Exception:
return {}
def _python_type_to_json_schema(type_hint: Any) -> dict:
"""Convert a Python type hint to a JSON schema."""
if type_hint is type(None):
return {"type": "null"}
if type_hint is str:
return {"type": "string"}
if type_hint is int:
return {"type": "integer"}
if type_hint is float:
return {"type": "number"}
if type_hint is bool:
return {"type": "boolean"}
origin = get_origin(type_hint)
if origin is Literal:
literal_values = get_args(type_hint)
if len(literal_values) == 1:
return {"const": literal_values[0]}
return {"enum": list(literal_values)}
if origin is Union or str(origin) == "|":
types = get_args(type_hint)
if len(types) == 2 and type(None) in types:
other_type = next(t for t in types if t is not type(None))
schema = _python_type_to_json_schema(other_type)
if "type" in schema:
schema["type"] = [schema["type"], "null"]
else:
schema["oneOf"] = [{"type": "null"}, schema]
return schema
return {"anyOf": [_python_type_to_json_schema(t) for t in types]}
if origin is list:
item_type = get_args(type_hint)[0]
return {"type": "array", "items": _python_type_to_json_schema(item_type)}
if origin is tuple:
types = get_args(type_hint)
return {
"type": "array",
"prefixItems": [_python_type_to_json_schema(t) for t in types],
"minItems": len(types),
"maxItems": len(types),
}
if origin is dict:
key_type, value_type = get_args(type_hint)
if key_type is not str:
raise ValueError("JSON Schema only supports string keys in objects")
schema = {
"type": "object",
"additionalProperties": _python_type_to_json_schema(value_type),
}
return schema
if inspect.isclass(type_hint) and hasattr(type_hint, "__annotations__"):
properties = {}
required = []
hints = get_type_hints(type_hint)
for field_name, field_type in hints.items():
properties[field_name] = _python_type_to_json_schema(field_type)
if hasattr(type_hint, "__total__"):
if type_hint.__total__:
required.append(field_name)
elif (
not hasattr(type_hint, "__dataclass_fields__")
or not type_hint.__dataclass_fields__[field_name].default
):
required.append(field_name)
schema = {"type": "object", "properties": properties}
if required:
schema["required"] = required
return schema
if type_hint is Any:
return {}
def traverse(json_obj: Any, func: Callable, is_root: Callable[..., bool]) -> Any:
"""
Traverse a JSON object and apply a function to each element that satisfies the is_root condition.

View File

@ -75,6 +75,7 @@ from gradio.events import (
RetryData,
SelectData,
UndoData,
api,
on,
)
from gradio.exceptions import Error

View File

@ -779,11 +779,12 @@ class BlocksConfig:
if fn is not None and not cancels:
check_function_inputs_match(fn, inputs, inputs_as_dict)
if _targets[0][1] in ["change", "key_up"] and trigger_mode is None:
trigger_mode = "always_last"
elif _targets[0][1] in ["stream"] and trigger_mode is None:
trigger_mode = "multiple"
elif trigger_mode is None:
if len(_targets) and trigger_mode is None:
if _targets[0][1] in ["change", "key_up"]:
trigger_mode = "always_last"
elif _targets[0][1] in ["stream"]:
trigger_mode = "multiple"
if trigger_mode is None:
trigger_mode = "once"
elif trigger_mode not in ["once", "multiple", "always_last"]:
raise ValueError(

View File

@ -48,7 +48,7 @@ INTERFACE_TEMPLATE = '''
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: list of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: list of gradio.components to use as outputs. If the function returns no outputs, this should be an empty list.
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, the endpoint will be exposed in the api docs as an unnamed endpoint, although this behavior will be changed in Gradio 4.0. If set to a string, the endpoint will be exposed in the api docs with the given name.
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, will use the functions name as the endpoint route. If set to a string, the endpoint will be exposed in the api docs with the given name.
scroll_to_output: if True, will scroll to output component on completion
show_progress: how to show the progress animation while event is running: "full" shows a spinner which covers the output component area as well as a runtime display in the upper right corner, "minimal" only shows the runtime display, "hidden" shows no progress animation at all
queue: if True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.

View File

@ -0,0 +1,46 @@
"""gr.Api() component."""
from __future__ import annotations
from typing import Any
from gradio.components.base import Component
class Api(Component):
"""
A generic component that holds any value. Used for generating APIs with no actual frontend component.
"""
EVENTS = []
def __init__(
self,
value: Any,
_api_info: dict[str, str],
label: str = "API",
):
"""
Parameters:
value: default value.
"""
self._api_info = _api_info
super().__init__(value=value, label=label)
def preprocess(self, payload: Any) -> Any:
return payload
def postprocess(self, value: Any) -> Any:
return value
def api_info(self) -> dict[str, str]:
return self._api_info
def example_payload(self) -> Any:
return self.value if self.value is not None else "..."
def example_value(self) -> Any:
return self.value if self.value is not None else "..."
# def get_block_name(self) -> str:
# return "state" # so that it does not render in the frontend, just like state

View File

@ -23,8 +23,10 @@ if TYPE_CHECKING:
from gradio.blocks import Block, BlockContext, Component
from gradio.components import Timer
from gradio_client.utils import python_type_to_json_schema
from gradio.context import get_blocks_context
from gradio.utils import get_cancelled_fn_indices
from gradio.utils import get_cancelled_fn_indices, get_function_params, get_return_types
def set_cancel_events(
@ -760,7 +762,7 @@ def on(
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: List of gradio.components to use as outputs. If the function returns no outputs, this should be an empty list.
api_name: Defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, the endpoint will be exposed in the api docs as an unnamed endpoint, although this behavior will be changed in Gradio 4.0. If set to a string, the endpoint will be exposed in the api docs with the given name.
api_name: Defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, will use the functions name as the endpoint route. If set to a string, the endpoint will be exposed in the api docs with the given name.
scroll_to_output: If True, will scroll to output component on completion
show_progress: how to show the progress animation while event is running: "full" shows a spinner which covers the output component area as well as a runtime display in the upper right corner, "minimal" only shows the runtime display, "hidden" shows no progress animation at all
queue: If True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
@ -886,6 +888,126 @@ def on(
return Dependency(None, dep.get_config(), dep_index, fn)
@document()
def api(
fn: Callable | Literal["decorator"] = "decorator",
*,
api_name: str | None | Literal[False] = None,
queue: bool = True,
batch: bool = False,
max_batch_size: int = 4,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
show_api: bool = True,
time_limit: int | None = None,
stream_every: float = 0.5,
) -> Dependency:
"""
Sets up an API endpoint for a generic function that can be called via the gradio client. Derives its type from type-hints in the function signature.
Parameters:
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
api_name: Defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, will use the functions name as the endpoint route. If set to a string, the endpoint will be exposed in the api docs with the given name.
queue: If True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
concurrency_limit: If set, this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default).
concurrency_id: If set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit.
show_api: whether to show this event in the "view API" page of the Gradio app, or in the ".view_api()" method of the Gradio clients. Unlike setting api_name to False, setting show_api to False will still allow downstream apps as well as the Clients to use this event. If fn is None, show_api will automatically be set to False.
time_limit: The time limit for the function to run. Parameter only used for the `.stream()` event.
stream_every: The latency (in seconds) at which stream chunks are sent to the backend. Defaults to 0.5 seconds. Parameter only used for the `.stream()` event.
Example:
import gradio as gr
with gr.Blocks() as demo:
with gr.Row():
input = gr.Textbox()
button = gr.Button("Submit")
output = gr.Textbox()
gr.on(
triggers=[button.click, input.submit],
fn=lambda x: x,
inputs=[input],
outputs=[output]
)
demo.launch()
"""
if fn == "decorator":
def wrapper(func):
api(
fn=func,
api_name=api_name,
queue=queue,
batch=batch,
max_batch_size=max_batch_size,
concurrency_limit=concurrency_limit,
concurrency_id=concurrency_id,
show_api=show_api,
time_limit=time_limit,
stream_every=stream_every,
)
@wraps(func)
def inner(*args, **kwargs):
return func(*args, **kwargs)
return inner
return Dependency(None, {}, None, wrapper)
root_block = get_blocks_context()
if root_block is None:
raise Exception("Cannot call api() outside of a gradio.Blocks context.")
from gradio.components.api_component import Api
fn_params = get_function_params(fn)
return_types = get_return_types(fn)
def ordinal(n):
return f"{n}{'th' if 10 <= n % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(n % 10, 'th')}"
if any(param[3] is None for param in fn_params):
raise ValueError(
"API endpoints must have type hints. Please specify a type hint for all parameters."
)
inputs = [
Api(
default_value if has_default else None,
python_type_to_json_schema(_type),
ordinal(i + 1),
)
for i, (_, has_default, default_value, _type) in enumerate(fn_params)
]
outputs = [
Api(None, python_type_to_json_schema(type), ordinal(i + 1))
for i, type in enumerate(return_types)
]
dep, dep_index = root_block.set_event_trigger(
[],
fn,
inputs,
outputs,
preprocess=False,
postprocess=False,
scroll_to_output=False,
show_progress="hidden",
api_name=api_name,
js=None,
concurrency_limit=concurrency_limit,
concurrency_id=concurrency_id,
queue=queue,
batch=batch,
max_batch_size=max_batch_size,
show_api=show_api,
trigger_mode=None,
time_limit=time_limit,
stream_every=stream_every,
)
return Dependency(None, dep.get_config(), dep_index, fn)
class Events:
change = EventListener(
"change",

View File

@ -42,7 +42,7 @@ from contextlib import contextmanager
from functools import wraps
from io import BytesIO
from pathlib import Path
from types import ModuleType
from types import ModuleType, NoneType
from typing import (
TYPE_CHECKING,
Any,
@ -50,6 +50,8 @@ from typing import (
Literal,
Optional,
TypeVar,
get_args,
get_origin,
)
import anyio
@ -1299,9 +1301,9 @@ def get_upload_folder() -> str:
)
def get_function_params(func: Callable) -> list[tuple[str, bool, Any]]:
def get_function_params(func: Callable) -> list[tuple[str, bool, Any, Any]]:
"""
Gets the parameters of a function as a list of tuples of the form (name, has_default, default_value).
Gets the parameters of a function as a list of tuples of the form (name, has_default, default_value, type_hint).
Excludes *args and **kwargs, as well as args that are Gradio-specific, such as gr.Request, gr.EventData, gr.OAuthProfile, and gr.OAuthToken.
"""
params_info = []
@ -1316,12 +1318,26 @@ def get_function_params(func: Callable) -> list[tuple[str, bool, Any]]:
if is_special_typed_parameter(name, type_hints):
continue
if parameter.default is inspect.Parameter.empty:
params_info.append((name, False, None))
params_info.append((name, False, None, type_hints.get(name, None)))
else:
params_info.append((name, True, parameter.default))
params_info.append(
(name, True, parameter.default, type_hints.get(name, None))
)
return params_info
def get_return_types(func: Callable) -> list:
return_hint = inspect.signature(func).return_annotation
if return_hint in {inspect.Signature.empty, None, NoneType}:
return []
if get_origin(return_hint) == tuple:
return list(get_args(return_hint))
return [return_hint]
def simplify_file_data_in_str(s):
"""
If a FileData dictionary has been dumped as part of a string, this function will replace the dict with just the str filepath

View File

@ -368,7 +368,7 @@
.endpoint-container {
margin-top: var(--size-3);
margin-bottom: var(--size-3);
border: 1px solid var(--body-text-color);
border: 1px solid var(--block-border-color);
border-radius: var(--radius-xl);
padding: var(--size-3);
padding-top: 0;

View File

@ -9,7 +9,7 @@
let bash_install = "curl --version";
</script>
<Block border_mode="contrast">
<Block>
<code>
{#if current_language === "python"}
<div class="copy">

View File

@ -638,6 +638,7 @@ export function get_component(
} {
let example_component_map: Map<ComponentMeta["type"], LoadingComponent> =
new Map();
if (type === "api") type = "state";
if (type === "dataset" && example_components) {
(example_components as string[]).forEach((name: string) => {
if (example_component_map.has(name)) {

View File

@ -12009,7 +12009,7 @@ snapshots:
code-red@1.0.4:
dependencies:
'@jridgewell/sourcemap-codec': 1.4.15
'@jridgewell/sourcemap-codec': 1.5.0
'@types/estree': 1.0.5
acorn: 8.11.3
estree-walker: 3.0.3

View File

@ -694,6 +694,18 @@ class TestRoutes:
assert response.status_code == 403
def test_api_listener(connect):
with gr.Blocks() as demo:
def fn(a: int, b: int, c: str) -> tuple[int, str]:
return a + b, c[a:b]
gr.api(fn, api_name="addition")
with connect(demo) as client:
assert client.predict(a=1, b=3, c="testing", api_name="/addition") == (4, "es")
class TestApp:
def test_create_app(self):
app = routes.App.create_app(Interface(lambda x: x, "text", "text"))

View File

@ -533,14 +533,14 @@ def test_diff(old, new, expected_diff):
class TestFunctionParams:
def test_regular_function(self):
def func(a, b=10, c="default", d=None):
def func(a: int, b: int = 10, c: str = "default", d=None):
pass
assert get_function_params(func) == [
("a", False, None),
("b", True, 10),
("c", True, "default"),
("d", True, None),
("a", False, None, int),
("b", True, 10, int),
("c", True, "default", str),
("d", True, None, None),
]
def test_function_no_params(self):
@ -551,32 +551,38 @@ class TestFunctionParams:
def test_lambda_function(self):
assert get_function_params(lambda x, y: x + y) == [
("x", False, None),
("y", False, None),
("x", False, None, None),
("y", False, None, None),
]
def test_function_with_args(self):
def func(a, *args):
pass
assert get_function_params(func) == [("a", False, None)]
assert get_function_params(func) == [("a", False, None, None)]
def test_function_with_kwargs(self):
def func(a, **kwargs):
pass
assert get_function_params(func) == [("a", False, None)]
assert get_function_params(func) == [("a", False, None, None)]
def test_function_with_special_args(self):
def func(a, r: Request, b=10):
pass
assert get_function_params(func) == [("a", False, None), ("b", True, 10)]
assert get_function_params(func) == [
("a", False, None, None),
("b", True, 10, None),
]
def func2(a, r: Request | None = None, b="abc"):
pass
assert get_function_params(func2) == [("a", False, None), ("b", True, "abc")]
assert get_function_params(func2) == [
("a", False, None, None),
("b", True, "abc", None),
]
def test_class_method_skip_first_param(self):
class MyClass:
@ -584,8 +590,8 @@ class TestFunctionParams:
pass
assert get_function_params(MyClass().method) == [
("arg1", False, None),
("arg2", True, 42),
("arg1", False, None, None),
("arg2", True, 42, None),
]
def test_static_method_no_skip(self):
@ -595,8 +601,8 @@ class TestFunctionParams:
pass
assert get_function_params(MyClass.method) == [
("arg1", False, None),
("arg2", True, 42),
("arg1", False, None, None),
("arg2", True, 42, None),
]
def test_class_method_with_args(self):
@ -604,13 +610,13 @@ class TestFunctionParams:
def method(self, a, *args, b=42):
pass
assert get_function_params(MyClass().method) == [("a", False, None)]
assert get_function_params(MyClass().method) == [("a", False, None, None)]
def test_lambda_with_args(self):
assert get_function_params(lambda x, *args: x) == [("x", False, None)]
assert get_function_params(lambda x, *args: x) == [("x", False, None, None)]
def test_lambda_with_kwargs(self):
assert get_function_params(lambda x, **kwargs: x) == [("x", False, None)]
assert get_function_params(lambda x, **kwargs: x) == [("x", False, None, None)]
def test_parse_file_size():