diff --git a/demo/blocks_xray/run.py b/demo/blocks_xray/run.py
index 4f1a8024be..4a98f39cb5 100644
--- a/demo/blocks_xray/run.py
+++ b/demo/blocks_xray/run.py
@@ -38,6 +38,7 @@ With this model you can lorem ipsum
inputs=[disease, xray_scan],
outputs=xray_results,
status_tracker=xray_progress,
+ api_name="xray_model"
)
with gr.TabItem("CT Scan"):
@@ -51,6 +52,7 @@ With this model you can lorem ipsum
inputs=[disease, ct_scan],
outputs=ct_results,
status_tracker=ct_progress,
+ api_name="ct_model"
)
upload_btn = gr.Button("Upload Results")
diff --git a/gradio/blocks.py b/gradio/blocks.py
index 621871173c..2318ef7ebc 100644
--- a/gradio/blocks.py
+++ b/gradio/blocks.py
@@ -8,7 +8,7 @@ import random
import sys
import time
import webbrowser
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, AnyStr, Callable, Dict, List, Optional, Tuple
import anyio
from anyio import CapacityLimiter
@@ -85,6 +85,7 @@ class Block:
outputs: Optional[Component | List[Component]],
preprocess: bool = True,
postprocess: bool = True,
+ api_name: Optional[AnyStr] = None,
js: Optional[str] = False,
no_target: bool = False,
status_tracker: Optional[StatusTracker] = None,
@@ -99,6 +100,7 @@ class Block:
outputs: output list
preprocess: whether to run the preprocess methods of components
postprocess: whether to run the postprocess methods of components
+ api_name: Defining this parameter exposes the endpoint in the api docs
js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components
no_target: if True, sets "targets" to [], used for Blocks "load" event
status_tracker: StatusTracker to visualize function progress
@@ -115,20 +117,25 @@ class Block:
outputs = [outputs]
Context.root_block.fns.append(BlockFunction(fn, preprocess, postprocess))
- Context.root_block.dependencies.append(
- {
- "targets": [self._id] if not no_target else [],
- "trigger": event_name,
- "inputs": [block._id for block in inputs],
- "outputs": [block._id for block in outputs],
- "backend_fn": fn is not None,
- "js": js,
- "status_tracker": status_tracker._id
- if status_tracker is not None
- else None,
- "queue": queue,
- }
- )
+ dependency = {
+ "targets": [self._id] if not no_target else [],
+ "trigger": event_name,
+ "inputs": [block._id for block in inputs],
+ "outputs": [block._id for block in outputs],
+ "backend_fn": fn is not None,
+ "js": js,
+ "status_tracker": status_tracker._id
+ if status_tracker is not None
+ else None,
+ "queue": queue,
+ "api_name": api_name,
+ }
+ if api_name is not None:
+ dependency["documentation"] = [
+ [component.document_parameters("input") for component in inputs],
+ [component.document_parameters("output") for component in outputs],
+ ]
+ Context.root_block.dependencies.append(dependency)
def get_config(self):
return {
diff --git a/gradio/components.py b/gradio/components.py
index 8bff6374c9..84a18361fd 100644
--- a/gradio/components.py
+++ b/gradio/components.py
@@ -244,6 +244,21 @@ class IOComponent(Component):
self._style["container"] = container
return self
+ @classmethod
+ def document_parameters(cls, target):
+ if target == "input":
+ doc = inspect.getdoc(cls.preprocess)
+ if "Parameters:\nx (" in doc:
+ return doc.split("Parameters:\nx ")[1].split("\n")[0]
+ return None
+ elif target == "output":
+ doc = inspect.getdoc(cls.postprocess)
+ if "Returns:\n" in doc:
+ return doc.split("Returns:\n")[1].split("\n")[0]
+ return None
+ else:
+ raise ValueError("Invalid doumentation target.")
+
class Textbox(Changeable, Submittable, IOComponent):
"""
@@ -329,6 +344,10 @@ class Textbox(Changeable, Submittable, IOComponent):
def preprocess(self, x: str | None) -> Any:
"""
Any preprocessing needed to be performed on function input.
+ Parameters:
+ x (str): text
+ Returns:
+ (str): text
"""
if x is None:
return None
@@ -415,6 +434,10 @@ class Textbox(Changeable, Submittable, IOComponent):
def postprocess(self, y: str | None):
"""
Any postprocessing needed to be performed on function output.
+ Parameters:
+ y (str | None): text
+ Returns:
+ (str | None): text
"""
if y is None:
return None
@@ -518,21 +541,21 @@ class Number(Changeable, Submittable, IOComponent):
"__type__": "update",
}
- def preprocess(self, x: int | float | None) -> int | float | None:
+ def preprocess(self, x: float | None) -> float | None:
"""
Parameters:
- x (int | float | None): numeric input as a string
+ x (float | None): numeric input
Returns:
- (int | float | None): number representing function input
+ (float | None): number representing function input
"""
if x is None:
return None
return self.round_to_precision(x, self.precision)
- def preprocess_example(self, x: int | float | None) -> int | float | None:
+ def preprocess_example(self, x: float | None) -> float | None:
"""
Returns:
- (int | float | None): Number representing function input
+ (float | None): Number representing function input
"""
if x is None:
return None
@@ -584,18 +607,18 @@ class Number(Changeable, Submittable, IOComponent):
interpretation.insert(int(len(interpretation) / 2), [x, None])
return interpretation
- def generate_sample(self) -> int | float:
+ def generate_sample(self) -> float:
return self.round_to_precision(1, self.precision)
# Output Functionalities
- def postprocess(self, y: int | float | None) -> int | float | None:
+ def postprocess(self, y: float | None) -> float | None:
"""
Any postprocessing needed to be performed on function output.
Parameters:
- y (int | float | None): numeric output
+ y (float | None): numeric output
Returns:
- (int | float | None): number representing function output
+ (float | None): number representing function output
"""
if y is None:
return None
@@ -740,9 +763,13 @@ class Slider(Changeable, IOComponent):
# Output Functionalities
- def postprocess(self, y: int | float | None):
+ def postprocess(self, y: float | None):
"""
Any postprocessing needed to be performed on function output.
+ Parameters:
+ y (float | None): numeric output
+ Returns:
+ (float): numeric output or minimum number if None
"""
return self.minimum if y is None else y
@@ -867,6 +894,10 @@ class Checkbox(Changeable, IOComponent):
def postprocess(self, y):
"""
Any postprocessing needed to be performed on function output.
+ Parameters:
+ y (bool): boolean output
+ Returns:
+ (bool): boolean output
"""
return y
@@ -1015,6 +1046,10 @@ class CheckboxGroup(Changeable, IOComponent):
def postprocess(self, y):
"""
Any postprocessing needed to be performed on function output.
+ Parameters:
+ y (List[str]): List of selected choices
+ Returns:
+ (List[str]): List of selected choices
"""
return [] if y is None else y
@@ -1160,6 +1195,10 @@ class Radio(Changeable, IOComponent):
def postprocess(self, y):
"""
Any postprocessing needed to be performed on function output.
+ Parameters:
+ y (str): string of choice
+ Returns:
+ (str): string of choice
"""
return (
y if y is not None else self.choices[0] if len(self.choices) > 0 else None
@@ -3407,9 +3446,7 @@ class Model3D(Changeable, Editable, Clearable, IOComponent):
Parameters:
y (str): path to the model
Returns:
- (str): file name
- (str): file extension
- (str): base64 url data
+ (Dict[name (str): file name, data (str): base64 url data] | None)
"""
if y is None:
return y
@@ -3494,8 +3531,7 @@ class Plot(Changeable, Clearable, IOComponent):
Parameters:
y (str): plot data
Returns:
- (str): plot type
- (str): plot base64 or json
+ (Dict[type (str): plot type, plot (str): plot base64 | json] | None)
"""
if y is None:
return None
diff --git a/gradio/events.py b/gradio/events.py
index 795a1db2e2..118a7ec3df 100644
--- a/gradio/events.py
+++ b/gradio/events.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, AnyStr, Callable, Dict, List, Optional, Tuple
from gradio.blocks import Block
@@ -15,6 +15,7 @@ class Changeable(Block):
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
+ api_name: AnyStr = None,
queue: Optional[bool] = None,
_js: Optional[str] = None,
_preprocess: bool = True,
@@ -26,6 +27,7 @@ class Changeable(Block):
inputs: List of inputs
outputs: List of outputs
status_tracker: StatusTracker to visualize function progress
+ api_name: Defining this parameter exposes the endpoint in the api docs
_js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of input and outputs components, return should be a list of values for output component.
Returns: None
"""
@@ -35,6 +37,7 @@ class Changeable(Block):
inputs,
outputs,
status_tracker=status_tracker,
+ api_name=api_name,
js=_js,
preprocess=_preprocess,
postprocess=_postprocess,
@@ -49,6 +52,7 @@ class Clickable(Block):
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
+ api_name: AnyStr = None,
queue=None,
_js: Optional[str] = None,
_preprocess: bool = True,
@@ -60,6 +64,7 @@ class Clickable(Block):
inputs: List of inputs
outputs: List of outputs
status_tracker: StatusTracker to visualize function progress
+ api_name: Defining this parameter exposes the endpoint in the api docs
_js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
_preprocess: If False, will not run preprocessing of component data before running 'fn'.
_postprocess: If False, will not run postprocessing of component data before returning 'fn' output.
@@ -71,6 +76,7 @@ class Clickable(Block):
inputs,
outputs,
status_tracker=status_tracker,
+ api_name=api_name,
queue=queue,
js=_js,
preprocess=_preprocess,
@@ -85,6 +91,7 @@ class Submittable(Block):
inputs: List[Component],
outputs: List[Component],
status_tracker: Optional[StatusTracker] = None,
+ api_name: AnyStr = None,
queue: Optional[bool] = None,
_js: Optional[str] = None,
_preprocess: bool = True,
@@ -96,6 +103,7 @@ class Submittable(Block):
inputs: List of inputs
outputs: List of outputs
status_tracker: StatusTracker to visualize function progress
+ api_name: Defining this parameter exposes the endpoint in the api docs
_js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
Returns: None
"""
@@ -105,6 +113,7 @@ class Submittable(Block):
inputs,
outputs,
status_tracker=status_tracker,
+ api_name=api_name,
js=_js,
preprocess=_preprocess,
postprocess=_postprocess,
@@ -118,6 +127,7 @@ class Editable(Block):
fn: Callable,
inputs: List[Component],
outputs: List[Component],
+ api_name: AnyStr = None,
queue: Optional[bool] = None,
_js: Optional[str] = None,
_preprocess: bool = True,
@@ -128,6 +138,7 @@ class Editable(Block):
fn: Callable function
inputs: List of inputs
outputs: List of outputs
+ api_name: Defining this parameter exposes the endpoint in the api docs
_js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
Returns: None
"""
@@ -136,6 +147,7 @@ class Editable(Block):
fn,
inputs,
outputs,
+ api_name=api_name,
js=_js,
preprocess=_preprocess,
postprocess=_postprocess,
@@ -149,6 +161,7 @@ class Clearable(Block):
fn: Callable,
inputs: List[Component],
outputs: List[Component],
+ api_name: AnyStr = None,
queue: Optional[bool] = None,
_js: Optional[str] = None,
_preprocess: bool = True,
@@ -159,6 +172,7 @@ class Clearable(Block):
fn: Callable function
inputs: List of inputs
outputs: List of outputs
+ api_name: Defining this parameter exposes the endpoint in the api docs
_js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
Returns: None
"""
@@ -167,6 +181,7 @@ class Clearable(Block):
fn,
inputs,
outputs,
+ api_name=api_name,
js=_js,
preprocess=_preprocess,
postprocess=_postprocess,
@@ -180,6 +195,7 @@ class Playable(Block):
fn: Callable,
inputs: List[Component],
outputs: List[Component],
+ api_name: AnyStr = None,
queue: Optional[bool] = None,
_js: Optional[str] = None,
_preprocess: bool = True,
@@ -190,6 +206,7 @@ class Playable(Block):
fn: Callable function
inputs: List of inputs
outputs: List of outputs
+ api_name: Defining this parameter exposes the endpoint in the api docs
_js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
Returns: None
"""
@@ -198,6 +215,7 @@ class Playable(Block):
fn,
inputs,
outputs,
+ api_name=api_name,
js=_js,
preprocess=_preprocess,
postprocess=_postprocess,
@@ -209,6 +227,7 @@ class Playable(Block):
fn: Callable,
inputs: List[Component],
outputs: List[Component],
+ api_name: Optional[AnyStr] = None,
queue: Optional[bool] = None,
_js: Optional[str] = None,
_preprocess: bool = True,
@@ -219,6 +238,7 @@ class Playable(Block):
fn: Callable function
inputs: List of inputs
outputs: List of outputs
+ api_name: Defining this parameter exposes the endpoint in the api docs
_js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
Returns: None
"""
@@ -227,6 +247,7 @@ class Playable(Block):
fn,
inputs,
outputs,
+ api_name=api_name,
js=_js,
preprocess=_preprocess,
postprocess=_postprocess,
@@ -238,6 +259,7 @@ class Playable(Block):
fn: Callable,
inputs: List[Component],
outputs: List[Component],
+ api_name: AnyStr = None,
queue: Optional[bool] = None,
_js: Optional[str] = None,
_preprocess: bool = True,
@@ -248,6 +270,7 @@ class Playable(Block):
fn: Callable function
inputs: List of inputs
outputs: List of outputs
+ api_name: Defining this parameter exposes the endpoint in the api docs
_js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
Returns: None
"""
@@ -256,6 +279,7 @@ class Playable(Block):
fn,
inputs,
outputs,
+ api_name=api_name,
js=_js,
preprocess=_preprocess,
postprocess=_postprocess,
@@ -269,6 +293,7 @@ class Streamable(Block):
fn: Callable,
inputs: List[Component],
outputs: List[Component],
+ api_name: AnyStr = None,
queue: Optional[bool] = None,
_js: Optional[str] = None,
_preprocess: bool = True,
@@ -279,6 +304,7 @@ class Streamable(Block):
fn: Callable function
inputs: List of inputs
outputs: List of outputs
+ api_name: Defining this parameter exposes the endpoint in the api docs
_js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
Returns: None
"""
@@ -288,6 +314,7 @@ class Streamable(Block):
fn,
inputs,
outputs,
+ api_name=api_name,
js=_js,
preprocess=_preprocess,
postprocess=_postprocess,
diff --git a/gradio/interface.py b/gradio/interface.py
index 0b73a722f6..048bad5300 100644
--- a/gradio/interface.py
+++ b/gradio/interface.py
@@ -519,6 +519,7 @@ class Interface(Blocks):
submit_fn,
self.input_components,
self.output_components,
+ api_name="predict",
status_tracker=status_tracker,
)
clear_btn.click(
diff --git a/gradio/routes.py b/gradio/routes.py
index 32ca616b64..8d03e9f372 100644
--- a/gradio/routes.py
+++ b/gradio/routes.py
@@ -226,47 +226,19 @@ class App(FastAPI):
if Path(app.cwd).resolve() in Path(path).resolve().parents:
return FileResponse(Path(path).resolve())
- @app.get("/api", response_class=HTMLResponse) # Needed for Spaces
- @app.get("/api/", response_class=HTMLResponse)
- def api_docs(request: Request):
- inputs = [type(inp) for inp in app.blocks.input_components]
- outputs = [type(out) for out in app.blocks.output_components]
- input_types_doc, input_types = get_types(inputs, "input")
- output_types_doc, output_types = get_types(outputs, "output")
- input_names = [inp.get_block_name() for inp in app.blocks.input_components]
- output_names = [
- out.get_block_name() for out in app.blocks.output_components
- ]
- if app.blocks.examples is not None:
- sample_inputs = app.blocks.examples[0]
- else:
- sample_inputs = [
- inp.generate_sample() for inp in app.blocks.input_components
- ]
- docs = {
- "inputs": input_names,
- "outputs": output_names,
- "len_inputs": len(inputs),
- "len_outputs": len(outputs),
- "inputs_lower": [name.lower() for name in input_names],
- "outputs_lower": [name.lower() for name in output_names],
- "input_types": input_types,
- "output_types": output_types,
- "input_types_doc": input_types_doc,
- "output_types_doc": output_types_doc,
- "sample_inputs": sample_inputs,
- "auth": app.blocks.auth,
- "local_login_url": urllib.parse.urljoin(app.blocks.local_url, "login"),
- "local_api_url": urllib.parse.urljoin(
- app.blocks.local_url, "api/predict"
- ),
- }
- return templates.TemplateResponse(
- "api_docs.html", {"request": request, **docs}
- )
+ @app.post("/api/queue/push/", dependencies=[Depends(login_check)])
+ async def queue_push(body: QueuePushBody):
+ job_hash, queue_position = queueing.push(body)
+ return {"hash": job_hash, "queue_position": queue_position}
- @app.post("/api/predict/", dependencies=[Depends(login_check)])
- async def predict(body: PredictBody, username: str = Depends(get_current_user)):
+ @app.post("/api/queue/status/", dependencies=[Depends(login_check)])
+ async def queue_status(body: QueueStatusBody):
+ status, data = queueing.get_status(body.hash)
+ return {"status": status, "data": data}
+
+ async def run_predict(
+ body: PredictBody, username: str = Depends(get_current_user)
+ ):
if hasattr(body, "session_hash"):
if body.session_hash not in app.state_holder:
app.state_holder[body.session_hash] = {
@@ -291,15 +263,24 @@ class App(FastAPI):
raise error
return output
- @app.post("/api/queue/push/", dependencies=[Depends(login_check)])
- async def queue_push(body: QueuePushBody):
- job_hash, queue_position = queueing.push(body)
- return {"hash": job_hash, "queue_position": queue_position}
-
- @app.post("/api/queue/status/", dependencies=[Depends(login_check)])
- async def queue_status(body: QueueStatusBody):
- status, data = queueing.get_status(body.hash)
- return {"status": status, "data": data}
+ @app.post("/api/{api_name}", dependencies=[Depends(login_check)])
+ @app.post("/api/{api_name}/", dependencies=[Depends(login_check)])
+ async def predict(
+ api_name: str, body: PredictBody, username: str = Depends(get_current_user)
+ ):
+ if body.fn_index is None:
+ for i, fn in enumerate(app.blocks.dependencies):
+ if fn["api_name"] == api_name:
+ body.fn_index = i
+ break
+ if body.fn_index is None:
+ return JSONResponse(
+ content={
+ "error": f"This app has no endpoint /api/{api_name}/."
+ },
+ status_code=500,
+ )
+ return await run_predict(body=body, username=username)
return app
@@ -329,19 +310,14 @@ def safe_join(directory: str, path: str) -> Optional[str]:
return posixpath.join(directory, filename)
-def get_types(cls_set: List[Type], component: str):
+def get_types(cls_set: List[Type]):
docset = []
types = []
- if component == "input":
- for cls in cls_set:
- doc = inspect.getdoc(cls.preprocess)
- doc_lines = doc.split("\n")
- docset.append(doc_lines[1].split(":")[-1])
- types.append(doc_lines[1].split(")")[0].split("(")[-1])
- else:
- for cls in cls_set:
- doc = inspect.getdoc(cls.postprocess)
- doc_lines = doc.split("\n")
- docset.append(doc_lines[-1].split(":")[-1])
- types.append(doc_lines[-1].split(")")[0].split("(")[-1])
+ for cls in cls_set:
+ doc = inspect.getdoc(cls)
+ doc_lines = doc.split("\n")
+ for line in doc_lines:
+ if "value (" in line:
+ types.append(line.split("value (")[1].split(")")[0])
+ docset.append(doc_lines[1].split(":")[-1])
return docset, types
diff --git a/gradio/templates/api_docs.html b/gradio/templates/api_docs.html
deleted file mode 100644
index d850733816..0000000000
--- a/gradio/templates/api_docs.html
+++ /dev/null
@@ -1,716 +0,0 @@
-
-
-
This interface takes in {{ len_inputs }} input(s) and returns
- {{ len_outputs }} output(s).
-
The URL endpoint is:
-
-
-
-
Input(s): [{%for i in range(0, len_inputs)%} {{inputs[i]}}{% if i !=
- len_inputs - 1 %} ,{% endif %}{%endfor%} ]
-
-
- {%for i in range(0, len_inputs)%}
-
{{inputs[i]}} accepts the {{ input_types_doc[i]}} as type {{ input_types[i] }}
- {%endfor%}
-
-
-
Output(s): [{%for i in range(0, len_outputs)%} {{outputs[i]}}{% if i !=
- len_outputs - 1 %} ,{% endif %}{%endfor%} ]
-
-
- {%for i in range(0, len_outputs)%}
-
{{outputs[i]}} returns the {{ output_types_doc[i]}} as type {{ output_types[i]}}
- {%endfor%}
-
-
-
-
Payload:
-
-
-
{
-
"data": [{%for i in range(0, len_inputs)%} {{ input_types[i]
- }}{% if i != len_inputs - 1 %} ,{% endif %}{%endfor%} ]
-
}
-
-
-
- {% if auth is not none %}
-
Note: This interface requires authentication. This means you will have to first post to the login api before you can post to the predict endpoint. See below for more info
- {% endif %}
-
Response:
-
-
-
{
-
"data": [{%for i in range(0, len_outputs)%} {{ output_types[i]
- }}{% if i != len_outputs - 1 %} ,{% endif %}{%endfor%} ],
-
"durations": [ float ], # the time taken for the prediction to complete
-
"avg_durations": [ float ] # the average time taken for all predictions so far (used to estimate the runtime)