mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
Add gr.on
listener method (#5639)
* changes * changes * changes * changes * changes * add changeset * changes * changes * changes * changes * changes * changes --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
6a36c3b786
commit
e1874aff81
6
.changeset/calm-mangos-send.md
Normal file
6
.changeset/calm-mangos-send.md
Normal file
@ -0,0 +1,6 @@
|
||||
---
|
||||
"@gradio/app": minor
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Add `gr.on` listener method
|
1
demo/on_listener_basic/run.ipynb
Normal file
1
demo/on_listener_basic/run.ipynb
Normal file
@ -0,0 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: on_listener_basic"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", " name = gr.Textbox(label=\"Name\")\n", " output = gr.Textbox(label=\"Output Box\")\n", " greet_btn = gr.Button(\"Greet\")\n", "\n", " def greet(name):\n", " return \"Hello \" + name + \"!\"\n", "\n", " gr.on(\n", " triggers=[name.submit, greet_btn.click],\n", " fn=greet,\n", " inputs=name,\n", " outputs=output,\n", " )\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
20
demo/on_listener_basic/run.py
Normal file
20
demo/on_listener_basic/run.py
Normal file
@ -0,0 +1,20 @@
|
||||
import gradio as gr
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
name = gr.Textbox(label="Name")
|
||||
output = gr.Textbox(label="Output Box")
|
||||
greet_btn = gr.Button("Greet")
|
||||
|
||||
def greet(name):
|
||||
return "Hello " + name + "!"
|
||||
|
||||
gr.on(
|
||||
triggers=[name.submit, greet_btn.click],
|
||||
fn=greet,
|
||||
inputs=name,
|
||||
outputs=output,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
1
demo/on_listener_decorator/run.ipynb
Normal file
1
demo/on_listener_decorator/run.ipynb
Normal file
@ -0,0 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: on_listener_decorator"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", " name = gr.Textbox(label=\"Name\")\n", " output = gr.Textbox(label=\"Output Box\")\n", " greet_btn = gr.Button(\"Greet\")\n", "\n", " @gr.on(triggers=[name.submit, greet_btn.click], inputs=name, outputs=output)\n", " def greet(name):\n", " return \"Hello \" + name + \"!\"\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
14
demo/on_listener_decorator/run.py
Normal file
14
demo/on_listener_decorator/run.py
Normal file
@ -0,0 +1,14 @@
|
||||
import gradio as gr
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
name = gr.Textbox(label="Name")
|
||||
output = gr.Textbox(label="Output Box")
|
||||
greet_btn = gr.Button("Greet")
|
||||
|
||||
@gr.on(triggers=[name.submit, greet_btn.click], inputs=name, outputs=output)
|
||||
def greet(name):
|
||||
return "Hello " + name + "!"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
1
demo/on_listener_live/run.ipynb
Normal file
1
demo/on_listener_live/run.ipynb
Normal file
@ -0,0 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: on_listener_live"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " num1 = gr.Slider(1, 10)\n", " num2 = gr.Slider(1, 10)\n", " num3 = gr.Slider(1, 10)\n", " output = gr.Number(label=\"Sum\")\n", "\n", " @gr.on(inputs=[num1, num2, num3], outputs=output)\n", " def sum(a, b, c):\n", " return a + b + c\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
16
demo/on_listener_live/run.py
Normal file
16
demo/on_listener_live/run.py
Normal file
@ -0,0 +1,16 @@
|
||||
import gradio as gr
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Row():
|
||||
num1 = gr.Slider(1, 10)
|
||||
num2 = gr.Slider(1, 10)
|
||||
num3 = gr.Slider(1, 10)
|
||||
output = gr.Number(label="Sum")
|
||||
|
||||
@gr.on(inputs=[num1, num2, num3], outputs=output)
|
||||
def sum(a, b, c):
|
||||
return a + b + c
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: stable-diffusion\n", "### Note: This is a simplified version of the code needed to create the Stable Diffusion demo. See full code here: https://hf.co/spaces/stabilityai/stable-diffusion/tree/main\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio diffusers transformers nvidia-ml-py3 ftfy torch"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import torch\n", "from diffusers import StableDiffusionPipeline\n", "from PIL import Image\n", "import os\n", "\n", "auth_token = os.getenv(\"auth_token\")\n", "model_id = \"CompVis/stable-diffusion-v1-4\"\n", "device = \"cpu\"\n", "pipe = StableDiffusionPipeline.from_pretrained(\n", " model_id, use_auth_token=auth_token, revision=\"fp16\", torch_dtype=torch.float16\n", ")\n", "pipe = pipe.to(device)\n", "\n", "\n", "def infer(prompt, samples, steps, scale, seed):\n", " generator = torch.Generator(device=device).manual_seed(seed)\n", " images_list = pipe(\n", " [prompt] * samples,\n", " num_inference_steps=steps,\n", " guidance_scale=scale,\n", " generator=generator,\n", " )\n", " images = []\n", " safe_image = Image.open(r\"unsafe.png\")\n", " for i, image in enumerate(images_list[\"sample\"]):\n", " if images_list[\"nsfw_content_detected\"][i]:\n", " images.append(safe_image)\n", " else:\n", " images.append(image)\n", " return images\n", "\n", "\n", "block = gr.Blocks()\n", "\n", "with block:\n", " with gr.Group():\n", " with gr.Row():\n", " text = gr.Textbox(\n", " label=\"Enter your prompt\",\n", " max_lines=1,\n", " placeholder=\"Enter your prompt\",\n", " container=False,\n", " )\n", " btn = gr.Button(\"Generate image\")\n", " gallery = gr.Gallery(\n", " label=\"Generated images\",\n", " show_label=False,\n", " elem_id=\"gallery\",\n", " columns=[2],\n", " height=\"auto\",\n", " )\n", "\n", " advanced_button = gr.Button(\"Advanced options\", elem_id=\"advanced-btn\")\n", "\n", " with gr.Row(elem_id=\"advanced-options\"):\n", " samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=4, step=1)\n", " steps = gr.Slider(label=\"Steps\", minimum=1, maximum=50, value=45, step=1)\n", " scale = gr.Slider(\n", " label=\"Guidance Scale\", minimum=0, maximum=50, value=7.5, step=0.1\n", " )\n", " seed = gr.Slider(\n", " label=\"Seed\",\n", " minimum=0,\n", " maximum=2147483647,\n", " step=1,\n", " randomize=True,\n", " )\n", " text.submit(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)\n", " btn.click(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)\n", " advanced_button.click(\n", " None,\n", " [],\n", " text,\n", " )\n", "\n", "block.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: stable-diffusion\n", "### Note: This is a simplified version of the code needed to create the Stable Diffusion demo. See full code here: https://hf.co/spaces/stabilityai/stable-diffusion/tree/main\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio diffusers transformers nvidia-ml-py3 ftfy torch"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import torch\n", "from diffusers import StableDiffusionPipeline\n", "from PIL import Image\n", "import os\n", "\n", "auth_token = os.getenv(\"auth_token\")\n", "model_id = \"CompVis/stable-diffusion-v1-4\"\n", "device = \"cpu\"\n", "pipe = StableDiffusionPipeline.from_pretrained(\n", " model_id, use_auth_token=auth_token, revision=\"fp16\", torch_dtype=torch.float16\n", ")\n", "pipe = pipe.to(device)\n", "\n", "\n", "def infer(prompt, samples, steps, scale, seed):\n", " generator = torch.Generator(device=device).manual_seed(seed)\n", " images_list = pipe(\n", " [prompt] * samples,\n", " num_inference_steps=steps,\n", " guidance_scale=scale,\n", " generator=generator,\n", " )\n", " images = []\n", " safe_image = Image.open(r\"unsafe.png\")\n", " for i, image in enumerate(images_list[\"sample\"]):\n", " if images_list[\"nsfw_content_detected\"][i]:\n", " images.append(safe_image)\n", " else:\n", " images.append(image)\n", " return images\n", "\n", "\n", "block = gr.Blocks()\n", "\n", "with block:\n", " with gr.Group():\n", " with gr.Row():\n", " text = gr.Textbox(\n", " label=\"Enter your prompt\",\n", " max_lines=1,\n", " placeholder=\"Enter your prompt\",\n", " container=False,\n", " )\n", " btn = gr.Button(\"Generate image\")\n", " gallery = gr.Gallery(\n", " label=\"Generated images\",\n", " show_label=False,\n", " elem_id=\"gallery\",\n", " columns=[2],\n", " height=\"auto\",\n", " )\n", "\n", " advanced_button = gr.Button(\"Advanced options\", elem_id=\"advanced-btn\")\n", "\n", " with gr.Row(elem_id=\"advanced-options\"):\n", " samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=4, step=1)\n", " steps = gr.Slider(label=\"Steps\", minimum=1, maximum=50, value=45, step=1)\n", " scale = gr.Slider(\n", " label=\"Guidance Scale\", minimum=0, maximum=50, value=7.5, step=0.1\n", " )\n", " seed = gr.Slider(\n", " label=\"Seed\",\n", " minimum=0,\n", " maximum=2147483647,\n", " step=1,\n", " randomize=True,\n", " )\n", " gr.on([text.submit, btn.click], infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)\n", " advanced_button.click(\n", " None,\n", " [],\n", " text,\n", " )\n", "\n", "block.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -66,8 +66,7 @@ with block:
|
||||
step=1,
|
||||
randomize=True,
|
||||
)
|
||||
text.submit(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
|
||||
btn.click(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
|
||||
gr.on([text.submit, btn.click], infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
|
||||
advanced_button.click(
|
||||
None,
|
||||
[],
|
||||
|
@ -60,7 +60,7 @@ from gradio.components import (
|
||||
component,
|
||||
)
|
||||
from gradio.deploy_space import deploy
|
||||
from gradio.events import LikeData, SelectData
|
||||
from gradio.events import LikeData, SelectData, on
|
||||
from gradio.exceptions import Error
|
||||
from gradio.external import load
|
||||
from gradio.flagging import (
|
||||
|
@ -165,7 +165,14 @@ def launched_analytics(blocks: gradio.Blocks, data: dict[str, Any]) -> None:
|
||||
if not analytics_enabled():
|
||||
return
|
||||
|
||||
blocks_telemetry, inputs_telemetry, outputs_telemetry, targets_telemetry = (
|
||||
(
|
||||
blocks_telemetry,
|
||||
inputs_telemetry,
|
||||
outputs_telemetry,
|
||||
targets_telemetry,
|
||||
events_telemetry,
|
||||
) = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
@ -182,9 +189,12 @@ def launched_analytics(blocks: gradio.Blocks, data: dict[str, Any]) -> None:
|
||||
for x in blocks.dependencies:
|
||||
targets_telemetry = targets_telemetry + [
|
||||
# Sometimes the target can be the Blocks object itself, so we need to check if its in blocks.blocks
|
||||
str(blocks.blocks[y])
|
||||
str(blocks.blocks[y[0]])
|
||||
for y in x["targets"]
|
||||
if y in blocks.blocks
|
||||
if y[0] in blocks.blocks
|
||||
]
|
||||
events_telemetry = events_telemetry + [
|
||||
y[1] for y in x["targets"] if y[0] in blocks.blocks
|
||||
]
|
||||
inputs_telemetry = inputs_telemetry + [
|
||||
str(blocks.blocks[y]) for y in x["inputs"] if y in blocks.blocks
|
||||
@ -209,7 +219,7 @@ def launched_analytics(blocks: gradio.Blocks, data: dict[str, Any]) -> None:
|
||||
else outputs_telemetry,
|
||||
"targets": targets_telemetry,
|
||||
"blocks": blocks_telemetry,
|
||||
"events": [str(x["trigger"]) for x in blocks.dependencies],
|
||||
"events": events_telemetry,
|
||||
"is_wasm": wasm_utils.IS_WASM,
|
||||
}
|
||||
|
||||
|
322
gradio/blocks.py
322
gradio/blocks.py
@ -16,7 +16,7 @@ from collections import defaultdict
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Literal, Sequence, cast
|
||||
|
||||
import anyio
|
||||
import requests
|
||||
@ -77,6 +77,7 @@ if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
from fastapi.applications import FastAPI
|
||||
|
||||
from gradio.components import Component
|
||||
from gradio.events import EventListenerMethod
|
||||
|
||||
BUILT_IN_THEMES: dict[str, Theme] = {
|
||||
t.name: t
|
||||
@ -209,142 +210,6 @@ class Block:
|
||||
def get_expected_parent(self) -> type[BlockContext] | None:
|
||||
return None
|
||||
|
||||
def set_event_trigger(
|
||||
self,
|
||||
event_name: str,
|
||||
fn: Callable | None,
|
||||
inputs: Component | list[Component] | set[Component] | None,
|
||||
outputs: Component | list[Component] | None,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
scroll_to_output: bool = False,
|
||||
show_progress: str = "full",
|
||||
api_name: str | None | Literal[False] = None,
|
||||
js: str | None = None,
|
||||
no_target: bool = False,
|
||||
queue: bool | None = None,
|
||||
batch: bool = False,
|
||||
max_batch_size: int = 4,
|
||||
cancels: list[int] | None = None,
|
||||
every: float | None = None,
|
||||
collects_event_data: bool | None = None,
|
||||
trigger_after: int | None = None,
|
||||
trigger_only_on_success: bool = False,
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
"""
|
||||
Adds an event to the component's dependencies.
|
||||
Parameters:
|
||||
event_name: event name
|
||||
fn: Callable function
|
||||
inputs: input list
|
||||
outputs: output list
|
||||
preprocess: whether to run the preprocess methods of components
|
||||
postprocess: whether to run the postprocess methods of components
|
||||
scroll_to_output: whether to scroll to output of dependency on trigger
|
||||
show_progress: whether to show progress animation while running.
|
||||
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, the endpoint will be exposed in the api docs as an unnamed endpoint, although this behavior will be changed in Gradio 4.0. If set to a string, the endpoint will be exposed in the api docs with the given name.
|
||||
js: Experimental parameter (API may change): Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components
|
||||
no_target: if True, sets "targets" to [], used for Blocks "load" event
|
||||
queue: If True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
|
||||
batch: whether this function takes in a batch of inputs
|
||||
max_batch_size: the maximum batch size to send to the function
|
||||
cancels: a list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
||||
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
|
||||
collects_event_data: whether to collect event data for this event
|
||||
trigger_after: if set, this event will be triggered after 'trigger_after' function index
|
||||
trigger_only_on_success: if True, this event will only be triggered if the previous event was successful (only applies if `trigger_after` is set)
|
||||
Returns: dependency information, dependency index
|
||||
"""
|
||||
# Support for singular parameter
|
||||
if isinstance(inputs, set):
|
||||
inputs_as_dict = True
|
||||
inputs = sorted(inputs, key=lambda x: x._id)
|
||||
else:
|
||||
inputs_as_dict = False
|
||||
if inputs is None:
|
||||
inputs = []
|
||||
elif not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
|
||||
if isinstance(outputs, set):
|
||||
outputs = sorted(outputs, key=lambda x: x._id)
|
||||
else:
|
||||
if outputs is None:
|
||||
outputs = []
|
||||
elif not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
if fn is not None and not cancels:
|
||||
check_function_inputs_match(fn, inputs, inputs_as_dict)
|
||||
|
||||
if Context.root_block is None:
|
||||
raise AttributeError(
|
||||
f"{event_name}() and other events can only be called within a Blocks context."
|
||||
)
|
||||
if every is not None and every <= 0:
|
||||
raise ValueError("Parameter every must be positive or None")
|
||||
if every and batch:
|
||||
raise ValueError(
|
||||
f"Cannot run {event_name} event in a batch and every {every} seconds. "
|
||||
"Either batch is True or every is non-zero but not both."
|
||||
)
|
||||
|
||||
if every and fn:
|
||||
fn = get_continuous_fn(fn, every)
|
||||
elif every:
|
||||
raise ValueError("Cannot set a value for `every` without a `fn`.")
|
||||
|
||||
_, progress_index, event_data_index = (
|
||||
special_args(fn) if fn else (None, None, None)
|
||||
)
|
||||
Context.root_block.fns.append(
|
||||
BlockFunction(
|
||||
fn,
|
||||
inputs,
|
||||
outputs,
|
||||
preprocess,
|
||||
postprocess,
|
||||
inputs_as_dict,
|
||||
progress_index is not None,
|
||||
)
|
||||
)
|
||||
if api_name is not None and api_name is not False:
|
||||
api_name_ = utils.append_unique_suffix(
|
||||
api_name, [dep["api_name"] for dep in Context.root_block.dependencies]
|
||||
)
|
||||
if api_name != api_name_:
|
||||
warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
|
||||
api_name = api_name_
|
||||
|
||||
if collects_event_data is None:
|
||||
collects_event_data = event_data_index is not None
|
||||
|
||||
dependency = {
|
||||
"targets": [self._id] if not no_target else [],
|
||||
"trigger": event_name,
|
||||
"inputs": [block._id for block in inputs],
|
||||
"outputs": [block._id for block in outputs],
|
||||
"backend_fn": fn is not None,
|
||||
"js": js,
|
||||
"queue": False if fn is None else queue,
|
||||
"api_name": api_name,
|
||||
"scroll_to_output": False if utils.get_space() else scroll_to_output,
|
||||
"show_progress": show_progress,
|
||||
"every": every,
|
||||
"batch": batch,
|
||||
"max_batch_size": max_batch_size,
|
||||
"cancels": cancels or [],
|
||||
"types": {
|
||||
"continuous": bool(every),
|
||||
"generator": inspect.isgeneratorfunction(fn) or bool(every),
|
||||
},
|
||||
"collects_event_data": collects_event_data,
|
||||
"trigger_after": trigger_after,
|
||||
"trigger_only_on_success": trigger_only_on_success,
|
||||
}
|
||||
Context.root_block.dependencies.append(dependency)
|
||||
return dependency, len(Context.root_block.dependencies) - 1
|
||||
|
||||
def get_config(self):
|
||||
config = {}
|
||||
signature = inspect.signature(self.__class__.__init__)
|
||||
@ -908,12 +773,24 @@ class Blocks(BlockContext):
|
||||
# We fixed the issue by removing "fake_event" from the config in examples.py
|
||||
# but we still need to skip these events when loading the config to support
|
||||
# older demos
|
||||
if dependency["trigger"] == "fake_event":
|
||||
if "trigger" in dependency and dependency["trigger"] == "fake_event":
|
||||
continue
|
||||
for field in derived_fields:
|
||||
dependency.pop(field, None)
|
||||
targets = dependency.pop("targets")
|
||||
trigger = dependency.pop("trigger")
|
||||
|
||||
# older versions had a separate trigger field, but now it is part of the
|
||||
# targets field
|
||||
_targets = dependency.pop("targets")
|
||||
trigger = dependency.pop("trigger", None)
|
||||
targets = [
|
||||
getattr(
|
||||
original_mapping[
|
||||
target if isinstance(target, int) else target[0]
|
||||
],
|
||||
trigger if isinstance(target, int) else target[1],
|
||||
)
|
||||
for target in _targets
|
||||
]
|
||||
dependency.pop("backend_fn")
|
||||
dependency.pop("documentation", None)
|
||||
dependency["inputs"] = [
|
||||
@ -926,12 +803,11 @@ class Blocks(BlockContext):
|
||||
dependency["preprocess"] = False
|
||||
dependency["postprocess"] = False
|
||||
|
||||
for target in targets:
|
||||
dependency = original_mapping[target].set_event_trigger(
|
||||
event_name=trigger, fn=fn, **dependency
|
||||
)[0]
|
||||
if first_dependency is None:
|
||||
first_dependency = dependency
|
||||
dependency = blocks.set_event_trigger(
|
||||
targets=targets, fn=fn, **dependency
|
||||
)[0]
|
||||
if first_dependency is None:
|
||||
first_dependency = dependency
|
||||
|
||||
# Allows some use of Interface-specific methods with loaded Spaces
|
||||
if first_dependency and Context.root_block:
|
||||
@ -976,6 +852,143 @@ class Blocks(BlockContext):
|
||||
for block in self.blocks.values()
|
||||
)
|
||||
|
||||
def set_event_trigger(
|
||||
self,
|
||||
targets: Sequence[EventListenerMethod],
|
||||
fn: Callable | None,
|
||||
inputs: Component | list[Component] | set[Component] | None,
|
||||
outputs: Component | list[Component] | None,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
scroll_to_output: bool = False,
|
||||
show_progress: str = "full",
|
||||
api_name: str | None | Literal[False] = None,
|
||||
js: str | None = None,
|
||||
no_target: bool = False,
|
||||
queue: bool | None = None,
|
||||
batch: bool = False,
|
||||
max_batch_size: int = 4,
|
||||
cancels: list[int] | None = None,
|
||||
every: float | None = None,
|
||||
collects_event_data: bool | None = None,
|
||||
trigger_after: int | None = None,
|
||||
trigger_only_on_success: bool = False,
|
||||
) -> tuple[dict[str, Any], int]:
|
||||
"""
|
||||
Adds an event to the component's dependencies.
|
||||
Parameters:
|
||||
event_name: event name
|
||||
fn: Callable function
|
||||
inputs: input list
|
||||
outputs: output list
|
||||
preprocess: whether to run the preprocess methods of components
|
||||
postprocess: whether to run the postprocess methods of components
|
||||
scroll_to_output: whether to scroll to output of dependency on trigger
|
||||
show_progress: whether to show progress animation while running.
|
||||
api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, the endpoint will be exposed in the api docs as an unnamed endpoint, although this behavior will be changed in Gradio 4.0. If set to a string, the endpoint will be exposed in the api docs with the given name.
|
||||
js: Experimental parameter (API may change): Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components
|
||||
no_target: if True, sets "targets" to [], used for Blocks "load" event
|
||||
queue: If True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
|
||||
batch: whether this function takes in a batch of inputs
|
||||
max_batch_size: the maximum batch size to send to the function
|
||||
cancels: a list of other events to cancel when this event is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method.
|
||||
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
|
||||
collects_event_data: whether to collect event data for this event
|
||||
trigger_after: if set, this event will be triggered after 'trigger_after' function index
|
||||
trigger_only_on_success: if True, this event will only be triggered if the previous event was successful (only applies if `trigger_after` is set)
|
||||
Returns: dependency information, dependency index
|
||||
"""
|
||||
# Support for singular parameter
|
||||
_targets = [
|
||||
(
|
||||
target.trigger._id if target.trigger and not no_target else None,
|
||||
target.event_name,
|
||||
)
|
||||
for target in targets
|
||||
]
|
||||
if isinstance(inputs, set):
|
||||
inputs_as_dict = True
|
||||
inputs = sorted(inputs, key=lambda x: x._id)
|
||||
else:
|
||||
inputs_as_dict = False
|
||||
if inputs is None:
|
||||
inputs = []
|
||||
elif not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
|
||||
if isinstance(outputs, set):
|
||||
outputs = sorted(outputs, key=lambda x: x._id)
|
||||
else:
|
||||
if outputs is None:
|
||||
outputs = []
|
||||
elif not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
|
||||
if fn is not None and not cancels:
|
||||
check_function_inputs_match(fn, inputs, inputs_as_dict)
|
||||
if every is not None and every <= 0:
|
||||
raise ValueError("Parameter every must be positive or None")
|
||||
if every and batch:
|
||||
raise ValueError(
|
||||
f"Cannot run event in a batch and every {every} seconds. "
|
||||
"Either batch is True or every is non-zero but not both."
|
||||
)
|
||||
|
||||
if every and fn:
|
||||
fn = get_continuous_fn(fn, every)
|
||||
elif every:
|
||||
raise ValueError("Cannot set a value for `every` without a `fn`.")
|
||||
|
||||
_, progress_index, event_data_index = (
|
||||
special_args(fn) if fn else (None, None, None)
|
||||
)
|
||||
self.fns.append(
|
||||
BlockFunction(
|
||||
fn,
|
||||
inputs,
|
||||
outputs,
|
||||
preprocess,
|
||||
postprocess,
|
||||
inputs_as_dict,
|
||||
progress_index is not None,
|
||||
)
|
||||
)
|
||||
if api_name is not None and api_name is not False:
|
||||
api_name_ = utils.append_unique_suffix(
|
||||
api_name, [dep["api_name"] for dep in self.dependencies]
|
||||
)
|
||||
if api_name != api_name_:
|
||||
warnings.warn(f"api_name {api_name} already exists, using {api_name_}")
|
||||
api_name = api_name_
|
||||
|
||||
if collects_event_data is None:
|
||||
collects_event_data = event_data_index is not None
|
||||
|
||||
dependency = {
|
||||
"targets": _targets,
|
||||
"inputs": [block._id for block in inputs],
|
||||
"outputs": [block._id for block in outputs],
|
||||
"backend_fn": fn is not None,
|
||||
"js": js,
|
||||
"queue": False if fn is None else queue,
|
||||
"api_name": api_name,
|
||||
"scroll_to_output": False if utils.get_space() else scroll_to_output,
|
||||
"show_progress": show_progress,
|
||||
"every": every,
|
||||
"batch": batch,
|
||||
"max_batch_size": max_batch_size,
|
||||
"cancels": cancels or [],
|
||||
"types": {
|
||||
"continuous": bool(every),
|
||||
"generator": inspect.isgeneratorfunction(fn) or bool(every),
|
||||
},
|
||||
"collects_event_data": collects_event_data,
|
||||
"trigger_after": trigger_after,
|
||||
"trigger_only_on_success": trigger_only_on_success,
|
||||
}
|
||||
self.dependencies.append(dependency)
|
||||
return dependency, len(self.dependencies) - 1
|
||||
|
||||
def render(self):
|
||||
if Context.root_block is not None:
|
||||
if self._id in Context.root_block.blocks:
|
||||
@ -1703,10 +1716,15 @@ Received outputs:
|
||||
name=name, src=src, hf_token=api_key, alias=alias, **kwargs
|
||||
)
|
||||
else:
|
||||
from gradio.events import Dependency
|
||||
from gradio.events import Dependency, EventListenerMethod
|
||||
|
||||
dep, dep_index = self.set_event_trigger(
|
||||
event_name="load",
|
||||
if Context.root_block is None:
|
||||
raise AttributeError(
|
||||
"Cannot call load() outside of a gradio.Blocks context."
|
||||
)
|
||||
|
||||
dep, dep_index = Context.root_block.set_event_trigger(
|
||||
targets=[EventListenerMethod(self, "load")],
|
||||
fn=fn,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
@ -1722,7 +1740,7 @@ Received outputs:
|
||||
every=every,
|
||||
no_target=True,
|
||||
)
|
||||
return Dependency(self, dep, dep_index, fn)
|
||||
return Dependency(dep, dep_index, fn)
|
||||
|
||||
def clear(self):
|
||||
"""Resets the layout of the Blocks object."""
|
||||
@ -2325,8 +2343,10 @@ Received outputs:
|
||||
):
|
||||
load_fn, every = component.load_event_to_attach
|
||||
# Use set_event_trigger to avoid ambiguity between load class/instance method
|
||||
from gradio.events import EventListenerMethod
|
||||
|
||||
dep = self.set_event_trigger(
|
||||
"load",
|
||||
[EventListenerMethod(self, "load")],
|
||||
load_fn,
|
||||
None,
|
||||
component,
|
||||
|
@ -22,7 +22,7 @@ from gradio.components import (
|
||||
Textbox,
|
||||
get_component_instance,
|
||||
)
|
||||
from gradio.events import Dependency, EventListenerMethod
|
||||
from gradio.events import Dependency, EventListenerMethod, on
|
||||
from gradio.helpers import create_examples as Examples # noqa: N812
|
||||
from gradio.layouts import Accordion, Column, Group, Row
|
||||
from gradio.themes import ThemeClass as Theme
|
||||
@ -245,8 +245,14 @@ class ChatInterface(Blocks):
|
||||
|
||||
def _setup_events(self) -> None:
|
||||
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
||||
submit_triggers = (
|
||||
[self.textbox.submit, self.submit_btn.click]
|
||||
if self.submit_btn
|
||||
else [self.textbox.submit]
|
||||
)
|
||||
submit_event = (
|
||||
self.textbox.submit(
|
||||
on(
|
||||
submit_triggers,
|
||||
self._clear_and_save_textbox,
|
||||
[self.textbox],
|
||||
[self.textbox, self.saved_input],
|
||||
@ -267,32 +273,7 @@ class ChatInterface(Blocks):
|
||||
api_name=False,
|
||||
)
|
||||
)
|
||||
self._setup_stop_events(self.textbox.submit, submit_event)
|
||||
|
||||
if self.submit_btn:
|
||||
click_event = (
|
||||
self.submit_btn.click(
|
||||
self._clear_and_save_textbox,
|
||||
[self.textbox],
|
||||
[self.textbox, self.saved_input],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
.then(
|
||||
self._display_input,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
.then(
|
||||
submit_fn,
|
||||
[self.saved_input, self.chatbot_state] + self.additional_inputs,
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
)
|
||||
)
|
||||
self._setup_stop_events(self.submit_btn.click, click_event)
|
||||
self._setup_stop_events(submit_triggers, submit_event)
|
||||
|
||||
if self.retry_btn:
|
||||
retry_event = (
|
||||
@ -317,7 +298,7 @@ class ChatInterface(Blocks):
|
||||
api_name=False,
|
||||
)
|
||||
)
|
||||
self._setup_stop_events(self.retry_btn.click, retry_event)
|
||||
self._setup_stop_events([self.retry_btn.click], retry_event)
|
||||
|
||||
if self.undo_btn:
|
||||
self.undo_btn.click(
|
||||
@ -344,17 +325,21 @@ class ChatInterface(Blocks):
|
||||
)
|
||||
|
||||
def _setup_stop_events(
|
||||
self, event_trigger: EventListenerMethod, event_to_cancel: Dependency
|
||||
self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
|
||||
) -> None:
|
||||
if self.stop_btn and self.is_generator:
|
||||
if self.submit_btn:
|
||||
event_trigger(
|
||||
lambda: (Button.update(visible=False), Button.update(visible=True)),
|
||||
None,
|
||||
[self.submit_btn, self.stop_btn],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
for event_trigger in event_triggers:
|
||||
event_trigger(
|
||||
lambda: (
|
||||
Button.update(visible=False),
|
||||
Button.update(visible=True),
|
||||
),
|
||||
None,
|
||||
[self.submit_btn, self.stop_btn],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
event_to_cancel.then(
|
||||
lambda: (Button.update(visible=True), Button.update(visible=False)),
|
||||
None,
|
||||
@ -363,13 +348,14 @@ class ChatInterface(Blocks):
|
||||
queue=False,
|
||||
)
|
||||
else:
|
||||
event_trigger(
|
||||
lambda: Button.update(visible=True),
|
||||
None,
|
||||
[self.stop_btn],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
for event_trigger in event_triggers:
|
||||
event_trigger(
|
||||
lambda: Button.update(visible=True),
|
||||
None,
|
||||
[self.stop_btn],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
event_to_cancel.then(
|
||||
lambda: Button.update(visible=False),
|
||||
None,
|
||||
|
134
gradio/events.py
134
gradio/events.py
@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
|
||||
from gradio.blocks import Block
|
||||
from gradio.context import Context
|
||||
from gradio.deprecation import warn_deprecation
|
||||
from gradio.helpers import EventData
|
||||
from gradio.utils import get_cancel_function
|
||||
@ -20,14 +21,21 @@ set_documentation_group("events")
|
||||
|
||||
|
||||
def set_cancel_events(
|
||||
block: Block, event_name: str, cancels: None | dict[str, Any] | list[dict[str, Any]]
|
||||
triggers: Sequence[EventListenerMethod],
|
||||
cancels: None | dict[str, Any] | list[dict[str, Any]],
|
||||
):
|
||||
if cancels:
|
||||
if not isinstance(cancels, list):
|
||||
cancels = [cancels]
|
||||
cancel_fn, fn_indices_to_cancel = get_cancel_function(cancels)
|
||||
block.set_event_trigger(
|
||||
event_name,
|
||||
|
||||
if Context.root_block is None:
|
||||
raise AttributeError(
|
||||
"Cannot cancel {self.event_name} outside of a gradio.Blocks context."
|
||||
)
|
||||
|
||||
Context.root_block.set_event_trigger(
|
||||
triggers,
|
||||
cancel_fn,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
@ -45,12 +53,11 @@ class EventListener(Block):
|
||||
|
||||
|
||||
class Dependency(dict):
|
||||
def __init__(self, trigger, key_vals, dep_index, fn):
|
||||
def __init__(self, key_vals, dep_index, fn):
|
||||
super().__init__(key_vals)
|
||||
self.fn = fn
|
||||
self.trigger = trigger
|
||||
self.then = EventListenerMethod(
|
||||
self.trigger,
|
||||
None,
|
||||
"then",
|
||||
trigger_after=dep_index,
|
||||
trigger_only_on_success=False,
|
||||
@ -59,7 +66,7 @@ class Dependency(dict):
|
||||
Triggered after directly preceding event is completed, regardless of success or failure.
|
||||
"""
|
||||
self.success = EventListenerMethod(
|
||||
self.trigger,
|
||||
None,
|
||||
"success",
|
||||
trigger_after=dep_index,
|
||||
trigger_only_on_success=True,
|
||||
@ -79,7 +86,7 @@ class EventListenerMethod:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trigger: Block,
|
||||
trigger: Block | None,
|
||||
event_name: str,
|
||||
show_progress: Literal["full", "minimal", "hidden"] = "full",
|
||||
callback: Callable | None = None,
|
||||
@ -155,7 +162,7 @@ class EventListenerMethod:
|
||||
|
||||
return inner
|
||||
|
||||
return Dependency(None, {}, None, wrapper)
|
||||
return Dependency({}, None, wrapper)
|
||||
|
||||
if status_tracker:
|
||||
warn_deprecation(
|
||||
@ -171,8 +178,13 @@ class EventListenerMethod:
|
||||
if isinstance(show_progress, bool):
|
||||
show_progress = "full" if show_progress else "hidden"
|
||||
|
||||
dep, dep_index = self.trigger.set_event_trigger(
|
||||
self.event_name,
|
||||
if Context.root_block is None:
|
||||
raise AttributeError(
|
||||
"Cannot call {self.event_name} outside of a gradio.Blocks context."
|
||||
)
|
||||
|
||||
dep, dep_index = Context.root_block.set_event_trigger(
|
||||
[self],
|
||||
fn,
|
||||
inputs,
|
||||
outputs,
|
||||
@ -191,10 +203,106 @@ class EventListenerMethod:
|
||||
trigger_after=self.trigger_after,
|
||||
trigger_only_on_success=self.trigger_only_on_success,
|
||||
)
|
||||
set_cancel_events(self.trigger, self.event_name, cancels)
|
||||
set_cancel_events([self], cancels)
|
||||
if self.callback:
|
||||
self.callback()
|
||||
return Dependency(self.trigger, dep, dep_index, fn)
|
||||
return Dependency(dep, dep_index, fn)
|
||||
|
||||
|
||||
def on(
|
||||
triggers: Sequence[EventListenerMethod] | EventListenerMethod | None = None,
|
||||
fn: Callable | None | Literal["decorator"] = "decorator",
|
||||
inputs: Component | list[Component] | set[Component] | None = None,
|
||||
outputs: Component | list[Component] | None = None,
|
||||
*,
|
||||
api_name: str | None | Literal[False] = None,
|
||||
scroll_to_output: bool = False,
|
||||
show_progress: Literal["full", "minimal", "hidden"] = "full",
|
||||
queue: bool | None = None,
|
||||
batch: bool = False,
|
||||
max_batch_size: int = 4,
|
||||
preprocess: bool = True,
|
||||
postprocess: bool = True,
|
||||
cancels: dict[str, Any] | list[dict[str, Any]] | None = None,
|
||||
every: float | None = None,
|
||||
_js: str | None = None,
|
||||
) -> Dependency:
|
||||
"""
|
||||
Parameters:
|
||||
triggers: List of triggers to listen to, e.g. [btn.click, number.change]. If None, will listen to changes to any inputs.
|
||||
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
|
||||
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
|
||||
outputs: List of gradio.components to use as outputs. If the function returns no outputs, this should be an empty list.
|
||||
api_name: Defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, the endpoint will be exposed in the api docs as an unnamed endpoint, although this behavior will be changed in Gradio 4.0. If set to a string, the endpoint will be exposed in the api docs with the given name.
|
||||
scroll_to_output: If True, will scroll to output component on completion
|
||||
show_progress: If True, will show progress animation while pending
|
||||
queue: If True, will place the request on the queue, if the queue has been enabled. If False, will not put this event on the queue, even if the queue has been enabled. If None, will use the queue setting of the gradio app.
|
||||
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. The lists should be of equal length (and be up to length `max_batch_size`). The function is then *required* to return a tuple of lists (even if there is only 1 output component), with each list in the tuple corresponding to one output component.
|
||||
max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True)
|
||||
preprocess: If False, will not run preprocessing of component data before running 'fn' (e.g. leaving it as a base64 string if this method is called with the `Image` component).
|
||||
postprocess: If False, will not run postprocessing of component data before returning 'fn' output to the browser.
|
||||
cancels: A list of other events to cancel when this listener is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. Functions that have not yet run (or generators that are iterating) will be cancelled, but functions that are currently running will be allowed to finish.
|
||||
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
|
||||
"""
|
||||
from gradio.components.base import Component
|
||||
|
||||
if isinstance(triggers, EventListenerMethod):
|
||||
triggers = [triggers]
|
||||
if isinstance(inputs, Component):
|
||||
inputs = [inputs]
|
||||
|
||||
if fn == "decorator":
|
||||
|
||||
def wrapper(func):
|
||||
on(
|
||||
triggers,
|
||||
fn=func,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
api_name=api_name,
|
||||
scroll_to_output=scroll_to_output,
|
||||
show_progress=show_progress,
|
||||
queue=queue,
|
||||
batch=batch,
|
||||
max_batch_size=max_batch_size,
|
||||
preprocess=preprocess,
|
||||
postprocess=postprocess,
|
||||
cancels=cancels,
|
||||
every=every,
|
||||
_js=_js,
|
||||
)
|
||||
|
||||
@wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return Dependency({}, None, wrapper)
|
||||
|
||||
if Context.root_block is None:
|
||||
raise Exception("Cannot call on() outside of a gradio.Blocks context.")
|
||||
if triggers is None:
|
||||
triggers = [input.change for input in inputs] if inputs is not None else []
|
||||
|
||||
dep, dep_index = Context.root_block.set_event_trigger(
|
||||
triggers,
|
||||
fn,
|
||||
inputs,
|
||||
outputs,
|
||||
preprocess=preprocess,
|
||||
postprocess=postprocess,
|
||||
scroll_to_output=scroll_to_output,
|
||||
show_progress=show_progress,
|
||||
api_name=api_name,
|
||||
js=_js,
|
||||
queue=queue,
|
||||
batch=batch,
|
||||
max_batch_size=max_batch_size,
|
||||
every=every,
|
||||
)
|
||||
set_cancel_events(triggers, cancels)
|
||||
return Dependency(dep, dep_index, fn)
|
||||
|
||||
|
||||
@document("*change", inherit=True)
|
||||
|
@ -315,8 +315,10 @@ class Examples:
|
||||
fn = self.fn
|
||||
|
||||
# create a fake dependency to process the examples and get the predictions
|
||||
from gradio.events import EventListenerMethod
|
||||
|
||||
dependency, fn_index = Context.root_block.set_event_trigger(
|
||||
event_name="fake_event",
|
||||
[EventListenerMethod(Context.root_block, "load")],
|
||||
fn=fn,
|
||||
inputs=self.inputs_with_examples, # type: ignore
|
||||
outputs=self.outputs, # type: ignore
|
||||
|
@ -27,7 +27,7 @@ from gradio.components import (
|
||||
)
|
||||
from gradio.data_classes import InterfaceTypes
|
||||
from gradio.deprecation import warn_deprecation
|
||||
from gradio.events import Changeable, Streamable, Submittable
|
||||
from gradio.events import Changeable, Streamable, Submittable, on
|
||||
from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
|
||||
from gradio.layouts import Column, Row, Tab, Tabs
|
||||
from gradio.pipelines import load_from_pipeline
|
||||
@ -38,6 +38,8 @@ set_documentation_group("interface")
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
from transformers.pipelines.base import Pipeline
|
||||
|
||||
from gradio.events import EventListenerMethod
|
||||
|
||||
|
||||
@document("launch", "load", "from_pipeline", "integrate", "queue")
|
||||
class Interface(Blocks):
|
||||
@ -624,26 +626,21 @@ class Interface(Blocks):
|
||||
max_batch_size=self.max_batch_size,
|
||||
)
|
||||
else:
|
||||
events: list[EventListenerMethod] = []
|
||||
for component in self.input_components:
|
||||
if isinstance(component, Streamable) and component.streaming:
|
||||
component.stream(
|
||||
self.fn,
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
api_name=self.api_name,
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
)
|
||||
continue
|
||||
if isinstance(component, Changeable):
|
||||
component.change(
|
||||
self.fn,
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
api_name=self.api_name,
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
)
|
||||
events.append(component.stream)
|
||||
elif isinstance(component, Changeable):
|
||||
events.append(component.change)
|
||||
on(
|
||||
events,
|
||||
self.fn,
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
api_name=self.api_name,
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
)
|
||||
else:
|
||||
assert submit_btn is not None, "Submit button not rendered"
|
||||
fn = self.fn
|
||||
@ -654,7 +651,6 @@ class Interface(Blocks):
|
||||
for component in self.input_components
|
||||
if isinstance(component, Submittable)
|
||||
]
|
||||
predict_events = []
|
||||
|
||||
if stop_btn:
|
||||
extra_output = [submit_btn, stop_btn]
|
||||
@ -662,57 +658,54 @@ class Interface(Blocks):
|
||||
def cleanup():
|
||||
return [Button.update(visible=True), Button.update(visible=False)]
|
||||
|
||||
for i, trigger in enumerate(triggers):
|
||||
predict_event = trigger(
|
||||
lambda: (
|
||||
submit_btn.update(visible=False),
|
||||
stop_btn.update(visible=True),
|
||||
),
|
||||
inputs=None,
|
||||
outputs=[submit_btn, stop_btn],
|
||||
queue=False,
|
||||
).then(
|
||||
self.fn,
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
api_name=self.api_name if i == 0 else None,
|
||||
scroll_to_output=True,
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
batch=self.batch,
|
||||
max_batch_size=self.max_batch_size,
|
||||
)
|
||||
predict_events.append(predict_event)
|
||||
predict_event = on(
|
||||
triggers,
|
||||
lambda: (
|
||||
submit_btn.update(visible=False),
|
||||
stop_btn.update(visible=True),
|
||||
),
|
||||
inputs=None,
|
||||
outputs=[submit_btn, stop_btn],
|
||||
queue=False,
|
||||
).then(
|
||||
self.fn,
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
api_name=self.api_name,
|
||||
scroll_to_output=True,
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
batch=self.batch,
|
||||
max_batch_size=self.max_batch_size,
|
||||
)
|
||||
|
||||
predict_event.then(
|
||||
cleanup,
|
||||
inputs=None,
|
||||
outputs=extra_output, # type: ignore
|
||||
queue=False,
|
||||
)
|
||||
predict_event.then(
|
||||
cleanup,
|
||||
inputs=None,
|
||||
outputs=extra_output, # type: ignore
|
||||
queue=False,
|
||||
)
|
||||
|
||||
stop_btn.click(
|
||||
cleanup,
|
||||
inputs=None,
|
||||
outputs=[submit_btn, stop_btn],
|
||||
cancels=predict_events,
|
||||
cancels=predict_event,
|
||||
queue=False,
|
||||
)
|
||||
else:
|
||||
for i, trigger in enumerate(triggers):
|
||||
predict_events.append(
|
||||
trigger(
|
||||
fn,
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
api_name=self.api_name if i == 0 else None,
|
||||
scroll_to_output=True,
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
batch=self.batch,
|
||||
max_batch_size=self.max_batch_size,
|
||||
)
|
||||
)
|
||||
on(
|
||||
triggers,
|
||||
fn,
|
||||
self.input_components,
|
||||
self.output_components,
|
||||
api_name=self.api_name,
|
||||
scroll_to_output=True,
|
||||
preprocess=not (self.api_mode),
|
||||
postprocess=not (self.api_mode),
|
||||
batch=self.batch,
|
||||
max_batch_size=self.max_batch_size,
|
||||
)
|
||||
|
||||
def attach_clear_events(
|
||||
self,
|
||||
|
@ -371,7 +371,7 @@ def assert_configs_are_equivalent_besides_ids(
|
||||
|
||||
for d1, d2 in zip(config1["dependencies"], config2["dependencies"]):
|
||||
for t1, t2 in zip(d1.pop("targets"), d2.pop("targets")):
|
||||
assert_same_components(t1, t2)
|
||||
assert_same_components(t1[0], t2[0])
|
||||
for i1, i2 in zip(d1.pop("inputs"), d2.pop("inputs")):
|
||||
assert_same_components(i1, i2)
|
||||
for o1, o2 in zip(d1.pop("outputs"), d2.pop("outputs")):
|
||||
|
@ -162,3 +162,21 @@ In the 2 player tic-tac-toe demo below, a user can select a cell in the `DataFra
|
||||
|
||||
$code_tictactoe
|
||||
$demo_tictactoe
|
||||
|
||||
## Binding Multiple Triggers to a Function
|
||||
|
||||
Often times, you may want to bind multiple triggers to the same function. For example, you may want to allow a user to click a submit button, or press enter to submit a form. You can do this using the `gr.on` method and passing a list of triggers to the `trigger`.
|
||||
|
||||
$code_on_listener_basic
|
||||
$demo_on_listener_basic
|
||||
|
||||
You can use decorator syntax as well:
|
||||
|
||||
$code_on_listener_decorator
|
||||
|
||||
You can use `gr.on` to create "live" events by binding to the change event of all components. If you do not specify any triggers, the function will automatically bind to the change event of all input components.
|
||||
|
||||
$code_on_listener_live
|
||||
$demo_on_listener_live
|
||||
|
||||
You can follow `gr.on` with `.then`, just like any regular event listener. This handy method should save you from having to write a lot of repetitive code!
|
@ -543,9 +543,7 @@
|
||||
|
||||
$: target_map = dependencies.reduce(
|
||||
(acc, dep, i) => {
|
||||
let { targets, trigger } = dep;
|
||||
|
||||
targets.forEach((id) => {
|
||||
dep.targets.forEach(([id, trigger]) => {
|
||||
if (!acc[id]) {
|
||||
acc[id] = {};
|
||||
}
|
||||
@ -577,7 +575,7 @@
|
||||
|
||||
// handle load triggers
|
||||
dependencies.forEach((dep, i) => {
|
||||
if (dep.targets.length === 0 && dep.trigger === "load") {
|
||||
if (dep.targets.length === 1 && dep.targets[0][1] === "load") {
|
||||
trigger_api_call(i);
|
||||
}
|
||||
});
|
||||
|
@ -23,8 +23,7 @@ export interface DependencyTypes {
|
||||
}
|
||||
|
||||
export interface Dependency {
|
||||
trigger: string;
|
||||
targets: number[];
|
||||
targets: [number, string][];
|
||||
inputs: number[];
|
||||
outputs: number[];
|
||||
backend_fn: boolean;
|
||||
|
@ -576,7 +576,7 @@ class TestComponentsInBlocks:
|
||||
else:
|
||||
assert component.load_event_to_attach
|
||||
dependencies_on_load = [
|
||||
dep["trigger"] == "load" for dep in demo.config["dependencies"]
|
||||
dep["targets"][0][1] == "load" for dep in demo.config["dependencies"]
|
||||
]
|
||||
assert all(dependencies_on_load)
|
||||
assert len(dependencies_on_load) == 2
|
||||
@ -592,7 +592,9 @@ class TestComponentsInBlocks:
|
||||
)
|
||||
|
||||
dependencies_on_load = [
|
||||
dep for dep in interface.config["dependencies"] if dep["trigger"] == "load"
|
||||
dep
|
||||
for dep in interface.config["dependencies"]
|
||||
if dep["targets"][0][1] == "load"
|
||||
]
|
||||
assert len(dependencies_on_load) == len(io_components)
|
||||
assert all(dep["every"] == 1 for dep in dependencies_on_load)
|
||||
@ -1429,7 +1431,7 @@ class TestCancel:
|
||||
class TestEvery:
|
||||
def test_raise_exception_if_parameters_invalid(self):
|
||||
with pytest.raises(
|
||||
ValueError, match="Cannot run change event in a batch and every 0.5 seconds"
|
||||
ValueError, match="Cannot run event in a batch and every 0.5 seconds"
|
||||
):
|
||||
with gr.Blocks():
|
||||
num = gr.Number()
|
||||
|
@ -37,7 +37,7 @@ class TestOAuthButtons:
|
||||
button = gr.LoginButton()
|
||||
|
||||
login_event = demo.dependencies[0]
|
||||
assert login_event["trigger"] == "click"
|
||||
assert login_event["targets"][0][1] == "click"
|
||||
assert not login_event["backend_fn"] # No Python code
|
||||
assert login_event["js"] # But JS code instead
|
||||
assert login_event["inputs"] == [button._id]
|
||||
|
@ -53,26 +53,22 @@ class TestInit:
|
||||
chatbot = gr.ChatInterface(double)
|
||||
dependencies = chatbot.dependencies
|
||||
textbox = chatbot.textbox._id
|
||||
submit_btn = chatbot.submit_btn._id
|
||||
assert next(
|
||||
(
|
||||
d
|
||||
for d in dependencies
|
||||
if d["targets"] == [textbox] and d["trigger"] == "submit"
|
||||
if d["targets"] == [(textbox, "submit"), (submit_btn, "click")]
|
||||
),
|
||||
None,
|
||||
)
|
||||
for btn_id in [
|
||||
chatbot.submit_btn._id,
|
||||
chatbot.retry_btn._id,
|
||||
chatbot.clear_btn._id,
|
||||
chatbot.undo_btn._id,
|
||||
]:
|
||||
assert next(
|
||||
(
|
||||
d
|
||||
for d in dependencies
|
||||
if d["targets"] == [btn_id] and d["trigger"] == "click"
|
||||
),
|
||||
(d for d in dependencies if d["targets"][0] == (btn_id, "click")),
|
||||
None,
|
||||
)
|
||||
|
||||
|
@ -16,7 +16,7 @@ class TestEvent:
|
||||
|
||||
img.clear(fn_img_cleared, [], [])
|
||||
|
||||
assert demo.config["dependencies"][0]["trigger"] == "clear"
|
||||
assert demo.config["dependencies"][0]["targets"][0][1] == "clear"
|
||||
|
||||
def test_event_data(self):
|
||||
with gr.Blocks() as demo:
|
||||
@ -69,6 +69,42 @@ class TestEvent:
|
||||
assert not parent.config["dependencies"][2]["trigger_only_on_success"]
|
||||
assert parent.config["dependencies"][3]["trigger_only_on_success"]
|
||||
|
||||
def test_on_listener(self):
|
||||
with gr.Blocks() as demo:
|
||||
name = gr.Textbox(label="Name")
|
||||
output = gr.Textbox(label="Output Box")
|
||||
greet_btn = gr.Button("Greet")
|
||||
|
||||
def greet(name):
|
||||
return "Hello " + name + "!"
|
||||
|
||||
gr.on(
|
||||
triggers=[name.submit, greet_btn.click],
|
||||
fn=greet,
|
||||
inputs=name,
|
||||
outputs=output,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
num1 = gr.Slider(1, 10)
|
||||
num2 = gr.Slider(1, 10)
|
||||
num3 = gr.Slider(1, 10)
|
||||
output = gr.Number(label="Sum")
|
||||
|
||||
@gr.on(inputs=[num1, num2, num3], outputs=output)
|
||||
def sum(a, b, c):
|
||||
return a + b + c
|
||||
|
||||
assert demo.config["dependencies"][0]["targets"] == [
|
||||
(name._id, "submit"),
|
||||
(greet_btn._id, "click"),
|
||||
]
|
||||
assert demo.config["dependencies"][1]["targets"] == [
|
||||
(num1._id, "change"),
|
||||
(num2._id, "change"),
|
||||
(num3._id, "change"),
|
||||
]
|
||||
|
||||
def test_load_chaining(self):
|
||||
calls = 0
|
||||
|
||||
@ -83,9 +119,9 @@ class TestEvent:
|
||||
increment, inputs=None, outputs=out
|
||||
)
|
||||
|
||||
assert demo.config["dependencies"][0]["trigger"] == "load"
|
||||
assert demo.config["dependencies"][0]["targets"][0][1] == "load"
|
||||
assert demo.config["dependencies"][0]["trigger_after"] is None
|
||||
assert demo.config["dependencies"][1]["trigger"] == "then"
|
||||
assert demo.config["dependencies"][1]["targets"][0][1] == "then"
|
||||
assert demo.config["dependencies"][1]["trigger_after"] == 0
|
||||
|
||||
def test_load_chaining_reuse(self):
|
||||
@ -105,9 +141,9 @@ class TestEvent:
|
||||
with gr.Blocks() as demo2:
|
||||
demo.render()
|
||||
|
||||
assert demo2.config["dependencies"][0]["trigger"] == "load"
|
||||
assert demo2.config["dependencies"][0]["targets"][0][1] == "load"
|
||||
assert demo2.config["dependencies"][0]["trigger_after"] is None
|
||||
assert demo2.config["dependencies"][1]["trigger"] == "then"
|
||||
assert demo2.config["dependencies"][1]["targets"][0][1] == "then"
|
||||
assert demo2.config["dependencies"][1]["trigger_after"] == 0
|
||||
|
||||
|
||||
|
@ -1,12 +1,12 @@
|
||||
{
|
||||
"version": "3.44.1",
|
||||
"version": "3.44.4",
|
||||
"mode": "blocks",
|
||||
"app_id": 3299865391549106311,
|
||||
"app_id": 5570383359191764024,
|
||||
"dev_mode": false,
|
||||
"analytics_enabled": false,
|
||||
"components": [
|
||||
{
|
||||
"id": 31,
|
||||
"id": 1,
|
||||
"type": "markdown",
|
||||
"props": {
|
||||
"value": "# Detect Disease From Scan\nWith this model you can lorem ipsum\n- ipsum 1\n- ipsum 2",
|
||||
@ -35,7 +35,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 32,
|
||||
"id": 2,
|
||||
"type": "checkboxgroup",
|
||||
"props": {
|
||||
"choices": [
|
||||
@ -82,7 +82,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 33,
|
||||
"id": 3,
|
||||
"type": "tabs",
|
||||
"props": {
|
||||
"visible": true,
|
||||
@ -91,7 +91,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 34,
|
||||
"id": 4,
|
||||
"type": "tabitem",
|
||||
"props": {
|
||||
"label": "X-ray",
|
||||
@ -100,7 +100,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 35,
|
||||
"id": 5,
|
||||
"type": "row",
|
||||
"props": {
|
||||
"variant": "default",
|
||||
@ -110,7 +110,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 36,
|
||||
"id": 6,
|
||||
"type": "image",
|
||||
"props": {
|
||||
"image_mode": "RGB",
|
||||
@ -145,7 +145,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 37,
|
||||
"id": 7,
|
||||
"type": "json",
|
||||
"props": {
|
||||
"show_label": true,
|
||||
@ -171,7 +171,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 38,
|
||||
"id": 8,
|
||||
"type": "button",
|
||||
"props": {
|
||||
"value": "Run",
|
||||
@ -193,7 +193,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 39,
|
||||
"id": 9,
|
||||
"type": "tabitem",
|
||||
"props": {
|
||||
"label": "CT Scan",
|
||||
@ -202,7 +202,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 40,
|
||||
"id": 10,
|
||||
"type": "row",
|
||||
"props": {
|
||||
"variant": "default",
|
||||
@ -212,7 +212,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 41,
|
||||
"id": 11,
|
||||
"type": "image",
|
||||
"props": {
|
||||
"image_mode": "RGB",
|
||||
@ -247,7 +247,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 42,
|
||||
"id": 12,
|
||||
"type": "json",
|
||||
"props": {
|
||||
"show_label": true,
|
||||
@ -273,7 +273,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 43,
|
||||
"id": 13,
|
||||
"type": "button",
|
||||
"props": {
|
||||
"value": "Run",
|
||||
@ -295,7 +295,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 44,
|
||||
"id": 14,
|
||||
"type": "textbox",
|
||||
"props": {
|
||||
"value": "",
|
||||
@ -326,7 +326,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 45,
|
||||
"id": 15,
|
||||
"type": "form",
|
||||
"props": {
|
||||
"scale": 0,
|
||||
@ -335,7 +335,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 46,
|
||||
"id": 16,
|
||||
"type": "form",
|
||||
"props": {
|
||||
"scale": 0,
|
||||
@ -357,67 +357,67 @@
|
||||
],
|
||||
"theme": "default",
|
||||
"layout": {
|
||||
"id": 30,
|
||||
"id": 0,
|
||||
"children": [
|
||||
{
|
||||
"id": 31
|
||||
"id": 1
|
||||
},
|
||||
{
|
||||
"id": 45,
|
||||
"id": 15,
|
||||
"children": [
|
||||
{
|
||||
"id": 32
|
||||
"id": 2
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 33,
|
||||
"id": 3,
|
||||
"children": [
|
||||
{
|
||||
"id": 34,
|
||||
"id": 4,
|
||||
"children": [
|
||||
{
|
||||
"id": 35,
|
||||
"id": 5,
|
||||
"children": [
|
||||
{
|
||||
"id": 36
|
||||
"id": 6
|
||||
},
|
||||
{
|
||||
"id": 37
|
||||
"id": 7
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 38
|
||||
"id": 8
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 39,
|
||||
"id": 9,
|
||||
"children": [
|
||||
{
|
||||
"id": 40,
|
||||
"id": 10,
|
||||
"children": [
|
||||
{
|
||||
"id": 41
|
||||
"id": 11
|
||||
},
|
||||
{
|
||||
"id": 42
|
||||
"id": 12
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 43
|
||||
"id": 13
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 46,
|
||||
"id": 16,
|
||||
"children": [
|
||||
{
|
||||
"id": 44
|
||||
"id": 14
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -426,15 +426,17 @@
|
||||
"dependencies": [
|
||||
{
|
||||
"targets": [
|
||||
38
|
||||
[
|
||||
8,
|
||||
"click"
|
||||
]
|
||||
],
|
||||
"trigger": "click",
|
||||
"inputs": [
|
||||
32,
|
||||
36
|
||||
2,
|
||||
6
|
||||
],
|
||||
"outputs": [
|
||||
37
|
||||
7
|
||||
],
|
||||
"backend_fn": true,
|
||||
"js": null,
|
||||
@ -456,15 +458,17 @@
|
||||
},
|
||||
{
|
||||
"targets": [
|
||||
43
|
||||
[
|
||||
13,
|
||||
"click"
|
||||
]
|
||||
],
|
||||
"trigger": "click",
|
||||
"inputs": [
|
||||
32,
|
||||
41
|
||||
2,
|
||||
11
|
||||
],
|
||||
"outputs": [
|
||||
42
|
||||
12
|
||||
],
|
||||
"backend_fn": true,
|
||||
"js": null,
|
||||
@ -486,10 +490,9 @@
|
||||
},
|
||||
{
|
||||
"targets": [],
|
||||
"trigger": "load",
|
||||
"inputs": [],
|
||||
"outputs": [
|
||||
44
|
||||
14
|
||||
],
|
||||
"backend_fn": true,
|
||||
"js": null,
|
||||
|
@ -1,12 +1,12 @@
|
||||
{
|
||||
"version": "3.44.1",
|
||||
"version": "3.44.4",
|
||||
"mode": "blocks",
|
||||
"app_id": 3299865391549106311,
|
||||
"app_id": 5570383359191764024,
|
||||
"dev_mode": false,
|
||||
"analytics_enabled": false,
|
||||
"components": [
|
||||
{
|
||||
"id": 1031,
|
||||
"id": 101,
|
||||
"type": "markdown",
|
||||
"props": {
|
||||
"value": "# Detect Disease From Scan\nWith this model you can lorem ipsum\n- ipsum 1\n- ipsum 2",
|
||||
@ -35,7 +35,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1032,
|
||||
"id": 102,
|
||||
"type": "checkboxgroup",
|
||||
"props": {
|
||||
"choices": [
|
||||
@ -82,7 +82,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1033,
|
||||
"id": 103,
|
||||
"type": "tabs",
|
||||
"props": {
|
||||
"visible": true,
|
||||
@ -91,7 +91,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1034,
|
||||
"id": 104,
|
||||
"type": "tabitem",
|
||||
"props": {
|
||||
"label": "X-ray",
|
||||
@ -100,7 +100,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1035,
|
||||
"id": 105,
|
||||
"type": "row",
|
||||
"props": {
|
||||
"variant": "default",
|
||||
@ -110,7 +110,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1036,
|
||||
"id": 106,
|
||||
"type": "image",
|
||||
"props": {
|
||||
"image_mode": "RGB",
|
||||
@ -145,7 +145,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1037,
|
||||
"id": 107,
|
||||
"type": "json",
|
||||
"props": {
|
||||
"show_label": true,
|
||||
@ -171,7 +171,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1038,
|
||||
"id": 108,
|
||||
"type": "button",
|
||||
"props": {
|
||||
"value": "Run",
|
||||
@ -193,7 +193,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1039,
|
||||
"id": 109,
|
||||
"type": "tabitem",
|
||||
"props": {
|
||||
"label": "CT Scan",
|
||||
@ -202,7 +202,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1040,
|
||||
"id": 1010,
|
||||
"type": "row",
|
||||
"props": {
|
||||
"variant": "default",
|
||||
@ -212,7 +212,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1041,
|
||||
"id": 1011,
|
||||
"type": "image",
|
||||
"props": {
|
||||
"image_mode": "RGB",
|
||||
@ -247,7 +247,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1042,
|
||||
"id": 1012,
|
||||
"type": "json",
|
||||
"props": {
|
||||
"show_label": true,
|
||||
@ -273,7 +273,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1043,
|
||||
"id": 1013,
|
||||
"type": "button",
|
||||
"props": {
|
||||
"value": "Run",
|
||||
@ -295,7 +295,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1044,
|
||||
"id": 1014,
|
||||
"type": "textbox",
|
||||
"props": {
|
||||
"value": "",
|
||||
@ -326,7 +326,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1045,
|
||||
"id": 1015,
|
||||
"type": "form",
|
||||
"props": {
|
||||
"scale": 0,
|
||||
@ -335,7 +335,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 1046,
|
||||
"id": 1016,
|
||||
"type": "form",
|
||||
"props": {
|
||||
"scale": 0,
|
||||
@ -357,67 +357,67 @@
|
||||
],
|
||||
"theme": "default",
|
||||
"layout": {
|
||||
"id": 1030,
|
||||
"id": 100,
|
||||
"children": [
|
||||
{
|
||||
"id": 1031
|
||||
"id": 101
|
||||
},
|
||||
{
|
||||
"id": 1045,
|
||||
"id": 1015,
|
||||
"children": [
|
||||
{
|
||||
"id": 1032
|
||||
"id": 102
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 1033,
|
||||
"id": 103,
|
||||
"children": [
|
||||
{
|
||||
"id": 1034,
|
||||
"id": 104,
|
||||
"children": [
|
||||
{
|
||||
"id": 1035,
|
||||
"id": 105,
|
||||
"children": [
|
||||
{
|
||||
"id": 1036
|
||||
"id": 106
|
||||
},
|
||||
{
|
||||
"id": 1037
|
||||
"id": 107
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 1038
|
||||
"id": 108
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 1039,
|
||||
"id": 109,
|
||||
"children": [
|
||||
{
|
||||
"id": 1040,
|
||||
"id": 1010,
|
||||
"children": [
|
||||
{
|
||||
"id": 1041
|
||||
"id": 1011
|
||||
},
|
||||
{
|
||||
"id": 1042
|
||||
"id": 1012
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 1043
|
||||
"id": 1013
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 1046,
|
||||
"id": 1016,
|
||||
"children": [
|
||||
{
|
||||
"id": 1044
|
||||
"id": 1014
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -426,15 +426,17 @@
|
||||
"dependencies": [
|
||||
{
|
||||
"targets": [
|
||||
1038
|
||||
[
|
||||
108,
|
||||
"click"
|
||||
]
|
||||
],
|
||||
"trigger": "click",
|
||||
"inputs": [
|
||||
1032,
|
||||
1036
|
||||
102,
|
||||
106
|
||||
],
|
||||
"outputs": [
|
||||
1037
|
||||
107
|
||||
],
|
||||
"backend_fn": true,
|
||||
"js": null,
|
||||
@ -456,15 +458,17 @@
|
||||
},
|
||||
{
|
||||
"targets": [
|
||||
1043
|
||||
[
|
||||
1013,
|
||||
"click"
|
||||
]
|
||||
],
|
||||
"trigger": "click",
|
||||
"inputs": [
|
||||
1032,
|
||||
1041
|
||||
102,
|
||||
1011
|
||||
],
|
||||
"outputs": [
|
||||
1042
|
||||
1012
|
||||
],
|
||||
"backend_fn": true,
|
||||
"js": null,
|
||||
@ -486,10 +490,9 @@
|
||||
},
|
||||
{
|
||||
"targets": [],
|
||||
"trigger": "load",
|
||||
"inputs": [],
|
||||
"outputs": [
|
||||
1044
|
||||
1014
|
||||
],
|
||||
"backend_fn": true,
|
||||
"js": null,
|
||||
|
@ -274,7 +274,6 @@ class TestProcessExamples:
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(0)
|
||||
assert not any(d["trigger"] == "fake_event" for d in io.config["dependencies"])
|
||||
assert prediction == [
|
||||
{"lines": 4, "__type__": "update", "mode": "static"},
|
||||
{"label": "lion"},
|
||||
|
@ -238,7 +238,7 @@ class TestInterfaceInterpretation:
|
||||
interpretation_dep = next(
|
||||
d
|
||||
for d in iface.config["dependencies"]
|
||||
if d["targets"] == [interpretation_id]
|
||||
if d["targets"][0][0] == interpretation_id
|
||||
)
|
||||
interpretation_comps = [
|
||||
c["id"]
|
||||
@ -268,7 +268,7 @@ class TestInterfaceInterpretation:
|
||||
fn_index = next(
|
||||
i
|
||||
for i, d in enumerate(iface.config["dependencies"])
|
||||
if d["targets"] == [btn]
|
||||
if d["targets"][0][0] == btn
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
|
Loading…
Reference in New Issue
Block a user