diff --git a/.changeset/real-grapes-accept.md b/.changeset/real-grapes-accept.md new file mode 100644 index 0000000000..ba1ae724e4 --- /dev/null +++ b/.changeset/real-grapes-accept.md @@ -0,0 +1,7 @@ +--- +"@gradio/dataframe": minor +"gradio": minor +"website": minor +--- + +feat:Allows updating the dataset of a `gr.Examples` diff --git a/demo/image_mod/run.ipynb b/demo/image_mod/run.ipynb index b7b80afe2b..2239efc6f3 100644 --- a/demo/image_mod/run.ipynb +++ b/demo/image_mod/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: image_mod"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('images')\n", "!wget -q -O images/cheetah1.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/cheetah1.jpg\n", "!wget -q -O images/lion.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/lion.jpg\n", "!wget -q -O images/logo.png https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/logo.png\n", "!wget -q -O images/tower.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/tower.jpg"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "\n", "\n", "def image_mod(image):\n", " return image.rotate(45)\n", "\n", "\n", "demo = gr.Interface(\n", " image_mod,\n", " gr.Image(type=\"pil\"),\n", " \"image\",\n", " flagging_options=[\"blurry\", \"incorrect\", \"other\"],\n", " examples=[\n", " os.path.join(os.path.abspath(''), \"images/cheetah1.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/lion.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/logo.png\"),\n", " os.path.join(os.path.abspath(''), \"images/tower.jpg\"),\n", " ],\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: image_mod"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('images')\n", "!wget -q -O images/cheetah1.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/cheetah1.jpg\n", "!wget -q -O images/lion.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/lion.jpg\n", "!wget -q -O images/logo.png https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/logo.png\n", "!wget -q -O images/tower.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/tower.jpg"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "\n", "\n", "def image_mod(image):\n", " return image.rotate(45)\n", "\n", "new_samples = [\n", " [os.path.join(os.path.abspath(''), \"images/logo.png\")],\n", " [os.path.join(os.path.abspath(''), \"images/tower.jpg\")],\n", "]\n", "\n", "with gr.Blocks() as demo:\n", " interface = gr.Interface(\n", " image_mod,\n", " gr.Image(type=\"pil\"),\n", " \"image\",\n", " flagging_options=[\"blurry\", \"incorrect\", \"other\"],\n", " examples=[\n", " os.path.join(os.path.abspath(''), \"images/cheetah1.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/lion.jpg\"),\n", " ],\n", " )\n", "\n", " btn = gr.Button(\"Update Examples\")\n", " btn.click(lambda : gr.Dataset(samples=new_samples), None, interface.examples_handler.dataset)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/image_mod/run.py b/demo/image_mod/run.py index 88bfaf63f0..fa28c0e1fc 100644 --- a/demo/image_mod/run.py +++ b/demo/image_mod/run.py @@ -5,19 +5,25 @@ import os def image_mod(image): return image.rotate(45) +new_samples = [ + [os.path.join(os.path.dirname(__file__), "images/logo.png")], + [os.path.join(os.path.dirname(__file__), "images/tower.jpg")], +] -demo = gr.Interface( - image_mod, - gr.Image(type="pil"), - "image", - flagging_options=["blurry", "incorrect", "other"], - examples=[ - os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"), - os.path.join(os.path.dirname(__file__), "images/lion.jpg"), - os.path.join(os.path.dirname(__file__), "images/logo.png"), - os.path.join(os.path.dirname(__file__), "images/tower.jpg"), - ], -) +with gr.Blocks() as demo: + interface = gr.Interface( + image_mod, + gr.Image(type="pil"), + "image", + flagging_options=["blurry", "incorrect", "other"], + examples=[ + os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"), + os.path.join(os.path.dirname(__file__), "images/lion.jpg"), + ], + ) + + btn = gr.Button("Update Examples") + btn.click(lambda : gr.Dataset(samples=new_samples), None, interface.examples_handler.dataset) if __name__ == "__main__": demo.launch() diff --git a/gradio/blocks.py b/gradio/blocks.py index a79d3b83a5..d2b771c747 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -1713,6 +1713,7 @@ Received outputs: kwargs.pop("value", None) kwargs.pop("__type__") kwargs["render"] = False + state[block._id] = block.__class__(**kwargs) prediction_value = postprocess_update_dict( diff --git a/gradio/components/dataset.py b/gradio/components/dataset.py index 7eb4d2b1c9..1a093397f5 100644 --- a/gradio/components/dataset.py +++ b/gradio/components/dataset.py @@ -32,7 +32,7 @@ class Dataset(Component): component_props: list[dict[str, Any]] | None = None, samples: list[list[Any]] | None = None, headers: list[str] | None = None, - type: Literal["values", "index"] = "values", + type: Literal["values", "index", "tuple"] = "values", samples_per_page: int = 10, visible: bool = True, elem_id: str | None = None, @@ -51,7 +51,7 @@ class Dataset(Component): components: Which component types to show in this dataset widget, can be passed in as a list of string names or Components instances. The following components are supported in a Dataset: Audio, Checkbox, CheckboxGroup, ColorPicker, Dataframe, Dropdown, File, HTML, Image, Markdown, Model3D, Number, Radio, Slider, Textbox, TimeSeries, Video samples: a nested list of samples. Each sublist within the outer list represents a data sample, and each element within the sublist represents an value for each component headers: Column headers in the Dataset widget, should be the same len as components. If not provided, inferred from component labels - type: 'values' if clicking on a sample should pass the value of the sample, or "index" if it should pass the index of the sample + type: "values" if clicking on a sample should pass the value of the sample, "index" if it should pass the index of the sample, or "tuple" if it should pass both the index and the value of the sample. samples_per_page: how many examples to show per page. visible: If False, component will be hidden. elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles. @@ -95,8 +95,10 @@ class Dataset(Component): self.proxy_url = proxy_url for component in self._components: component.proxy_url = proxy_url - self.samples = [[]] if samples is None else samples - for example in self.samples: + self.raw_samples = [[]] if samples is None else samples + self.samples: list[list] = [] + for example in self.raw_samples: + self.samples.append([]) for i, (component, ex) in enumerate(zip(self._components, example)): # If proxy_url is set, that means it is being loaded from an external Gradio app # which means that the example has already been processed. @@ -104,9 +106,9 @@ class Dataset(Component): # The `as_example()` method has been renamed to `process_example()` but we # use the previous name to be backwards-compatible with previously-created # custom components - example[i] = component.as_example(ex) - example[i] = processing_utils.move_files_to_cache( - example[i], component, keep_in_cache=True + self.samples[-1].append(component.as_example(ex)) + self.samples[-1][i] = processing_utils.move_files_to_cache( + self.samples[-1][i], component, keep_in_cache=True ) self.type = type self.label = label @@ -137,19 +139,21 @@ class Dataset(Component): return config - def preprocess(self, payload: int | None) -> int | list | None: + def preprocess(self, payload: int | None) -> int | list | tuple[int, list] | None: """ Parameters: payload: the index of the selected example in the dataset Returns: - Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index") + Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index"), or as a `tuple` of the index and the data (if `type` is "tuple"). """ if payload is None: return None if self.type == "index": return payload elif self.type == "values": - return self.samples[payload] + return self.raw_samples[payload] + elif self.type == "tuple": + return payload, self.raw_samples[payload] def postprocess(self, sample: int | list | None) -> int | None: """ diff --git a/gradio/helpers.py b/gradio/helpers.py index 4f5323d617..ac3d590c6b 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -5,6 +5,7 @@ Defines helper methods useful for loading and caching Interface examples. from __future__ import annotations import ast +import copy import csv import inspect import os @@ -28,6 +29,7 @@ from gradio.data_classes import GradioModel, GradioRootModel from gradio.events import Dependency, EventData from gradio.exceptions import Error from gradio.flagging import CSVLogger +from gradio.utils import UnhashableKeyDict if TYPE_CHECKING: # Only import for type checking (to avoid circular imports). from gradio.components import Component @@ -239,6 +241,7 @@ class Examples: self.examples = examples self.non_none_examples = non_none_examples self.inputs = inputs + self.input_has_examples = input_has_examples self.inputs_with_examples = inputs_with_examples self.outputs = outputs or [] self.fn = fn @@ -248,35 +251,15 @@ class Examples: self.api_name: str | Literal[False] = api_name self.batch = batch self.example_labels = example_labels - - with utils.set_directory(working_directory): - self.processed_examples = [] - for example in examples: - sub = [] - for component, sample in zip(inputs, example): - prediction_value = component.postprocess(sample) - if isinstance(prediction_value, (GradioRootModel, GradioModel)): - prediction_value = prediction_value.model_dump() - prediction_value = processing_utils.move_files_to_cache( - prediction_value, - component, - postprocess=True, - ) - sub.append(prediction_value) - self.processed_examples.append(sub) - - self.non_none_processed_examples = [ - [ex for (ex, keep) in zip(example, input_has_examples) if keep] - for example in self.processed_examples - ] + self.working_directory = working_directory from gradio import components with utils.set_directory(working_directory): self.dataset = components.Dataset( components=inputs_with_examples, - samples=non_none_examples, - type="index", + samples=copy.deepcopy(non_none_examples), + type="tuple", label=label, samples_per_page=examples_per_page, elem_id=elem_id, @@ -290,13 +273,38 @@ class Examples: self.cached_indices_file = Path(self.cached_folder) / "indices.csv" self.run_on_click = run_on_click self.cache_event: Dependency | None = None + self.non_none_processed_examples = UnhashableKeyDict() + + if self.dataset.samples: + for index, example in enumerate(self.non_none_examples): + self.non_none_processed_examples[self.dataset.samples[index]] = ( + self._get_processed_example(example) + ) + + def _get_processed_example(self, example): + if example in self.non_none_processed_examples: + return self.non_none_processed_examples[example] + with utils.set_directory(self.working_directory): + sub = [] + for component, sample in zip(self.inputs, example): + prediction_value = component.postprocess(sample) + if isinstance(prediction_value, (GradioRootModel, GradioModel)): + prediction_value = prediction_value.model_dump() + prediction_value = processing_utils.move_files_to_cache( + prediction_value, + component, + postprocess=True, + ) + sub.append(prediction_value) + return [ex for (ex, keep) in zip(sub, self.input_has_examples) if keep] def create(self) -> None: """Caches the examples if self.cache_examples is True and creates the Dataset component to hold the examples""" - async def load_example(example_id): - processed_example = self.non_none_processed_examples[example_id] + async def load_example(example_tuple): + _, example_value = example_tuple + processed_example = self._get_processed_example(example_value) if len(self.inputs_with_examples) == 1: return update( value=processed_example[0], @@ -496,9 +504,9 @@ class Examples: if self.outputs is None: raise ValueError("self.outputs is missing") - for example_id in range(len(self.examples)): - print(f"Caching example {example_id + 1}/{len(self.examples)}") - processed_input = self.processed_examples[example_id] + for i, example in enumerate(self.examples): + print(f"Caching example {i + 1}/{len(self.examples)}") + processed_input = self._get_processed_example(example) if self.batch: processed_input = [[value] for value in processed_input] with utils.MatplotlibBackendMananger(): @@ -523,10 +531,11 @@ class Examples: # method to be called independently of the create() method blocks_config.fns.pop(self.load_input_event["id"]) - def load_example(example_id): - processed_example = self.non_none_processed_examples[ - example_id - ] + self.load_from_cache(example_id) + def load_example(example_tuple): + example_id, example_value = example_tuple + processed_example = self._get_processed_example( + example_value + ) + self.load_from_cache(example_id) return utils.resolve_singleton(processed_example) self.cache_event = self.load_input_event = self.dataset.click( diff --git a/gradio/utils.py b/gradio/utils.py index e2a47c633d..8a463095cf 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -25,6 +25,7 @@ import urllib.parse import warnings from abc import ABC, abstractmethod from collections import OrderedDict +from collections.abc import MutableMapping from contextlib import contextmanager from functools import wraps from io import BytesIO @@ -1424,3 +1425,42 @@ def error_payload( content["duration"] = error.duration content["visible"] = error.visible return content + + +class UnhashableKeyDict(MutableMapping): + """ + Essentially a list of key-value tuples that allows for keys that are not hashable, + but acts like a dictionary for convenience. + """ + + def __init__(self): + self.data = [] + + def __getitem__(self, key): + for k, v in self.data: + if deep_equal(k, key): + return v + raise KeyError(key) + + def __setitem__(self, key, value): + for i, (k, _) in enumerate(self.data): + if deep_equal(k, key): + self.data[i] = (key, value) + return + self.data.append((key, value)) + + def __delitem__(self, key): + for i, (k, _) in enumerate(self.data): + if deep_equal(k, key): + del self.data[i] + return + raise KeyError(key) + + def __iter__(self): + return (k for k, _ in self.data) + + def __len__(self): + return len(self.data) + + def as_list(self): + return [v for _, v in self.data] diff --git a/js/_website/src/lib/templates/gradio/04_helpers/11_examples.svx b/js/_website/src/lib/templates/gradio/04_helpers/11_examples.svx index e7f82d6ac6..ebc60da1f1 100644 --- a/js/_website/src/lib/templates/gradio/04_helpers/11_examples.svx +++ b/js/_website/src/lib/templates/gradio/04_helpers/11_examples.svx @@ -63,6 +63,31 @@ None +### Examples + +**Updating Examples** + +In this demo, we show how to update the examples by updating the samples of the underlying dataset. Note that this only works if `cache_examples=False` as updating the underlying dataset does not update the cache. + +```py +import gradio as gr + +def update_examples(country): + if country == "USA": + return gr.Dataset(samples=[["Chicago"], ["Little Rock"], ["San Francisco"]]) + else: + return gr.Dataset(samples=[["Islamabad"], ["Karachi"], ["Lahore"]]) + +with gr.Blocks() as demo: + dropdown = gr.Dropdown(label="Country", choices=["USA", "Pakistan"], value="USA") + textbox = gr.Textbox() + examples = gr.Examples([["Chicago"], ["Little Rock"], ["San Francisco"]], textbox) + dropdown.change(update_examples, dropdown, examples.dataset) + +demo.launch() +``` + + {#if obj.demos && obj.demos.length > 0} ### Demos diff --git a/js/app/test/image_mod.spec.ts b/js/app/test/image_mod.spec.ts new file mode 100644 index 0000000000..3ee40981f4 --- /dev/null +++ b/js/app/test/image_mod.spec.ts @@ -0,0 +1,20 @@ +import { test, expect } from "@gradio/tootils"; + +test("examples_get_updated_correctly", async ({ page }) => { + await page.locator(".gallery-item").first().click(); + let image = await page.getByTestId("image").locator("img").first(); + await expect(await image.getAttribute("src")).toContain("cheetah1.jpg"); + await page.getByRole("button", { name: "Update Examples" }).click(); + + let example_image; + await expect(async () => { + example_image = await page.locator(".gallery-item").locator("img").first(); + await expect(await example_image.getAttribute("src")).toContain("logo.png"); + }).toPass(); + + await example_image.click(); + await expect(async () => { + image = await page.getByTestId("image").locator("img").first(); + await expect(await image.getAttribute("src")).toContain("logo.png"); + }).toPass(); +}); diff --git a/js/dataframe/Example.svelte b/js/dataframe/Example.svelte index 84eb33de4f..faf029da18 100644 --- a/js/dataframe/Example.svelte +++ b/js/dataframe/Example.svelte @@ -5,8 +5,7 @@ export let index: number; let hovered = false; - let loaded_value: (string | number)[][] | string = value; - let loaded = Array.isArray(loaded_value); + let loaded = Array.isArray(value); {#if loaded} @@ -19,11 +18,11 @@ on:mouseenter={() => (hovered = true)} on:mouseleave={() => (hovered = false)} > - {#if typeof loaded_value === "string"} - {loaded_value} + {#if typeof value === "string"} + {value} {:else} - {#each loaded_value.slice(0, 3) as row, i} + {#each value.slice(0, 3) as row, i} {#each row.slice(0, 3) as cell, j} diff --git a/test/components/test_dataset.py b/test/components/test_dataset.py index 3101784777..2eee6b56e1 100644 --- a/test/components/test_dataset.py +++ b/test/components/test_dataset.py @@ -23,7 +23,7 @@ class TestDataset: row = dataset.preprocess(1) assert row[0] == 15 assert row[1] == "hi" - assert row[2]["path"].endswith("bus.png") + assert row[2].endswith("bus.png") assert row[3] == "Italics" assert row[4] == "*Italics*" diff --git a/test/test_helpers.py b/test/test_helpers.py index b48a405506..81b951f757 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -24,15 +24,15 @@ from gradio import helpers, utils class TestExamples: def test_handle_single_input(self, patched_cache_folder): examples = gr.Examples(["hello", "hi"], gr.Textbox()) - assert examples.processed_examples == [["hello"], ["hi"]] + assert examples.non_none_processed_examples.as_list() == [["hello"], ["hi"]] examples = gr.Examples([["hello"]], gr.Textbox()) - assert examples.processed_examples == [["hello"]] + assert examples.non_none_processed_examples.as_list() == [["hello"]] examples = gr.Examples(["test/test_files/bus.png"], gr.Image()) assert ( client_utils.encode_file_to_base64( - examples.processed_examples[0][0]["path"] + examples.non_none_processed_examples.as_list()[0][0]["path"] ) == media_data.BASE64_IMAGE ) @@ -41,18 +41,18 @@ class TestExamples: examples = gr.Examples( [["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()] ) - assert examples.processed_examples[0][0] == "hello" + assert examples.non_none_processed_examples.as_list()[0][0] == "hello" assert ( client_utils.encode_file_to_base64( - examples.processed_examples[0][1]["path"] + examples.non_none_processed_examples.as_list()[0][1]["path"] ) == media_data.BASE64_IMAGE ) def test_handle_directory(self, patched_cache_folder): examples = gr.Examples("test/test_files/images", gr.Image()) - assert len(examples.processed_examples) == 2 - for row in examples.processed_examples: + assert len(examples.non_none_processed_examples.as_list()) == 2 + for row in examples.non_none_processed_examples.as_list(): for output in row: assert ( client_utils.encode_file_to_base64(output["path"]) @@ -64,7 +64,7 @@ class TestExamples: "test/test_files/images_log", [gr.Image(label="im"), gr.Text()] ) ex = client_utils.traverse( - examples.processed_examples, + examples.non_none_processed_examples.as_list(), lambda s: client_utils.encode_file_to_base64(s["path"]), lambda x: isinstance(x, dict) and Path(x["path"]).exists(), ) diff --git a/test/test_utils.py b/test/test_utils.py index ba19a4bd35..06cc60c02d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,6 +7,7 @@ import warnings from pathlib import Path from unittest.mock import MagicMock, patch +import numpy as np import pytest from typing_extensions import Literal @@ -14,6 +15,7 @@ from gradio import EventData, Request from gradio.external_utils import format_ner_list from gradio.utils import ( FileSize, + UnhashableKeyDict, _parse_file_size, abspath, append_unique_suffix, @@ -474,3 +476,62 @@ def test_parse_file_size(): assert _parse_file_size("1kb") == 1 * FileSize.KB assert _parse_file_size("1mb") == 1 * FileSize.MB assert _parse_file_size("505 Mb") == 505 * FileSize.MB + + +class TestUnhashableKeyDict: + def test_set_get_simple(self): + d = UnhashableKeyDict() + d["a"] = 1 + assert d["a"] == 1 + + def test_set_get_unhashable(self): + d = UnhashableKeyDict() + key = [1, 2, 3] + key2 = [1, 2, 3] + d[key] = "value" + assert d[key] == "value" + assert d[key2] == "value" + + def test_set_get_numpy_array(self): + d = UnhashableKeyDict() + key = np.array([1, 2, 3]) + key2 = np.array([1, 2, 3]) + d[key] = "numpy value" + assert d[key2] == "numpy value" + + def test_overwrite(self): + d = UnhashableKeyDict() + d["key"] = "old" + d["key"] = "new" + assert d["key"] == "new" + + def test_delete(self): + d = UnhashableKeyDict() + d["key"] = "value" + del d["key"] + assert len(d) == 0 + with pytest.raises(KeyError): + d["key"] + + def test_delete_nonexistent(self): + d = UnhashableKeyDict() + with pytest.raises(KeyError): + del d["nonexistent"] + + def test_len(self): + d = UnhashableKeyDict() + assert len(d) == 0 + d["a"] = 1 + d["b"] = 2 + assert len(d) == 2 + + def test_contains(self): + d = UnhashableKeyDict() + d["key"] = "value" + assert "key" in d + assert "nonexistent" not in d + + def test_get_nonexistent(self): + d = UnhashableKeyDict() + with pytest.raises(KeyError): + d["nonexistent"]
{cell}