Improvements to gr.Examples: adds events as attributes and documents, them, adds sample_labels, and visible properties (#8733)

* events

* examples

* add changeset

* format

* add changeset

* add changeset

* format

* changes

* Update gradio/components/dataset.py

Co-authored-by: Ali Abdalla <ali.si3luwa@gmail.com>

* Update gradio/helpers.py

Co-authored-by: Ali Abdalla <ali.si3luwa@gmail.com>

* add test

* Update test/test_helpers.py

Co-authored-by: Ali Abdalla <ali.si3luwa@gmail.com>

* changes

* format

* add to interface as well

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Ali Abdalla <ali.si3luwa@gmail.com>
This commit is contained in:
Abubakar Abid 2024-07-10 16:35:36 -07:00 committed by GitHub
parent d15ada9a1c
commit fb0daf3730
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 100 additions and 8 deletions

View File

@ -0,0 +1,7 @@
---
"@gradio/dataset": minor
"gradio": minor
"website": minor
---
feat:Improvements to `gr.Examples`: adds events as attributes and documents, them, adds `sample_labels`, and `visible` properties

View File

@ -43,6 +43,7 @@ class Dataset(Component):
scale: int | None = None,
min_width: int = 160,
proxy_url: str | None = None,
sample_labels: list[str] | None = None,
):
"""
Parameters:
@ -61,6 +62,7 @@ class Dataset(Component):
scale: relative size compared to adjacent Components. For example if Components A and B are in a Row, and A has scale=2, and B has scale=1, A will be twice as wide as B. Should be an integer. scale applies in Rows, and to top-level Components in Blocks where fill_height=True.
min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.
proxy_url: The URL of the external Space used to load this component. Set automatically when using `gr.load()`. This should not be set manually.
sample_labels: A list of labels for each sample. If provided, the length of this list should be the same as the number of samples, and these labels will be used in the UI instead of rendering the sample values.
"""
super().__init__(
visible=visible,
@ -115,6 +117,7 @@ class Dataset(Component):
else:
self.headers = [c.label or "" for c in self._components]
self.samples_per_page = samples_per_page
self.sample_labels = sample_labels
def api_info(self) -> dict[str, str]:
return {"type": "integer", "description": "index of selected example"}
@ -124,6 +127,7 @@ class Dataset(Component):
config["components"] = []
config["component_props"] = self.component_props
config["sample_labels"] = self.sample_labels
config["component_ids"] = []
for component in self._components:

View File

@ -25,7 +25,7 @@ from gradio_client.documentation import document
from gradio import components, oauth, processing_utils, routes, utils, wasm_utils
from gradio.context import Context, LocalContext, get_blocks_context
from gradio.data_classes import GradioModel, GradioRootModel
from gradio.events import EventData
from gradio.events import Dependency, EventData
from gradio.exceptions import Error
from gradio.flagging import CSVLogger
@ -50,6 +50,9 @@ def create_examples(
postprocess: bool = True,
api_name: str | Literal[False] = "load_example",
batch: bool = False,
*,
example_labels: list[str] | None = None,
visible: bool = True,
_defer_caching: bool = False,
):
"""Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component."""
@ -69,6 +72,8 @@ def create_examples(
api_name=api_name,
batch=batch,
_defer_caching=_defer_caching,
example_labels=example_labels,
visible=visible,
_initiated_directly=False,
)
examples_obj.create()
@ -103,6 +108,9 @@ class Examples:
postprocess: bool = True,
api_name: str | Literal[False] = "load_example",
batch: bool = False,
*,
example_labels: list[str] | None = None,
visible: bool = True,
_defer_caching: bool = False,
_initiated_directly: bool = True,
):
@ -121,6 +129,8 @@ class Examples:
postprocess: if True, postprocesses the example output after running the prediction function and before caching. Only applies if `cache_examples` is not False.
api_name: Defines how the event associated with clicking on the examples appears in the API docs. Can be a string or False. If set to a string, the endpoint will be exposed in the API docs with the given name. If False, the endpoint will not be exposed in the API docs and downstream apps (including those that `gr.load` this app) will not be able to use the example function.
batch: If True, then the function should process a batch of inputs, meaning that it should accept a list of input values for each parameter. Used only if cache_examples is not False.
example_labels: A list of labels for each example. If provided, the length of this list should be the same as the number of examples, and these labels will be used in the UI instead of rendering the example values.
visible: If False, the examples component will be hidden in the UI.
"""
if _initiated_directly:
warnings.warn(
@ -221,6 +231,10 @@ class Examples:
[ex for (ex, keep) in zip(example, input_has_examples) if keep]
for example in examples
]
if example_labels is not None and len(example_labels) != len(examples):
raise ValueError(
"If `example_labels` are provided, the length of `example_labels` must be the same as the number of examples."
)
self.examples = examples
self.non_none_examples = non_none_examples
@ -233,6 +247,7 @@ class Examples:
self.postprocess = postprocess
self.api_name: str | Literal[False] = api_name
self.batch = batch
self.example_labels = example_labels
with utils.set_directory(working_directory):
self.processed_examples = []
@ -265,6 +280,8 @@ class Examples:
label=label,
samples_per_page=examples_per_page,
elem_id=elem_id,
visible=visible,
sample_labels=example_labels,
)
self.cache_logger = CSVLogger(simplify_file_data=False)
@ -272,6 +289,7 @@ class Examples:
self.cached_file = Path(self.cached_folder) / "log.csv"
self.cached_indices_file = Path(self.cached_folder) / "indices.csv"
self.run_on_click = run_on_click
self.cache_event: Dependency | None = None
def create(self) -> None:
"""Caches the examples if self.cache_examples is True and creates the Dataset
@ -380,7 +398,7 @@ class Examples:
lazy_cache_fn = self.async_lazy_cache
else:
lazy_cache_fn = self.sync_lazy_cache
self.load_input_event.then(
self.cache_event = self.load_input_event.then(
lazy_cache_fn,
inputs=[self.dataset] + self.inputs,
outputs=self.outputs,
@ -466,7 +484,7 @@ class Examples:
# create a fake dependency to process the examples and get the predictions
from gradio.events import EventListenerMethod
dependency, fn_index = blocks_config.set_event_trigger(
_, fn_index = blocks_config.set_event_trigger(
[EventListenerMethod(Context.root_block, "load")],
fn=fn,
inputs=self.inputs_with_examples, # type: ignore
@ -511,7 +529,7 @@ class Examples:
] + self.load_from_cache(example_id)
return utils.resolve_singleton(processed_example)
self.load_input_event = self.dataset.click(
self.cache_event = self.load_input_event = self.dataset.click(
load_example,
inputs=[self.dataset],
outputs=self.inputs_with_examples + self.outputs, # type: ignore

View File

@ -132,6 +132,7 @@ class Interface(Blocks):
clear_btn: str | Button | None = "Clear",
delete_cache: tuple[int, int] | None = None,
show_progress: Literal["full", "minimal", "hidden"] = "full",
example_labels: list[str] | None = None,
**kwargs,
):
"""
@ -168,6 +169,7 @@ class Interface(Blocks):
clear_btn: The button to use for clearing the inputs. Defaults to a `gr.Button("Clear", variant="secondary")`. Can be set to a string (which becomes the button label) or a `gr.Button` object (which allows for more customization). Can be set to None, which hides the button.
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
show_progress: whether to show progress animation while running. Has no effect if the interface is `live`.
example_labels: A list of labels for each example. If provided, the length of this list should be the same as the number of examples, and these labels will be used in the UI instead of rendering the example values.
"""
super().__init__(
analytics_enabled=analytics_enabled,
@ -320,6 +322,7 @@ class Interface(Blocks):
self.examples = examples
self.examples_per_page = examples_per_page
self.example_labels = example_labels
if isinstance(submit_btn, Button):
self.submit_btn_parms = submit_btn.recover_kwargs(submit_btn.get_config())
@ -879,6 +882,7 @@ class Interface(Blocks):
examples_per_page=self.examples_per_page,
_api_mode=self.api_mode,
batch=self.batch,
example_labels=self.example_labels,
)
def __str__(self):

View File

@ -10,6 +10,27 @@
import { style_formatted_text } from "$lib/text";
let obj = get_object("examples");
obj["attributes"] = [
{
name: "dataset",
annotation: "gradio.Dataset",
doc: "The `gr.Dataset` component corresponding to this Examples object.",
kwargs: null
},
{
name: "load_input_event",
annotation: "gradio.events.Dependency",
doc: "The Gradio event that populates the input values when the examples are clicked. You can attach a `.then()` or a `.success()` to this event to trigger subsequent events to fire after this event.",
kwargs: null
},
{
name: "cache_event",
annotation: "gradio.events.Dependency | None",
doc: "The Gradio event that populates the cached output values when the examples are clicked. You can attach a `.then()` or a `.success()` to this event to trigger subsequent events to fire after this event. This event is `None` if `cache_examples` if False, and is the same as `load_input_event` if `cache_examples` is `'lazy'`.",
kwargs: null
}
]
</script>
<!--- Title -->
@ -37,6 +58,10 @@ None
### Initialization
<ParamTable parameters={obj.parameters} />
<!--- Attributes -->
### Attributes
<ParamTable parameters={obj.attributes} />
{#if obj.demos && obj.demos.length > 0}
<!--- Demos -->

View File

@ -2,6 +2,7 @@
import { Block } from "@gradio/atoms";
import type { SvelteComponent, ComponentType } from "svelte";
import type { Gradio, SelectData } from "@gradio/utils";
import { BaseExample } from "@gradio/textbox";
export let components: string[];
export let component_props: Record<string, any>[];
export let component_map: Map<
@ -13,6 +14,7 @@
export let label = "Examples";
export let headers: string[];
export let samples: any[][] | null = null;
export let sample_labels: string[] | null = null;
export let elem_id = "";
export let elem_classes: string[] = [];
export let visible = true;
@ -33,7 +35,7 @@
? `/proxy=${proxy_url}file=`
: `${root}/file=`;
let page = 0;
$: gallery = components.length < 2;
$: gallery = components.length < 2 || sample_labels !== null;
let paginate = samples ? samples.length > samples_per_page : false;
let selected_samples: any[][];
@ -51,7 +53,7 @@
}
$: {
samples = samples || [];
samples = sample_labels ? sample_labels.map((e) => [e]) : samples || [];
paginate = samples.length > samples_per_page;
if (paginate) {
visible_pages = [];
@ -146,7 +148,13 @@
on:mouseenter={() => handle_mouseenter(i)}
on:mouseleave={() => handle_mouseleave()}
>
{#if component_meta.length && component_map.get(components[0])}
{#if sample_labels}
<BaseExample
value={sample_row[0]}
selected={current_hover === i}
type="gallery"
/>
{:else if component_meta.length && component_map.get(components[0])}
<svelte:component
this={component_meta[0][0].component}
{...component_props[0]}

View File

@ -15,7 +15,8 @@
"@gradio/atoms": "workspace:^",
"@gradio/client": "workspace:^",
"@gradio/utils": "workspace:^",
"@gradio/upload": "workspace:^"
"@gradio/upload": "workspace:^",
"@gradio/textbox": "workspace:^"
},
"devDependencies": {
"@gradio/preview": "workspace:^"

3
pnpm-lock.yaml generated
View File

@ -1047,6 +1047,9 @@ importers:
'@gradio/client':
specifier: workspace:^
version: link:../../client/js
'@gradio/textbox':
specifier: workspace:^
version: link:../textbox
'@gradio/upload':
specifier: workspace:^
version: link:../upload

View File

@ -168,6 +168,28 @@ class TestExamplesDataset:
)
assert examples.dataset.headers == ["im", ""]
def test_example_labels(self, patched_cache_folder):
examples = gr.Examples(
examples=[
[5, "add", 3],
[4, "divide", 2],
[-4, "multiply", 2.5],
[0, "subtract", 1.2],
],
inputs=[
gr.Number(),
gr.Radio(["add", "divide", "multiply", "subtract"]),
gr.Number(),
],
example_labels=["add", "divide", "multiply", "subtract"],
)
assert examples.dataset.sample_labels == [
"add",
"divide",
"multiply",
"subtract",
]
def test_example_caching_relaunch(connect):
def combine(a, b):