fix dataset update (#8581)

* fix dataset update

* revert'

* add changeset

* add test

* add changeset

* changes

* add template

* add changeset

* fix docstring

* test postprocessing

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-06-19 19:00:57 -04:00 committed by GitHub
parent 2b0c1577b2
commit a1c21cb69a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 102 additions and 37 deletions

View File

@ -0,0 +1,7 @@
---
"@gradio/dataset": patch
"gradio": patch
"website": patch
---
fix:fix dataset update

View File

@ -1723,12 +1723,12 @@ Received outputs:
) from err
if block.stateful:
if not utils.is_update(predictions[i]):
if not utils.is_prop_update(predictions[i]):
state[block._id] = predictions[i]
output.append(None)
else:
prediction_value = predictions[i]
if utils.is_update(
if utils.is_prop_update(
prediction_value
): # if update is passed directly (deprecated), remove Nones
prediction_value = utils.delete_none(
@ -1738,7 +1738,7 @@ Received outputs:
if isinstance(prediction_value, Block):
prediction_value = prediction_value.constructor_args.copy()
prediction_value["__type__"] = "update"
if utils.is_update(prediction_value):
if utils.is_prop_update(prediction_value):
kwargs = state[block._id].constructor_args.copy()
kwargs.update(prediction_value)
kwargs.pop("value", None)

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import warnings
from typing import Any, Literal
from gradio_client.documentation import document
@ -17,7 +18,8 @@ from gradio.events import Events
@document()
class Dataset(Component):
"""
Creates a gallery or table to display data samples. This component is designed for internal use to display examples.
Creates a gallery or table to display data samples. This component is primarily designed for internal use to display examples.
However, it can also be used directly to display a dataset and let users select examples.
"""
EVENTS = [Events.click, Events.select]
@ -26,7 +28,7 @@ class Dataset(Component):
self,
*,
label: str | None = None,
components: list[Component] | list[str],
components: list[Component] | list[str] | None = None,
component_props: list[dict[str, Any]] | None = None,
samples: list[list[Any]] | None = None,
headers: list[str] | None = None,
@ -70,7 +72,7 @@ class Dataset(Component):
self.container = container
self.scale = scale
self.min_width = min_width
self._components = [get_component_instance(c) for c in components]
self._components = [get_component_instance(c) for c in components or []]
if component_props is None:
self.component_props = [
component.recover_kwargs(
@ -131,29 +133,39 @@ class Dataset(Component):
return config
def preprocess(self, payload: int) -> int | list | None:
def preprocess(self, payload: int | None) -> int | list | None:
"""
Parameters:
payload: the index of the selected example in the dataset
Returns:
Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index")
"""
if payload is None:
return None
if self.type == "index":
return payload
elif self.type == "values":
return self.samples[payload]
def postprocess(self, samples: list[list]) -> dict:
def postprocess(self, sample: int | list | None) -> int | None:
"""
Parameters:
samples: Expects a `list[list]` corresponding to the dataset data, can be used to update the dataset.
sample: Expects an `int` index or `list` of sample data. Returns the index of the sample in the dataset or `None` if the sample is not found.
Returns:
Returns the updated dataset data as a `dict` with the key "samples".
Returns the index of the sample in the dataset.
"""
return {
"samples": samples,
"__type__": "update",
}
if sample is None or isinstance(sample, int):
return sample
if isinstance(sample, list):
try:
index = self.samples.index(sample)
except ValueError:
index = None
warnings.warn(
"The `Dataset` component does not support updating the dataset data by providing "
"a set of list values. Instead, you should return a new Dataset(samples=...) object."
)
return index
def example_payload(self) -> Any:
return 0

View File

@ -164,7 +164,7 @@ class CSVLogger(FlaggingCallback):
) / client_utils.strip_invalid_filename_characters(
getattr(component, "label", None) or f"component {idx}"
)
if utils.is_update(sample):
if utils.is_prop_update(sample):
csv_data.append(str(sample))
else:
data = (

View File

@ -544,7 +544,7 @@ class Examples:
component, components.File
):
value_to_use = value_as_dict
if not utils.is_update(value_as_dict):
if not utils.is_prop_update(value_as_dict):
raise TypeError("value wasn't an update") # caught below
output.append(value_as_dict)
except (ValueError, TypeError, SyntaxError):

View File

@ -737,7 +737,7 @@ def validate_url(possible_url: str) -> bool:
return False
def is_update(val):
def is_prop_update(val):
return isinstance(val, dict) and "update" in val.get("__type__", "")

View File

@ -86,6 +86,40 @@ def predict(···) -> list[list]
<DemosSection demos={obj.demos} />
{/if}
### Examples
**Updating a Dataset**
In this example, we display a text dataset using `gr.Dataset` and then update it when the user clicks a button:
```py
import gradio as gr
philosophy_quotes = [
["I think therefore I am."],
["The unexamined life is not worth living."]
]
startup_quotes = [
["Ideas are easy. Implementation is hard"],
["Make mistakes faster."]
]
def show_startup_quotes():
return gr.Dataset(samples=startup_quotes)
with gr.Blocks() as demo:
textbox = gr.Textbox()
dataset = gr.Dataset(components=[textbox], samples=philosophy_quotes)
button = gr.Button()
button.click(show_startup_quotes, None, dataset)
demo.launch()
```
{#if obj.fns && obj.fns.length > 0}
<!--- Event Listeners -->
### Event Listeners
@ -97,3 +131,4 @@ def predict(···) -> list[list]
### Guides
<GuidesSection guides={obj.guides}/>
{/if}

View File

@ -12,7 +12,7 @@
>;
export let label = "Examples";
export let headers: string[];
export let samples: any[][];
export let samples: any[][] | null = null;
export let elem_id = "";
export let elem_classes: string[] = [];
export let visible = true;
@ -34,7 +34,7 @@
: `${root}/file=`;
let page = 0;
$: gallery = components.length < 2;
let paginate = samples.length > samples_per_page;
let paginate = samples ? samples.length > samples_per_page : false;
let selected_samples: any[][];
let page_count: number;
@ -51,6 +51,7 @@
}
$: {
samples = samples || [];
paginate = samples.length > samples_per_page;
if (paginate) {
visible_pages = [];

View File

@ -43,27 +43,10 @@ class TestDataset:
assert dataset.samples == [["value 1"], ["value 2"]]
def test_postprocessing(self):
test_file_dir = Path(Path(__file__).parent, "test_files")
bus = Path(test_file_dir, "bus.png")
dataset = gr.Dataset(
components=["number", "textbox", "image", "html", "markdown"], type="index"
)
output = dataset.postprocess(
samples=[
[5, "hello", bus, "<b>Bold</b>", "**Bold**"],
[15, "hi", bus, "<i>Italics</i>", "*Italics*"],
],
)
assert output == {
"samples": [
[5, "hello", bus, "<b>Bold</b>", "**Bold**"],
[15, "hi", bus, "<i>Italics</i>", "*Italics*"],
],
"__type__": "update",
}
assert dataset.postprocess(1) == 1
@patch(

View File

@ -732,6 +732,33 @@ class TestBlocksPostprocessing:
):
await demo.postprocess_data(demo.fns[0], predictions=(1, 2), state=None)
@pytest.mark.asyncio
async def test_dataset_is_updated(self):
def update(value):
return value, gr.Dataset(samples=[["New A"], ["New B"]])
with gr.Blocks() as demo:
with gr.Row():
textbox = gr.Textbox()
dataset = gr.Dataset(
components=["text"], samples=[["Original"]], label="Saved Prompts"
)
dataset.click(update, inputs=[dataset], outputs=[textbox, dataset])
app, _, _ = demo.launch(prevent_thread_lock=True)
client = TestClient(app)
session_1 = client.post(
"/api/predict/",
json={"data": [0], "session_hash": "1", "fn_index": 0},
)
assert "Original" in session_1.json()["data"][0]
session_2 = client.post(
"/api/predict/",
json={"data": [0], "session_hash": "1", "fn_index": 0},
)
assert "New" in session_2.json()["data"][0]
class TestStateHolder:
@pytest.mark.asyncio