mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
Allows updating the dataset of a gr.Examples
(#8745)
* helpers * add changeset * changes * add changeset * changes * tweak * format * example to docs * add changeset * fixes * add tuple * add changeset * print * format * clean' * clean * format * format backend * fix backend tests * format * notebooks * comment * delete demo * add changeset * docstring * docstring * changes * add changeset * components * changes * changes * format * add test * fix python test * use deep_equal --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
2d179f647b
commit
4030f28af6
7
.changeset/real-grapes-accept.md
Normal file
7
.changeset/real-grapes-accept.md
Normal file
@ -0,0 +1,7 @@
|
||||
---
|
||||
"@gradio/dataframe": minor
|
||||
"gradio": minor
|
||||
"website": minor
|
||||
---
|
||||
|
||||
feat:Allows updating the dataset of a `gr.Examples`
|
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: image_mod"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('images')\n", "!wget -q -O images/cheetah1.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/cheetah1.jpg\n", "!wget -q -O images/lion.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/lion.jpg\n", "!wget -q -O images/logo.png https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/logo.png\n", "!wget -q -O images/tower.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/tower.jpg"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "\n", "\n", "def image_mod(image):\n", " return image.rotate(45)\n", "\n", "\n", "demo = gr.Interface(\n", " image_mod,\n", " gr.Image(type=\"pil\"),\n", " \"image\",\n", " flagging_options=[\"blurry\", \"incorrect\", \"other\"],\n", " examples=[\n", " os.path.join(os.path.abspath(''), \"images/cheetah1.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/lion.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/logo.png\"),\n", " os.path.join(os.path.abspath(''), \"images/tower.jpg\"),\n", " ],\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: image_mod"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('images')\n", "!wget -q -O images/cheetah1.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/cheetah1.jpg\n", "!wget -q -O images/lion.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/lion.jpg\n", "!wget -q -O images/logo.png https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/logo.png\n", "!wget -q -O images/tower.jpg https://github.com/gradio-app/gradio/raw/main/demo/image_mod/images/tower.jpg"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "\n", "\n", "def image_mod(image):\n", " return image.rotate(45)\n", "\n", "new_samples = [\n", " [os.path.join(os.path.abspath(''), \"images/logo.png\")],\n", " [os.path.join(os.path.abspath(''), \"images/tower.jpg\")],\n", "]\n", "\n", "with gr.Blocks() as demo:\n", " interface = gr.Interface(\n", " image_mod,\n", " gr.Image(type=\"pil\"),\n", " \"image\",\n", " flagging_options=[\"blurry\", \"incorrect\", \"other\"],\n", " examples=[\n", " os.path.join(os.path.abspath(''), \"images/cheetah1.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/lion.jpg\"),\n", " ],\n", " )\n", "\n", " btn = gr.Button(\"Update Examples\")\n", " btn.click(lambda : gr.Dataset(samples=new_samples), None, interface.examples_handler.dataset)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -5,19 +5,25 @@ import os
|
||||
def image_mod(image):
|
||||
return image.rotate(45)
|
||||
|
||||
new_samples = [
|
||||
[os.path.join(os.path.dirname(__file__), "images/logo.png")],
|
||||
[os.path.join(os.path.dirname(__file__), "images/tower.jpg")],
|
||||
]
|
||||
|
||||
demo = gr.Interface(
|
||||
image_mod,
|
||||
gr.Image(type="pil"),
|
||||
"image",
|
||||
flagging_options=["blurry", "incorrect", "other"],
|
||||
examples=[
|
||||
os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
|
||||
os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
|
||||
os.path.join(os.path.dirname(__file__), "images/logo.png"),
|
||||
os.path.join(os.path.dirname(__file__), "images/tower.jpg"),
|
||||
],
|
||||
)
|
||||
with gr.Blocks() as demo:
|
||||
interface = gr.Interface(
|
||||
image_mod,
|
||||
gr.Image(type="pil"),
|
||||
"image",
|
||||
flagging_options=["blurry", "incorrect", "other"],
|
||||
examples=[
|
||||
os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
|
||||
os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
|
||||
],
|
||||
)
|
||||
|
||||
btn = gr.Button("Update Examples")
|
||||
btn.click(lambda : gr.Dataset(samples=new_samples), None, interface.examples_handler.dataset)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
||||
|
@ -1713,6 +1713,7 @@ Received outputs:
|
||||
kwargs.pop("value", None)
|
||||
kwargs.pop("__type__")
|
||||
kwargs["render"] = False
|
||||
|
||||
state[block._id] = block.__class__(**kwargs)
|
||||
|
||||
prediction_value = postprocess_update_dict(
|
||||
|
@ -32,7 +32,7 @@ class Dataset(Component):
|
||||
component_props: list[dict[str, Any]] | None = None,
|
||||
samples: list[list[Any]] | None = None,
|
||||
headers: list[str] | None = None,
|
||||
type: Literal["values", "index"] = "values",
|
||||
type: Literal["values", "index", "tuple"] = "values",
|
||||
samples_per_page: int = 10,
|
||||
visible: bool = True,
|
||||
elem_id: str | None = None,
|
||||
@ -51,7 +51,7 @@ class Dataset(Component):
|
||||
components: Which component types to show in this dataset widget, can be passed in as a list of string names or Components instances. The following components are supported in a Dataset: Audio, Checkbox, CheckboxGroup, ColorPicker, Dataframe, Dropdown, File, HTML, Image, Markdown, Model3D, Number, Radio, Slider, Textbox, TimeSeries, Video
|
||||
samples: a nested list of samples. Each sublist within the outer list represents a data sample, and each element within the sublist represents an value for each component
|
||||
headers: Column headers in the Dataset widget, should be the same len as components. If not provided, inferred from component labels
|
||||
type: 'values' if clicking on a sample should pass the value of the sample, or "index" if it should pass the index of the sample
|
||||
type: "values" if clicking on a sample should pass the value of the sample, "index" if it should pass the index of the sample, or "tuple" if it should pass both the index and the value of the sample.
|
||||
samples_per_page: how many examples to show per page.
|
||||
visible: If False, component will be hidden.
|
||||
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
||||
@ -95,8 +95,10 @@ class Dataset(Component):
|
||||
self.proxy_url = proxy_url
|
||||
for component in self._components:
|
||||
component.proxy_url = proxy_url
|
||||
self.samples = [[]] if samples is None else samples
|
||||
for example in self.samples:
|
||||
self.raw_samples = [[]] if samples is None else samples
|
||||
self.samples: list[list] = []
|
||||
for example in self.raw_samples:
|
||||
self.samples.append([])
|
||||
for i, (component, ex) in enumerate(zip(self._components, example)):
|
||||
# If proxy_url is set, that means it is being loaded from an external Gradio app
|
||||
# which means that the example has already been processed.
|
||||
@ -104,9 +106,9 @@ class Dataset(Component):
|
||||
# The `as_example()` method has been renamed to `process_example()` but we
|
||||
# use the previous name to be backwards-compatible with previously-created
|
||||
# custom components
|
||||
example[i] = component.as_example(ex)
|
||||
example[i] = processing_utils.move_files_to_cache(
|
||||
example[i], component, keep_in_cache=True
|
||||
self.samples[-1].append(component.as_example(ex))
|
||||
self.samples[-1][i] = processing_utils.move_files_to_cache(
|
||||
self.samples[-1][i], component, keep_in_cache=True
|
||||
)
|
||||
self.type = type
|
||||
self.label = label
|
||||
@ -137,19 +139,21 @@ class Dataset(Component):
|
||||
|
||||
return config
|
||||
|
||||
def preprocess(self, payload: int | None) -> int | list | None:
|
||||
def preprocess(self, payload: int | None) -> int | list | tuple[int, list] | None:
|
||||
"""
|
||||
Parameters:
|
||||
payload: the index of the selected example in the dataset
|
||||
Returns:
|
||||
Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index")
|
||||
Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index"), or as a `tuple` of the index and the data (if `type` is "tuple").
|
||||
"""
|
||||
if payload is None:
|
||||
return None
|
||||
if self.type == "index":
|
||||
return payload
|
||||
elif self.type == "values":
|
||||
return self.samples[payload]
|
||||
return self.raw_samples[payload]
|
||||
elif self.type == "tuple":
|
||||
return payload, self.raw_samples[payload]
|
||||
|
||||
def postprocess(self, sample: int | list | None) -> int | None:
|
||||
"""
|
||||
|
@ -5,6 +5,7 @@ Defines helper methods useful for loading and caching Interface examples.
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import copy
|
||||
import csv
|
||||
import inspect
|
||||
import os
|
||||
@ -28,6 +29,7 @@ from gradio.data_classes import GradioModel, GradioRootModel
|
||||
from gradio.events import Dependency, EventData
|
||||
from gradio.exceptions import Error
|
||||
from gradio.flagging import CSVLogger
|
||||
from gradio.utils import UnhashableKeyDict
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
|
||||
from gradio.components import Component
|
||||
@ -239,6 +241,7 @@ class Examples:
|
||||
self.examples = examples
|
||||
self.non_none_examples = non_none_examples
|
||||
self.inputs = inputs
|
||||
self.input_has_examples = input_has_examples
|
||||
self.inputs_with_examples = inputs_with_examples
|
||||
self.outputs = outputs or []
|
||||
self.fn = fn
|
||||
@ -248,35 +251,15 @@ class Examples:
|
||||
self.api_name: str | Literal[False] = api_name
|
||||
self.batch = batch
|
||||
self.example_labels = example_labels
|
||||
|
||||
with utils.set_directory(working_directory):
|
||||
self.processed_examples = []
|
||||
for example in examples:
|
||||
sub = []
|
||||
for component, sample in zip(inputs, example):
|
||||
prediction_value = component.postprocess(sample)
|
||||
if isinstance(prediction_value, (GradioRootModel, GradioModel)):
|
||||
prediction_value = prediction_value.model_dump()
|
||||
prediction_value = processing_utils.move_files_to_cache(
|
||||
prediction_value,
|
||||
component,
|
||||
postprocess=True,
|
||||
)
|
||||
sub.append(prediction_value)
|
||||
self.processed_examples.append(sub)
|
||||
|
||||
self.non_none_processed_examples = [
|
||||
[ex for (ex, keep) in zip(example, input_has_examples) if keep]
|
||||
for example in self.processed_examples
|
||||
]
|
||||
self.working_directory = working_directory
|
||||
|
||||
from gradio import components
|
||||
|
||||
with utils.set_directory(working_directory):
|
||||
self.dataset = components.Dataset(
|
||||
components=inputs_with_examples,
|
||||
samples=non_none_examples,
|
||||
type="index",
|
||||
samples=copy.deepcopy(non_none_examples),
|
||||
type="tuple",
|
||||
label=label,
|
||||
samples_per_page=examples_per_page,
|
||||
elem_id=elem_id,
|
||||
@ -290,13 +273,38 @@ class Examples:
|
||||
self.cached_indices_file = Path(self.cached_folder) / "indices.csv"
|
||||
self.run_on_click = run_on_click
|
||||
self.cache_event: Dependency | None = None
|
||||
self.non_none_processed_examples = UnhashableKeyDict()
|
||||
|
||||
if self.dataset.samples:
|
||||
for index, example in enumerate(self.non_none_examples):
|
||||
self.non_none_processed_examples[self.dataset.samples[index]] = (
|
||||
self._get_processed_example(example)
|
||||
)
|
||||
|
||||
def _get_processed_example(self, example):
|
||||
if example in self.non_none_processed_examples:
|
||||
return self.non_none_processed_examples[example]
|
||||
with utils.set_directory(self.working_directory):
|
||||
sub = []
|
||||
for component, sample in zip(self.inputs, example):
|
||||
prediction_value = component.postprocess(sample)
|
||||
if isinstance(prediction_value, (GradioRootModel, GradioModel)):
|
||||
prediction_value = prediction_value.model_dump()
|
||||
prediction_value = processing_utils.move_files_to_cache(
|
||||
prediction_value,
|
||||
component,
|
||||
postprocess=True,
|
||||
)
|
||||
sub.append(prediction_value)
|
||||
return [ex for (ex, keep) in zip(sub, self.input_has_examples) if keep]
|
||||
|
||||
def create(self) -> None:
|
||||
"""Caches the examples if self.cache_examples is True and creates the Dataset
|
||||
component to hold the examples"""
|
||||
|
||||
async def load_example(example_id):
|
||||
processed_example = self.non_none_processed_examples[example_id]
|
||||
async def load_example(example_tuple):
|
||||
_, example_value = example_tuple
|
||||
processed_example = self._get_processed_example(example_value)
|
||||
if len(self.inputs_with_examples) == 1:
|
||||
return update(
|
||||
value=processed_example[0],
|
||||
@ -496,9 +504,9 @@ class Examples:
|
||||
|
||||
if self.outputs is None:
|
||||
raise ValueError("self.outputs is missing")
|
||||
for example_id in range(len(self.examples)):
|
||||
print(f"Caching example {example_id + 1}/{len(self.examples)}")
|
||||
processed_input = self.processed_examples[example_id]
|
||||
for i, example in enumerate(self.examples):
|
||||
print(f"Caching example {i + 1}/{len(self.examples)}")
|
||||
processed_input = self._get_processed_example(example)
|
||||
if self.batch:
|
||||
processed_input = [[value] for value in processed_input]
|
||||
with utils.MatplotlibBackendMananger():
|
||||
@ -523,10 +531,11 @@ class Examples:
|
||||
# method to be called independently of the create() method
|
||||
blocks_config.fns.pop(self.load_input_event["id"])
|
||||
|
||||
def load_example(example_id):
|
||||
processed_example = self.non_none_processed_examples[
|
||||
example_id
|
||||
] + self.load_from_cache(example_id)
|
||||
def load_example(example_tuple):
|
||||
example_id, example_value = example_tuple
|
||||
processed_example = self._get_processed_example(
|
||||
example_value
|
||||
) + self.load_from_cache(example_id)
|
||||
return utils.resolve_singleton(processed_example)
|
||||
|
||||
self.cache_event = self.load_input_event = self.dataset.click(
|
||||
|
@ -25,6 +25,7 @@ import urllib.parse
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from collections.abc import MutableMapping
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from io import BytesIO
|
||||
@ -1424,3 +1425,42 @@ def error_payload(
|
||||
content["duration"] = error.duration
|
||||
content["visible"] = error.visible
|
||||
return content
|
||||
|
||||
|
||||
class UnhashableKeyDict(MutableMapping):
|
||||
"""
|
||||
Essentially a list of key-value tuples that allows for keys that are not hashable,
|
||||
but acts like a dictionary for convenience.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.data = []
|
||||
|
||||
def __getitem__(self, key):
|
||||
for k, v in self.data:
|
||||
if deep_equal(k, key):
|
||||
return v
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
for i, (k, _) in enumerate(self.data):
|
||||
if deep_equal(k, key):
|
||||
self.data[i] = (key, value)
|
||||
return
|
||||
self.data.append((key, value))
|
||||
|
||||
def __delitem__(self, key):
|
||||
for i, (k, _) in enumerate(self.data):
|
||||
if deep_equal(k, key):
|
||||
del self.data[i]
|
||||
return
|
||||
raise KeyError(key)
|
||||
|
||||
def __iter__(self):
|
||||
return (k for k, _ in self.data)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def as_list(self):
|
||||
return [v for _, v in self.data]
|
||||
|
@ -63,6 +63,31 @@ None
|
||||
<ParamTable parameters={obj.attributes} />
|
||||
|
||||
|
||||
### Examples
|
||||
|
||||
**Updating Examples**
|
||||
|
||||
In this demo, we show how to update the examples by updating the samples of the underlying dataset. Note that this only works if `cache_examples=False` as updating the underlying dataset does not update the cache.
|
||||
|
||||
```py
|
||||
import gradio as gr
|
||||
|
||||
def update_examples(country):
|
||||
if country == "USA":
|
||||
return gr.Dataset(samples=[["Chicago"], ["Little Rock"], ["San Francisco"]])
|
||||
else:
|
||||
return gr.Dataset(samples=[["Islamabad"], ["Karachi"], ["Lahore"]])
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
dropdown = gr.Dropdown(label="Country", choices=["USA", "Pakistan"], value="USA")
|
||||
textbox = gr.Textbox()
|
||||
examples = gr.Examples([["Chicago"], ["Little Rock"], ["San Francisco"]], textbox)
|
||||
dropdown.change(update_examples, dropdown, examples.dataset)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
|
||||
{#if obj.demos && obj.demos.length > 0}
|
||||
<!--- Demos -->
|
||||
### Demos
|
||||
|
20
js/app/test/image_mod.spec.ts
Normal file
20
js/app/test/image_mod.spec.ts
Normal file
@ -0,0 +1,20 @@
|
||||
import { test, expect } from "@gradio/tootils";
|
||||
|
||||
test("examples_get_updated_correctly", async ({ page }) => {
|
||||
await page.locator(".gallery-item").first().click();
|
||||
let image = await page.getByTestId("image").locator("img").first();
|
||||
await expect(await image.getAttribute("src")).toContain("cheetah1.jpg");
|
||||
await page.getByRole("button", { name: "Update Examples" }).click();
|
||||
|
||||
let example_image;
|
||||
await expect(async () => {
|
||||
example_image = await page.locator(".gallery-item").locator("img").first();
|
||||
await expect(await example_image.getAttribute("src")).toContain("logo.png");
|
||||
}).toPass();
|
||||
|
||||
await example_image.click();
|
||||
await expect(async () => {
|
||||
image = await page.getByTestId("image").locator("img").first();
|
||||
await expect(await image.getAttribute("src")).toContain("logo.png");
|
||||
}).toPass();
|
||||
});
|
@ -5,8 +5,7 @@
|
||||
export let index: number;
|
||||
|
||||
let hovered = false;
|
||||
let loaded_value: (string | number)[][] | string = value;
|
||||
let loaded = Array.isArray(loaded_value);
|
||||
let loaded = Array.isArray(value);
|
||||
</script>
|
||||
|
||||
{#if loaded}
|
||||
@ -19,11 +18,11 @@
|
||||
on:mouseenter={() => (hovered = true)}
|
||||
on:mouseleave={() => (hovered = false)}
|
||||
>
|
||||
{#if typeof loaded_value === "string"}
|
||||
{loaded_value}
|
||||
{#if typeof value === "string"}
|
||||
{value}
|
||||
{:else}
|
||||
<table class="">
|
||||
{#each loaded_value.slice(0, 3) as row, i}
|
||||
{#each value.slice(0, 3) as row, i}
|
||||
<tr>
|
||||
{#each row.slice(0, 3) as cell, j}
|
||||
<td>{cell}</td>
|
||||
|
@ -23,7 +23,7 @@ class TestDataset:
|
||||
row = dataset.preprocess(1)
|
||||
assert row[0] == 15
|
||||
assert row[1] == "hi"
|
||||
assert row[2]["path"].endswith("bus.png")
|
||||
assert row[2].endswith("bus.png")
|
||||
assert row[3] == "<i>Italics</i>"
|
||||
assert row[4] == "*Italics*"
|
||||
|
||||
|
@ -24,15 +24,15 @@ from gradio import helpers, utils
|
||||
class TestExamples:
|
||||
def test_handle_single_input(self, patched_cache_folder):
|
||||
examples = gr.Examples(["hello", "hi"], gr.Textbox())
|
||||
assert examples.processed_examples == [["hello"], ["hi"]]
|
||||
assert examples.non_none_processed_examples.as_list() == [["hello"], ["hi"]]
|
||||
|
||||
examples = gr.Examples([["hello"]], gr.Textbox())
|
||||
assert examples.processed_examples == [["hello"]]
|
||||
assert examples.non_none_processed_examples.as_list() == [["hello"]]
|
||||
|
||||
examples = gr.Examples(["test/test_files/bus.png"], gr.Image())
|
||||
assert (
|
||||
client_utils.encode_file_to_base64(
|
||||
examples.processed_examples[0][0]["path"]
|
||||
examples.non_none_processed_examples.as_list()[0][0]["path"]
|
||||
)
|
||||
== media_data.BASE64_IMAGE
|
||||
)
|
||||
@ -41,18 +41,18 @@ class TestExamples:
|
||||
examples = gr.Examples(
|
||||
[["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()]
|
||||
)
|
||||
assert examples.processed_examples[0][0] == "hello"
|
||||
assert examples.non_none_processed_examples.as_list()[0][0] == "hello"
|
||||
assert (
|
||||
client_utils.encode_file_to_base64(
|
||||
examples.processed_examples[0][1]["path"]
|
||||
examples.non_none_processed_examples.as_list()[0][1]["path"]
|
||||
)
|
||||
== media_data.BASE64_IMAGE
|
||||
)
|
||||
|
||||
def test_handle_directory(self, patched_cache_folder):
|
||||
examples = gr.Examples("test/test_files/images", gr.Image())
|
||||
assert len(examples.processed_examples) == 2
|
||||
for row in examples.processed_examples:
|
||||
assert len(examples.non_none_processed_examples.as_list()) == 2
|
||||
for row in examples.non_none_processed_examples.as_list():
|
||||
for output in row:
|
||||
assert (
|
||||
client_utils.encode_file_to_base64(output["path"])
|
||||
@ -64,7 +64,7 @@ class TestExamples:
|
||||
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
|
||||
)
|
||||
ex = client_utils.traverse(
|
||||
examples.processed_examples,
|
||||
examples.non_none_processed_examples.as_list(),
|
||||
lambda s: client_utils.encode_file_to_base64(s["path"]),
|
||||
lambda x: isinstance(x, dict) and Path(x["path"]).exists(),
|
||||
)
|
||||
|
@ -7,6 +7,7 @@ import warnings
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from typing_extensions import Literal
|
||||
|
||||
@ -14,6 +15,7 @@ from gradio import EventData, Request
|
||||
from gradio.external_utils import format_ner_list
|
||||
from gradio.utils import (
|
||||
FileSize,
|
||||
UnhashableKeyDict,
|
||||
_parse_file_size,
|
||||
abspath,
|
||||
append_unique_suffix,
|
||||
@ -474,3 +476,62 @@ def test_parse_file_size():
|
||||
assert _parse_file_size("1kb") == 1 * FileSize.KB
|
||||
assert _parse_file_size("1mb") == 1 * FileSize.MB
|
||||
assert _parse_file_size("505 Mb") == 505 * FileSize.MB
|
||||
|
||||
|
||||
class TestUnhashableKeyDict:
|
||||
def test_set_get_simple(self):
|
||||
d = UnhashableKeyDict()
|
||||
d["a"] = 1
|
||||
assert d["a"] == 1
|
||||
|
||||
def test_set_get_unhashable(self):
|
||||
d = UnhashableKeyDict()
|
||||
key = [1, 2, 3]
|
||||
key2 = [1, 2, 3]
|
||||
d[key] = "value"
|
||||
assert d[key] == "value"
|
||||
assert d[key2] == "value"
|
||||
|
||||
def test_set_get_numpy_array(self):
|
||||
d = UnhashableKeyDict()
|
||||
key = np.array([1, 2, 3])
|
||||
key2 = np.array([1, 2, 3])
|
||||
d[key] = "numpy value"
|
||||
assert d[key2] == "numpy value"
|
||||
|
||||
def test_overwrite(self):
|
||||
d = UnhashableKeyDict()
|
||||
d["key"] = "old"
|
||||
d["key"] = "new"
|
||||
assert d["key"] == "new"
|
||||
|
||||
def test_delete(self):
|
||||
d = UnhashableKeyDict()
|
||||
d["key"] = "value"
|
||||
del d["key"]
|
||||
assert len(d) == 0
|
||||
with pytest.raises(KeyError):
|
||||
d["key"]
|
||||
|
||||
def test_delete_nonexistent(self):
|
||||
d = UnhashableKeyDict()
|
||||
with pytest.raises(KeyError):
|
||||
del d["nonexistent"]
|
||||
|
||||
def test_len(self):
|
||||
d = UnhashableKeyDict()
|
||||
assert len(d) == 0
|
||||
d["a"] = 1
|
||||
d["b"] = 2
|
||||
assert len(d) == 2
|
||||
|
||||
def test_contains(self):
|
||||
d = UnhashableKeyDict()
|
||||
d["key"] = "value"
|
||||
assert "key" in d
|
||||
assert "nonexistent" not in d
|
||||
|
||||
def test_get_nonexistent(self):
|
||||
d = UnhashableKeyDict()
|
||||
with pytest.raises(KeyError):
|
||||
d["nonexistent"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user