Allows updating the dataset of a gr.Examples (#8745)

* helpers

* add changeset

* changes

* add changeset

* changes

* tweak

* format

* example to docs

* add changeset

* fixes

* add tuple

* add changeset

* print

* format

* clean'

* clean

* format

* format backend

* fix backend tests

* format

* notebooks

* comment

* delete demo

* add changeset

* docstring

* docstring

* changes

* add changeset

* components

* changes

* changes

* format

* add test

* fix python test

* use deep_equal

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-07-15 09:19:52 -07:00 committed by GitHub
parent 2d179f647b
commit 4030f28af6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 241 additions and 69 deletions

View File

@ -0,0 +1,7 @@
---
"@gradio/dataframe": minor
"gradio": minor
"website": minor
---
feat:Allows updating the dataset of a `gr.Examples`

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: image_mod"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('images')\n", "!wget -q -O images/cheetah1.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/cheetah1.jpg\n", "!wget -q -O images/lion.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/lion.jpg\n", "!wget -q -O images/logo.png https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/logo.png\n", "!wget -q -O images/tower.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/tower.jpg"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "\n", "\n", "def image_mod(image):\n", " return image.rotate(45)\n", "\n", "\n", "demo = gr.Interface(\n", " image_mod,\n", " gr.Image(type=\"pil\"),\n", " \"image\",\n", " flagging_options=[\"blurry\", \"incorrect\", \"other\"],\n", " examples=[\n", " os.path.join(os.path.abspath(''), \"images/cheetah1.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/lion.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/logo.png\"),\n", " os.path.join(os.path.abspath(''), \"images/tower.jpg\"),\n", " ],\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: image_mod"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('images')\n", "!wget -q -O images/cheetah1.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/cheetah1.jpg\n", "!wget -q -O images/lion.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/lion.jpg\n", "!wget -q -O images/logo.png https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/logo.png\n", "!wget -q -O images/tower.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/tower.jpg"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "\n", "\n", "def image_mod(image):\n", " return image.rotate(45)\n", "\n", "new_samples = [\n", " [os.path.join(os.path.abspath(''), \"images/logo.png\")],\n", " [os.path.join(os.path.abspath(''), \"images/tower.jpg\")],\n", "]\n", "\n", "with gr.Blocks() as demo:\n", " interface = gr.Interface(\n", " image_mod,\n", " gr.Image(type=\"pil\"),\n", " \"image\",\n", " flagging_options=[\"blurry\", \"incorrect\", \"other\"],\n", " examples=[\n", " os.path.join(os.path.abspath(''), \"images/cheetah1.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/lion.jpg\"),\n", " ],\n", " )\n", "\n", " btn = gr.Button(\"Update Examples\")\n", " btn.click(lambda : gr.Dataset(samples=new_samples), None, interface.examples_handler.dataset)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -5,19 +5,25 @@ import os
def image_mod(image):
return image.rotate(45)
new_samples = [
[os.path.join(os.path.dirname(__file__), "images/logo.png")],
[os.path.join(os.path.dirname(__file__), "images/tower.jpg")],
]
demo = gr.Interface(
image_mod,
gr.Image(type="pil"),
"image",
flagging_options=["blurry", "incorrect", "other"],
examples=[
os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
os.path.join(os.path.dirname(__file__), "images/logo.png"),
os.path.join(os.path.dirname(__file__), "images/tower.jpg"),
],
)
with gr.Blocks() as demo:
interface = gr.Interface(
image_mod,
gr.Image(type="pil"),
"image",
flagging_options=["blurry", "incorrect", "other"],
examples=[
os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
],
)
btn = gr.Button("Update Examples")
btn.click(lambda : gr.Dataset(samples=new_samples), None, interface.examples_handler.dataset)
if __name__ == "__main__":
demo.launch()

View File

@ -1713,6 +1713,7 @@ Received outputs:
kwargs.pop("value", None)
kwargs.pop("__type__")
kwargs["render"] = False
state[block._id] = block.__class__(**kwargs)
prediction_value = postprocess_update_dict(

View File

@ -32,7 +32,7 @@ class Dataset(Component):
component_props: list[dict[str, Any]] | None = None,
samples: list[list[Any]] | None = None,
headers: list[str] | None = None,
type: Literal["values", "index"] = "values",
type: Literal["values", "index", "tuple"] = "values",
samples_per_page: int = 10,
visible: bool = True,
elem_id: str | None = None,
@ -51,7 +51,7 @@ class Dataset(Component):
components: Which component types to show in this dataset widget, can be passed in as a list of string names or Components instances. The following components are supported in a Dataset: Audio, Checkbox, CheckboxGroup, ColorPicker, Dataframe, Dropdown, File, HTML, Image, Markdown, Model3D, Number, Radio, Slider, Textbox, TimeSeries, Video
samples: a nested list of samples. Each sublist within the outer list represents a data sample, and each element within the sublist represents an value for each component
headers: Column headers in the Dataset widget, should be the same len as components. If not provided, inferred from component labels
type: 'values' if clicking on a sample should pass the value of the sample, or "index" if it should pass the index of the sample
type: "values" if clicking on a sample should pass the value of the sample, "index" if it should pass the index of the sample, or "tuple" if it should pass both the index and the value of the sample.
samples_per_page: how many examples to show per page.
visible: If False, component will be hidden.
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
@ -95,8 +95,10 @@ class Dataset(Component):
self.proxy_url = proxy_url
for component in self._components:
component.proxy_url = proxy_url
self.samples = [[]] if samples is None else samples
for example in self.samples:
self.raw_samples = [[]] if samples is None else samples
self.samples: list[list] = []
for example in self.raw_samples:
self.samples.append([])
for i, (component, ex) in enumerate(zip(self._components, example)):
# If proxy_url is set, that means it is being loaded from an external Gradio app
# which means that the example has already been processed.
@ -104,9 +106,9 @@ class Dataset(Component):
# The `as_example()` method has been renamed to `process_example()` but we
# use the previous name to be backwards-compatible with previously-created
# custom components
example[i] = component.as_example(ex)
example[i] = processing_utils.move_files_to_cache(
example[i], component, keep_in_cache=True
self.samples[-1].append(component.as_example(ex))
self.samples[-1][i] = processing_utils.move_files_to_cache(
self.samples[-1][i], component, keep_in_cache=True
)
self.type = type
self.label = label
@ -137,19 +139,21 @@ class Dataset(Component):
return config
def preprocess(self, payload: int | None) -> int | list | None:
def preprocess(self, payload: int | None) -> int | list | tuple[int, list] | None:
"""
Parameters:
payload: the index of the selected example in the dataset
Returns:
Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index")
Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index"), or as a `tuple` of the index and the data (if `type` is "tuple").
"""
if payload is None:
return None
if self.type == "index":
return payload
elif self.type == "values":
return self.samples[payload]
return self.raw_samples[payload]
elif self.type == "tuple":
return payload, self.raw_samples[payload]
def postprocess(self, sample: int | list | None) -> int | None:
"""

View File

@ -5,6 +5,7 @@ Defines helper methods useful for loading and caching Interface examples.
from __future__ import annotations
import ast
import copy
import csv
import inspect
import os
@ -28,6 +29,7 @@ from gradio.data_classes import GradioModel, GradioRootModel
from gradio.events import Dependency, EventData
from gradio.exceptions import Error
from gradio.flagging import CSVLogger
from gradio.utils import UnhashableKeyDict
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio.components import Component
@ -239,6 +241,7 @@ class Examples:
self.examples = examples
self.non_none_examples = non_none_examples
self.inputs = inputs
self.input_has_examples = input_has_examples
self.inputs_with_examples = inputs_with_examples
self.outputs = outputs or []
self.fn = fn
@ -248,35 +251,15 @@ class Examples:
self.api_name: str | Literal[False] = api_name
self.batch = batch
self.example_labels = example_labels
with utils.set_directory(working_directory):
self.processed_examples = []
for example in examples:
sub = []
for component, sample in zip(inputs, example):
prediction_value = component.postprocess(sample)
if isinstance(prediction_value, (GradioRootModel, GradioModel)):
prediction_value = prediction_value.model_dump()
prediction_value = processing_utils.move_files_to_cache(
prediction_value,
component,
postprocess=True,
)
sub.append(prediction_value)
self.processed_examples.append(sub)
self.non_none_processed_examples = [
[ex for (ex, keep) in zip(example, input_has_examples) if keep]
for example in self.processed_examples
]
self.working_directory = working_directory
from gradio import components
with utils.set_directory(working_directory):
self.dataset = components.Dataset(
components=inputs_with_examples,
samples=non_none_examples,
type="index",
samples=copy.deepcopy(non_none_examples),
type="tuple",
label=label,
samples_per_page=examples_per_page,
elem_id=elem_id,
@ -290,13 +273,38 @@ class Examples:
self.cached_indices_file = Path(self.cached_folder) / "indices.csv"
self.run_on_click = run_on_click
self.cache_event: Dependency | None = None
self.non_none_processed_examples = UnhashableKeyDict()
if self.dataset.samples:
for index, example in enumerate(self.non_none_examples):
self.non_none_processed_examples[self.dataset.samples[index]] = (
self._get_processed_example(example)
)
def _get_processed_example(self, example):
if example in self.non_none_processed_examples:
return self.non_none_processed_examples[example]
with utils.set_directory(self.working_directory):
sub = []
for component, sample in zip(self.inputs, example):
prediction_value = component.postprocess(sample)
if isinstance(prediction_value, (GradioRootModel, GradioModel)):
prediction_value = prediction_value.model_dump()
prediction_value = processing_utils.move_files_to_cache(
prediction_value,
component,
postprocess=True,
)
sub.append(prediction_value)
return [ex for (ex, keep) in zip(sub, self.input_has_examples) if keep]
def create(self) -> None:
"""Caches the examples if self.cache_examples is True and creates the Dataset
component to hold the examples"""
async def load_example(example_id):
processed_example = self.non_none_processed_examples[example_id]
async def load_example(example_tuple):
_, example_value = example_tuple
processed_example = self._get_processed_example(example_value)
if len(self.inputs_with_examples) == 1:
return update(
value=processed_example[0],
@ -496,9 +504,9 @@ class Examples:
if self.outputs is None:
raise ValueError("self.outputs is missing")
for example_id in range(len(self.examples)):
print(f"Caching example {example_id + 1}/{len(self.examples)}")
processed_input = self.processed_examples[example_id]
for i, example in enumerate(self.examples):
print(f"Caching example {i + 1}/{len(self.examples)}")
processed_input = self._get_processed_example(example)
if self.batch:
processed_input = [[value] for value in processed_input]
with utils.MatplotlibBackendMananger():
@ -523,10 +531,11 @@ class Examples:
# method to be called independently of the create() method
blocks_config.fns.pop(self.load_input_event["id"])
def load_example(example_id):
processed_example = self.non_none_processed_examples[
example_id
] + self.load_from_cache(example_id)
def load_example(example_tuple):
example_id, example_value = example_tuple
processed_example = self._get_processed_example(
example_value
) + self.load_from_cache(example_id)
return utils.resolve_singleton(processed_example)
self.cache_event = self.load_input_event = self.dataset.click(

View File

@ -25,6 +25,7 @@ import urllib.parse
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict
from collections.abc import MutableMapping
from contextlib import contextmanager
from functools import wraps
from io import BytesIO
@ -1424,3 +1425,42 @@ def error_payload(
content["duration"] = error.duration
content["visible"] = error.visible
return content
class UnhashableKeyDict(MutableMapping):
"""
Essentially a list of key-value tuples that allows for keys that are not hashable,
but acts like a dictionary for convenience.
"""
def __init__(self):
self.data = []
def __getitem__(self, key):
for k, v in self.data:
if deep_equal(k, key):
return v
raise KeyError(key)
def __setitem__(self, key, value):
for i, (k, _) in enumerate(self.data):
if deep_equal(k, key):
self.data[i] = (key, value)
return
self.data.append((key, value))
def __delitem__(self, key):
for i, (k, _) in enumerate(self.data):
if deep_equal(k, key):
del self.data[i]
return
raise KeyError(key)
def __iter__(self):
return (k for k, _ in self.data)
def __len__(self):
return len(self.data)
def as_list(self):
return [v for _, v in self.data]

View File

@ -63,6 +63,31 @@ None
<ParamTable parameters={obj.attributes} />
### Examples
**Updating Examples**
In this demo, we show how to update the examples by updating the samples of the underlying dataset. Note that this only works if `cache_examples=False` as updating the underlying dataset does not update the cache.
```py
import gradio as gr
def update_examples(country):
if country == "USA":
return gr.Dataset(samples=[["Chicago"], ["Little Rock"], ["San Francisco"]])
else:
return gr.Dataset(samples=[["Islamabad"], ["Karachi"], ["Lahore"]])
with gr.Blocks() as demo:
dropdown = gr.Dropdown(label="Country", choices=["USA", "Pakistan"], value="USA")
textbox = gr.Textbox()
examples = gr.Examples([["Chicago"], ["Little Rock"], ["San Francisco"]], textbox)
dropdown.change(update_examples, dropdown, examples.dataset)
demo.launch()
```
{#if obj.demos && obj.demos.length > 0}
<!--- Demos -->
### Demos

View File

@ -0,0 +1,20 @@
import { test, expect } from "@gradio/tootils";
test("examples_get_updated_correctly", async ({ page }) => {
await page.locator(".gallery-item").first().click();
let image = await page.getByTestId("image").locator("img").first();
await expect(await image.getAttribute("src")).toContain("cheetah1.jpg");
await page.getByRole("button", { name: "Update Examples" }).click();
let example_image;
await expect(async () => {
example_image = await page.locator(".gallery-item").locator("img").first();
await expect(await example_image.getAttribute("src")).toContain("logo.png");
}).toPass();
await example_image.click();
await expect(async () => {
image = await page.getByTestId("image").locator("img").first();
await expect(await image.getAttribute("src")).toContain("logo.png");
}).toPass();
});

View File

@ -5,8 +5,7 @@
export let index: number;
let hovered = false;
let loaded_value: (string | number)[][] | string = value;
let loaded = Array.isArray(loaded_value);
let loaded = Array.isArray(value);
</script>
{#if loaded}
@ -19,11 +18,11 @@
on:mouseenter={() => (hovered = true)}
on:mouseleave={() => (hovered = false)}
>
{#if typeof loaded_value === "string"}
{loaded_value}
{#if typeof value === "string"}
{value}
{:else}
<table class="">
{#each loaded_value.slice(0, 3) as row, i}
{#each value.slice(0, 3) as row, i}
<tr>
{#each row.slice(0, 3) as cell, j}
<td>{cell}</td>

View File

@ -23,7 +23,7 @@ class TestDataset:
row = dataset.preprocess(1)
assert row[0] == 15
assert row[1] == "hi"
assert row[2]["path"].endswith("bus.png")
assert row[2].endswith("bus.png")
assert row[3] == "<i>Italics</i>"
assert row[4] == "*Italics*"

View File

@ -24,15 +24,15 @@ from gradio import helpers, utils
class TestExamples:
def test_handle_single_input(self, patched_cache_folder):
examples = gr.Examples(["hello", "hi"], gr.Textbox())
assert examples.processed_examples == [["hello"], ["hi"]]
assert examples.non_none_processed_examples.as_list() == [["hello"], ["hi"]]
examples = gr.Examples([["hello"]], gr.Textbox())
assert examples.processed_examples == [["hello"]]
assert examples.non_none_processed_examples.as_list() == [["hello"]]
examples = gr.Examples(["test/test_files/bus.png"], gr.Image())
assert (
client_utils.encode_file_to_base64(
examples.processed_examples[0][0]["path"]
examples.non_none_processed_examples.as_list()[0][0]["path"]
)
== media_data.BASE64_IMAGE
)
@ -41,18 +41,18 @@ class TestExamples:
examples = gr.Examples(
[["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()]
)
assert examples.processed_examples[0][0] == "hello"
assert examples.non_none_processed_examples.as_list()[0][0] == "hello"
assert (
client_utils.encode_file_to_base64(
examples.processed_examples[0][1]["path"]
examples.non_none_processed_examples.as_list()[0][1]["path"]
)
== media_data.BASE64_IMAGE
)
def test_handle_directory(self, patched_cache_folder):
examples = gr.Examples("test/test_files/images", gr.Image())
assert len(examples.processed_examples) == 2
for row in examples.processed_examples:
assert len(examples.non_none_processed_examples.as_list()) == 2
for row in examples.non_none_processed_examples.as_list():
for output in row:
assert (
client_utils.encode_file_to_base64(output["path"])
@ -64,7 +64,7 @@ class TestExamples:
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
)
ex = client_utils.traverse(
examples.processed_examples,
examples.non_none_processed_examples.as_list(),
lambda s: client_utils.encode_file_to_base64(s["path"]),
lambda x: isinstance(x, dict) and Path(x["path"]).exists(),
)

View File

@ -7,6 +7,7 @@ import warnings
from pathlib import Path
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from typing_extensions import Literal
@ -14,6 +15,7 @@ from gradio import EventData, Request
from gradio.external_utils import format_ner_list
from gradio.utils import (
FileSize,
UnhashableKeyDict,
_parse_file_size,
abspath,
append_unique_suffix,
@ -474,3 +476,62 @@ def test_parse_file_size():
assert _parse_file_size("1kb") == 1 * FileSize.KB
assert _parse_file_size("1mb") == 1 * FileSize.MB
assert _parse_file_size("505 Mb") == 505 * FileSize.MB
class TestUnhashableKeyDict:
def test_set_get_simple(self):
d = UnhashableKeyDict()
d["a"] = 1
assert d["a"] == 1
def test_set_get_unhashable(self):
d = UnhashableKeyDict()
key = [1, 2, 3]
key2 = [1, 2, 3]
d[key] = "value"
assert d[key] == "value"
assert d[key2] == "value"
def test_set_get_numpy_array(self):
d = UnhashableKeyDict()
key = np.array([1, 2, 3])
key2 = np.array([1, 2, 3])
d[key] = "numpy value"
assert d[key2] == "numpy value"
def test_overwrite(self):
d = UnhashableKeyDict()
d["key"] = "old"
d["key"] = "new"
assert d["key"] == "new"
def test_delete(self):
d = UnhashableKeyDict()
d["key"] = "value"
del d["key"]
assert len(d) == 0
with pytest.raises(KeyError):
d["key"]
def test_delete_nonexistent(self):
d = UnhashableKeyDict()
with pytest.raises(KeyError):
del d["nonexistent"]
def test_len(self):
d = UnhashableKeyDict()
assert len(d) == 0
d["a"] = 1
d["b"] = 2
assert len(d) == 2
def test_contains(self):
d = UnhashableKeyDict()
d["key"] = "value"
assert "key" in d
assert "nonexistent" not in d
def test_get_nonexistent(self):
d = UnhashableKeyDict()
with pytest.raises(KeyError):
d["nonexistent"]