diff --git a/.changeset/real-grapes-accept.md b/.changeset/real-grapes-accept.md
new file mode 100644
index 0000000000..ba1ae724e4
--- /dev/null
+++ b/.changeset/real-grapes-accept.md
@@ -0,0 +1,7 @@
+---
+"@gradio/dataframe": minor
+"gradio": minor
+"website": minor
+---
+
+feat:Allows updating the dataset of a `gr.Examples`
diff --git a/demo/image_mod/run.ipynb b/demo/image_mod/run.ipynb
index b7b80afe2b..2239efc6f3 100644
--- a/demo/image_mod/run.ipynb
+++ b/demo/image_mod/run.ipynb
@@ -1 +1 @@
-{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: image_mod"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('images')\n", "!wget -q -O images/cheetah1.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/cheetah1.jpg\n", "!wget -q -O images/lion.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/lion.jpg\n", "!wget -q -O images/logo.png https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/logo.png\n", "!wget -q -O images/tower.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/tower.jpg"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "\n", "\n", "def image_mod(image):\n", " return image.rotate(45)\n", "\n", "\n", "demo = gr.Interface(\n", " image_mod,\n", " gr.Image(type=\"pil\"),\n", " \"image\",\n", " flagging_options=[\"blurry\", \"incorrect\", \"other\"],\n", " examples=[\n", " os.path.join(os.path.abspath(''), \"images/cheetah1.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/lion.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/logo.png\"),\n", " os.path.join(os.path.abspath(''), \"images/tower.jpg\"),\n", " ],\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
\ No newline at end of file
+{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: image_mod"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('images')\n", "!wget -q -O images/cheetah1.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/cheetah1.jpg\n", "!wget -q -O images/lion.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/lion.jpg\n", "!wget -q -O images/logo.png https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/logo.png\n", "!wget -q -O images/tower.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/tower.jpg"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "\n", "\n", "def image_mod(image):\n", " return image.rotate(45)\n", "\n", "new_samples = [\n", " [os.path.join(os.path.abspath(''), \"images/logo.png\")],\n", " [os.path.join(os.path.abspath(''), \"images/tower.jpg\")],\n", "]\n", "\n", "with gr.Blocks() as demo:\n", " interface = gr.Interface(\n", " image_mod,\n", " gr.Image(type=\"pil\"),\n", " \"image\",\n", " flagging_options=[\"blurry\", \"incorrect\", \"other\"],\n", " examples=[\n", " os.path.join(os.path.abspath(''), \"images/cheetah1.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/lion.jpg\"),\n", " ],\n", " )\n", "\n", " btn = gr.Button(\"Update Examples\")\n", " btn.click(lambda : gr.Dataset(samples=new_samples), None, interface.examples_handler.dataset)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
\ No newline at end of file
diff --git a/demo/image_mod/run.py b/demo/image_mod/run.py
index 88bfaf63f0..fa28c0e1fc 100644
--- a/demo/image_mod/run.py
+++ b/demo/image_mod/run.py
@@ -5,19 +5,25 @@ import os
def image_mod(image):
return image.rotate(45)
+new_samples = [
+ [os.path.join(os.path.dirname(__file__), "images/logo.png")],
+ [os.path.join(os.path.dirname(__file__), "images/tower.jpg")],
+]
-demo = gr.Interface(
- image_mod,
- gr.Image(type="pil"),
- "image",
- flagging_options=["blurry", "incorrect", "other"],
- examples=[
- os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
- os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
- os.path.join(os.path.dirname(__file__), "images/logo.png"),
- os.path.join(os.path.dirname(__file__), "images/tower.jpg"),
- ],
-)
+with gr.Blocks() as demo:
+ interface = gr.Interface(
+ image_mod,
+ gr.Image(type="pil"),
+ "image",
+ flagging_options=["blurry", "incorrect", "other"],
+ examples=[
+ os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
+ os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
+ ],
+ )
+
+ btn = gr.Button("Update Examples")
+ btn.click(lambda : gr.Dataset(samples=new_samples), None, interface.examples_handler.dataset)
if __name__ == "__main__":
demo.launch()
diff --git a/gradio/blocks.py b/gradio/blocks.py
index a79d3b83a5..d2b771c747 100644
--- a/gradio/blocks.py
+++ b/gradio/blocks.py
@@ -1713,6 +1713,7 @@ Received outputs:
kwargs.pop("value", None)
kwargs.pop("__type__")
kwargs["render"] = False
+
state[block._id] = block.__class__(**kwargs)
prediction_value = postprocess_update_dict(
diff --git a/gradio/components/dataset.py b/gradio/components/dataset.py
index 7eb4d2b1c9..1a093397f5 100644
--- a/gradio/components/dataset.py
+++ b/gradio/components/dataset.py
@@ -32,7 +32,7 @@ class Dataset(Component):
component_props: list[dict[str, Any]] | None = None,
samples: list[list[Any]] | None = None,
headers: list[str] | None = None,
- type: Literal["values", "index"] = "values",
+ type: Literal["values", "index", "tuple"] = "values",
samples_per_page: int = 10,
visible: bool = True,
elem_id: str | None = None,
@@ -51,7 +51,7 @@ class Dataset(Component):
components: Which component types to show in this dataset widget, can be passed in as a list of string names or Components instances. The following components are supported in a Dataset: Audio, Checkbox, CheckboxGroup, ColorPicker, Dataframe, Dropdown, File, HTML, Image, Markdown, Model3D, Number, Radio, Slider, Textbox, TimeSeries, Video
samples: a nested list of samples. Each sublist within the outer list represents a data sample, and each element within the sublist represents an value for each component
headers: Column headers in the Dataset widget, should be the same len as components. If not provided, inferred from component labels
- type: 'values' if clicking on a sample should pass the value of the sample, or "index" if it should pass the index of the sample
+ type: "values" if clicking on a sample should pass the value of the sample, "index" if it should pass the index of the sample, or "tuple" if it should pass both the index and the value of the sample.
samples_per_page: how many examples to show per page.
visible: If False, component will be hidden.
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
@@ -95,8 +95,10 @@ class Dataset(Component):
self.proxy_url = proxy_url
for component in self._components:
component.proxy_url = proxy_url
- self.samples = [[]] if samples is None else samples
- for example in self.samples:
+ self.raw_samples = [[]] if samples is None else samples
+ self.samples: list[list] = []
+ for example in self.raw_samples:
+ self.samples.append([])
for i, (component, ex) in enumerate(zip(self._components, example)):
# If proxy_url is set, that means it is being loaded from an external Gradio app
# which means that the example has already been processed.
@@ -104,9 +106,9 @@ class Dataset(Component):
# The `as_example()` method has been renamed to `process_example()` but we
# use the previous name to be backwards-compatible with previously-created
# custom components
- example[i] = component.as_example(ex)
- example[i] = processing_utils.move_files_to_cache(
- example[i], component, keep_in_cache=True
+ self.samples[-1].append(component.as_example(ex))
+ self.samples[-1][i] = processing_utils.move_files_to_cache(
+ self.samples[-1][i], component, keep_in_cache=True
)
self.type = type
self.label = label
@@ -137,19 +139,21 @@ class Dataset(Component):
return config
- def preprocess(self, payload: int | None) -> int | list | None:
+ def preprocess(self, payload: int | None) -> int | list | tuple[int, list] | None:
"""
Parameters:
payload: the index of the selected example in the dataset
Returns:
- Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index")
+ Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index"), or as a `tuple` of the index and the data (if `type` is "tuple").
"""
if payload is None:
return None
if self.type == "index":
return payload
elif self.type == "values":
- return self.samples[payload]
+ return self.raw_samples[payload]
+ elif self.type == "tuple":
+ return payload, self.raw_samples[payload]
def postprocess(self, sample: int | list | None) -> int | None:
"""
diff --git a/gradio/helpers.py b/gradio/helpers.py
index 4f5323d617..ac3d590c6b 100644
--- a/gradio/helpers.py
+++ b/gradio/helpers.py
@@ -5,6 +5,7 @@ Defines helper methods useful for loading and caching Interface examples.
from __future__ import annotations
import ast
+import copy
import csv
import inspect
import os
@@ -28,6 +29,7 @@ from gradio.data_classes import GradioModel, GradioRootModel
from gradio.events import Dependency, EventData
from gradio.exceptions import Error
from gradio.flagging import CSVLogger
+from gradio.utils import UnhashableKeyDict
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio.components import Component
@@ -239,6 +241,7 @@ class Examples:
self.examples = examples
self.non_none_examples = non_none_examples
self.inputs = inputs
+ self.input_has_examples = input_has_examples
self.inputs_with_examples = inputs_with_examples
self.outputs = outputs or []
self.fn = fn
@@ -248,35 +251,15 @@ class Examples:
self.api_name: str | Literal[False] = api_name
self.batch = batch
self.example_labels = example_labels
-
- with utils.set_directory(working_directory):
- self.processed_examples = []
- for example in examples:
- sub = []
- for component, sample in zip(inputs, example):
- prediction_value = component.postprocess(sample)
- if isinstance(prediction_value, (GradioRootModel, GradioModel)):
- prediction_value = prediction_value.model_dump()
- prediction_value = processing_utils.move_files_to_cache(
- prediction_value,
- component,
- postprocess=True,
- )
- sub.append(prediction_value)
- self.processed_examples.append(sub)
-
- self.non_none_processed_examples = [
- [ex for (ex, keep) in zip(example, input_has_examples) if keep]
- for example in self.processed_examples
- ]
+ self.working_directory = working_directory
from gradio import components
with utils.set_directory(working_directory):
self.dataset = components.Dataset(
components=inputs_with_examples,
- samples=non_none_examples,
- type="index",
+ samples=copy.deepcopy(non_none_examples),
+ type="tuple",
label=label,
samples_per_page=examples_per_page,
elem_id=elem_id,
@@ -290,13 +273,38 @@ class Examples:
self.cached_indices_file = Path(self.cached_folder) / "indices.csv"
self.run_on_click = run_on_click
self.cache_event: Dependency | None = None
+ self.non_none_processed_examples = UnhashableKeyDict()
+
+ if self.dataset.samples:
+ for index, example in enumerate(self.non_none_examples):
+ self.non_none_processed_examples[self.dataset.samples[index]] = (
+ self._get_processed_example(example)
+ )
+
+ def _get_processed_example(self, example):
+ if example in self.non_none_processed_examples:
+ return self.non_none_processed_examples[example]
+ with utils.set_directory(self.working_directory):
+ sub = []
+ for component, sample in zip(self.inputs, example):
+ prediction_value = component.postprocess(sample)
+ if isinstance(prediction_value, (GradioRootModel, GradioModel)):
+ prediction_value = prediction_value.model_dump()
+ prediction_value = processing_utils.move_files_to_cache(
+ prediction_value,
+ component,
+ postprocess=True,
+ )
+ sub.append(prediction_value)
+ return [ex for (ex, keep) in zip(sub, self.input_has_examples) if keep]
def create(self) -> None:
"""Caches the examples if self.cache_examples is True and creates the Dataset
component to hold the examples"""
- async def load_example(example_id):
- processed_example = self.non_none_processed_examples[example_id]
+ async def load_example(example_tuple):
+ _, example_value = example_tuple
+ processed_example = self._get_processed_example(example_value)
if len(self.inputs_with_examples) == 1:
return update(
value=processed_example[0],
@@ -496,9 +504,9 @@ class Examples:
if self.outputs is None:
raise ValueError("self.outputs is missing")
- for example_id in range(len(self.examples)):
- print(f"Caching example {example_id + 1}/{len(self.examples)}")
- processed_input = self.processed_examples[example_id]
+ for i, example in enumerate(self.examples):
+ print(f"Caching example {i + 1}/{len(self.examples)}")
+ processed_input = self._get_processed_example(example)
if self.batch:
processed_input = [[value] for value in processed_input]
with utils.MatplotlibBackendMananger():
@@ -523,10 +531,11 @@ class Examples:
# method to be called independently of the create() method
blocks_config.fns.pop(self.load_input_event["id"])
- def load_example(example_id):
- processed_example = self.non_none_processed_examples[
- example_id
- ] + self.load_from_cache(example_id)
+ def load_example(example_tuple):
+ example_id, example_value = example_tuple
+ processed_example = self._get_processed_example(
+ example_value
+ ) + self.load_from_cache(example_id)
return utils.resolve_singleton(processed_example)
self.cache_event = self.load_input_event = self.dataset.click(
diff --git a/gradio/utils.py b/gradio/utils.py
index e2a47c633d..8a463095cf 100644
--- a/gradio/utils.py
+++ b/gradio/utils.py
@@ -25,6 +25,7 @@ import urllib.parse
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict
+from collections.abc import MutableMapping
from contextlib import contextmanager
from functools import wraps
from io import BytesIO
@@ -1424,3 +1425,42 @@ def error_payload(
content["duration"] = error.duration
content["visible"] = error.visible
return content
+
+
+class UnhashableKeyDict(MutableMapping):
+ """
+ Essentially a list of key-value tuples that allows for keys that are not hashable,
+ but acts like a dictionary for convenience.
+ """
+
+ def __init__(self):
+ self.data = []
+
+ def __getitem__(self, key):
+ for k, v in self.data:
+ if deep_equal(k, key):
+ return v
+ raise KeyError(key)
+
+ def __setitem__(self, key, value):
+ for i, (k, _) in enumerate(self.data):
+ if deep_equal(k, key):
+ self.data[i] = (key, value)
+ return
+ self.data.append((key, value))
+
+ def __delitem__(self, key):
+ for i, (k, _) in enumerate(self.data):
+ if deep_equal(k, key):
+ del self.data[i]
+ return
+ raise KeyError(key)
+
+ def __iter__(self):
+ return (k for k, _ in self.data)
+
+ def __len__(self):
+ return len(self.data)
+
+ def as_list(self):
+ return [v for _, v in self.data]
diff --git a/js/_website/src/lib/templates/gradio/04_helpers/11_examples.svx b/js/_website/src/lib/templates/gradio/04_helpers/11_examples.svx
index e7f82d6ac6..ebc60da1f1 100644
--- a/js/_website/src/lib/templates/gradio/04_helpers/11_examples.svx
+++ b/js/_website/src/lib/templates/gradio/04_helpers/11_examples.svx
@@ -63,6 +63,31 @@ None
{cell} | diff --git a/test/components/test_dataset.py b/test/components/test_dataset.py index 3101784777..2eee6b56e1 100644 --- a/test/components/test_dataset.py +++ b/test/components/test_dataset.py @@ -23,7 +23,7 @@ class TestDataset: row = dataset.preprocess(1) assert row[0] == 15 assert row[1] == "hi" - assert row[2]["path"].endswith("bus.png") + assert row[2].endswith("bus.png") assert row[3] == "Italics" assert row[4] == "*Italics*" diff --git a/test/test_helpers.py b/test/test_helpers.py index b48a405506..81b951f757 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -24,15 +24,15 @@ from gradio import helpers, utils class TestExamples: def test_handle_single_input(self, patched_cache_folder): examples = gr.Examples(["hello", "hi"], gr.Textbox()) - assert examples.processed_examples == [["hello"], ["hi"]] + assert examples.non_none_processed_examples.as_list() == [["hello"], ["hi"]] examples = gr.Examples([["hello"]], gr.Textbox()) - assert examples.processed_examples == [["hello"]] + assert examples.non_none_processed_examples.as_list() == [["hello"]] examples = gr.Examples(["test/test_files/bus.png"], gr.Image()) assert ( client_utils.encode_file_to_base64( - examples.processed_examples[0][0]["path"] + examples.non_none_processed_examples.as_list()[0][0]["path"] ) == media_data.BASE64_IMAGE ) @@ -41,18 +41,18 @@ class TestExamples: examples = gr.Examples( [["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()] ) - assert examples.processed_examples[0][0] == "hello" + assert examples.non_none_processed_examples.as_list()[0][0] == "hello" assert ( client_utils.encode_file_to_base64( - examples.processed_examples[0][1]["path"] + examples.non_none_processed_examples.as_list()[0][1]["path"] ) == media_data.BASE64_IMAGE ) def test_handle_directory(self, patched_cache_folder): examples = gr.Examples("test/test_files/images", gr.Image()) - assert len(examples.processed_examples) == 2 - for row in examples.processed_examples: + assert len(examples.non_none_processed_examples.as_list()) == 2 + for row in examples.non_none_processed_examples.as_list(): for output in row: assert ( client_utils.encode_file_to_base64(output["path"]) @@ -64,7 +64,7 @@ class TestExamples: "test/test_files/images_log", [gr.Image(label="im"), gr.Text()] ) ex = client_utils.traverse( - examples.processed_examples, + examples.non_none_processed_examples.as_list(), lambda s: client_utils.encode_file_to_base64(s["path"]), lambda x: isinstance(x, dict) and Path(x["path"]).exists(), ) diff --git a/test/test_utils.py b/test/test_utils.py index ba19a4bd35..06cc60c02d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,6 +7,7 @@ import warnings from pathlib import Path from unittest.mock import MagicMock, patch +import numpy as np import pytest from typing_extensions import Literal @@ -14,6 +15,7 @@ from gradio import EventData, Request from gradio.external_utils import format_ner_list from gradio.utils import ( FileSize, + UnhashableKeyDict, _parse_file_size, abspath, append_unique_suffix, @@ -474,3 +476,62 @@ def test_parse_file_size(): assert _parse_file_size("1kb") == 1 * FileSize.KB assert _parse_file_size("1mb") == 1 * FileSize.MB assert _parse_file_size("505 Mb") == 505 * FileSize.MB + + +class TestUnhashableKeyDict: + def test_set_get_simple(self): + d = UnhashableKeyDict() + d["a"] = 1 + assert d["a"] == 1 + + def test_set_get_unhashable(self): + d = UnhashableKeyDict() + key = [1, 2, 3] + key2 = [1, 2, 3] + d[key] = "value" + assert d[key] == "value" + assert d[key2] == "value" + + def test_set_get_numpy_array(self): + d = UnhashableKeyDict() + key = np.array([1, 2, 3]) + key2 = np.array([1, 2, 3]) + d[key] = "numpy value" + assert d[key2] == "numpy value" + + def test_overwrite(self): + d = UnhashableKeyDict() + d["key"] = "old" + d["key"] = "new" + assert d["key"] == "new" + + def test_delete(self): + d = UnhashableKeyDict() + d["key"] = "value" + del d["key"] + assert len(d) == 0 + with pytest.raises(KeyError): + d["key"] + + def test_delete_nonexistent(self): + d = UnhashableKeyDict() + with pytest.raises(KeyError): + del d["nonexistent"] + + def test_len(self): + d = UnhashableKeyDict() + assert len(d) == 0 + d["a"] = 1 + d["b"] = 2 + assert len(d) == 2 + + def test_contains(self): + d = UnhashableKeyDict() + d["key"] = "value" + assert "key" in d + assert "nonexistent" not in d + + def test_get_nonexistent(self): + d = UnhashableKeyDict() + with pytest.raises(KeyError): + d["nonexistent"]