mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
fix dataset update (#8581)
* fix dataset update * revert' * add changeset * add test * add changeset * changes * add template * add changeset * fix docstring * test postprocessing --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
2b0c1577b2
commit
a1c21cb69a
7
.changeset/soft-worms-remain.md
Normal file
7
.changeset/soft-worms-remain.md
Normal file
@ -0,0 +1,7 @@
|
||||
---
|
||||
"@gradio/dataset": patch
|
||||
"gradio": patch
|
||||
"website": patch
|
||||
---
|
||||
|
||||
fix:fix dataset update
|
@ -1723,12 +1723,12 @@ Received outputs:
|
||||
) from err
|
||||
|
||||
if block.stateful:
|
||||
if not utils.is_update(predictions[i]):
|
||||
if not utils.is_prop_update(predictions[i]):
|
||||
state[block._id] = predictions[i]
|
||||
output.append(None)
|
||||
else:
|
||||
prediction_value = predictions[i]
|
||||
if utils.is_update(
|
||||
if utils.is_prop_update(
|
||||
prediction_value
|
||||
): # if update is passed directly (deprecated), remove Nones
|
||||
prediction_value = utils.delete_none(
|
||||
@ -1738,7 +1738,7 @@ Received outputs:
|
||||
if isinstance(prediction_value, Block):
|
||||
prediction_value = prediction_value.constructor_args.copy()
|
||||
prediction_value["__type__"] = "update"
|
||||
if utils.is_update(prediction_value):
|
||||
if utils.is_prop_update(prediction_value):
|
||||
kwargs = state[block._id].constructor_args.copy()
|
||||
kwargs.update(prediction_value)
|
||||
kwargs.pop("value", None)
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Literal
|
||||
|
||||
from gradio_client.documentation import document
|
||||
@ -17,7 +18,8 @@ from gradio.events import Events
|
||||
@document()
|
||||
class Dataset(Component):
|
||||
"""
|
||||
Creates a gallery or table to display data samples. This component is designed for internal use to display examples.
|
||||
Creates a gallery or table to display data samples. This component is primarily designed for internal use to display examples.
|
||||
However, it can also be used directly to display a dataset and let users select examples.
|
||||
"""
|
||||
|
||||
EVENTS = [Events.click, Events.select]
|
||||
@ -26,7 +28,7 @@ class Dataset(Component):
|
||||
self,
|
||||
*,
|
||||
label: str | None = None,
|
||||
components: list[Component] | list[str],
|
||||
components: list[Component] | list[str] | None = None,
|
||||
component_props: list[dict[str, Any]] | None = None,
|
||||
samples: list[list[Any]] | None = None,
|
||||
headers: list[str] | None = None,
|
||||
@ -70,7 +72,7 @@ class Dataset(Component):
|
||||
self.container = container
|
||||
self.scale = scale
|
||||
self.min_width = min_width
|
||||
self._components = [get_component_instance(c) for c in components]
|
||||
self._components = [get_component_instance(c) for c in components or []]
|
||||
if component_props is None:
|
||||
self.component_props = [
|
||||
component.recover_kwargs(
|
||||
@ -131,29 +133,39 @@ class Dataset(Component):
|
||||
|
||||
return config
|
||||
|
||||
def preprocess(self, payload: int) -> int | list | None:
|
||||
def preprocess(self, payload: int | None) -> int | list | None:
|
||||
"""
|
||||
Parameters:
|
||||
payload: the index of the selected example in the dataset
|
||||
Returns:
|
||||
Passes the selected sample either as a `list` of data corresponding to each input component (if `type` is "value") or as an `int` index (if `type` is "index")
|
||||
"""
|
||||
if payload is None:
|
||||
return None
|
||||
if self.type == "index":
|
||||
return payload
|
||||
elif self.type == "values":
|
||||
return self.samples[payload]
|
||||
|
||||
def postprocess(self, samples: list[list]) -> dict:
|
||||
def postprocess(self, sample: int | list | None) -> int | None:
|
||||
"""
|
||||
Parameters:
|
||||
samples: Expects a `list[list]` corresponding to the dataset data, can be used to update the dataset.
|
||||
sample: Expects an `int` index or `list` of sample data. Returns the index of the sample in the dataset or `None` if the sample is not found.
|
||||
Returns:
|
||||
Returns the updated dataset data as a `dict` with the key "samples".
|
||||
Returns the index of the sample in the dataset.
|
||||
"""
|
||||
return {
|
||||
"samples": samples,
|
||||
"__type__": "update",
|
||||
}
|
||||
if sample is None or isinstance(sample, int):
|
||||
return sample
|
||||
if isinstance(sample, list):
|
||||
try:
|
||||
index = self.samples.index(sample)
|
||||
except ValueError:
|
||||
index = None
|
||||
warnings.warn(
|
||||
"The `Dataset` component does not support updating the dataset data by providing "
|
||||
"a set of list values. Instead, you should return a new Dataset(samples=...) object."
|
||||
)
|
||||
return index
|
||||
|
||||
def example_payload(self) -> Any:
|
||||
return 0
|
||||
|
@ -164,7 +164,7 @@ class CSVLogger(FlaggingCallback):
|
||||
) / client_utils.strip_invalid_filename_characters(
|
||||
getattr(component, "label", None) or f"component {idx}"
|
||||
)
|
||||
if utils.is_update(sample):
|
||||
if utils.is_prop_update(sample):
|
||||
csv_data.append(str(sample))
|
||||
else:
|
||||
data = (
|
||||
|
@ -544,7 +544,7 @@ class Examples:
|
||||
component, components.File
|
||||
):
|
||||
value_to_use = value_as_dict
|
||||
if not utils.is_update(value_as_dict):
|
||||
if not utils.is_prop_update(value_as_dict):
|
||||
raise TypeError("value wasn't an update") # caught below
|
||||
output.append(value_as_dict)
|
||||
except (ValueError, TypeError, SyntaxError):
|
||||
|
@ -737,7 +737,7 @@ def validate_url(possible_url: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def is_update(val):
|
||||
def is_prop_update(val):
|
||||
return isinstance(val, dict) and "update" in val.get("__type__", "")
|
||||
|
||||
|
||||
|
@ -86,6 +86,40 @@ def predict(···) -> list[list]
|
||||
<DemosSection demos={obj.demos} />
|
||||
{/if}
|
||||
|
||||
### Examples
|
||||
|
||||
**Updating a Dataset**
|
||||
|
||||
In this example, we display a text dataset using `gr.Dataset` and then update it when the user clicks a button:
|
||||
|
||||
```py
|
||||
import gradio as gr
|
||||
|
||||
philosophy_quotes = [
|
||||
["I think therefore I am."],
|
||||
["The unexamined life is not worth living."]
|
||||
]
|
||||
|
||||
startup_quotes = [
|
||||
["Ideas are easy. Implementation is hard"],
|
||||
["Make mistakes faster."]
|
||||
]
|
||||
|
||||
def show_startup_quotes():
|
||||
return gr.Dataset(samples=startup_quotes)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
textbox = gr.Textbox()
|
||||
dataset = gr.Dataset(components=[textbox], samples=philosophy_quotes)
|
||||
button = gr.Button()
|
||||
|
||||
button.click(show_startup_quotes, None, dataset)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
|
||||
|
||||
{#if obj.fns && obj.fns.length > 0}
|
||||
<!--- Event Listeners -->
|
||||
### Event Listeners
|
||||
@ -97,3 +131,4 @@ def predict(···) -> list[list]
|
||||
### Guides
|
||||
<GuidesSection guides={obj.guides}/>
|
||||
{/if}
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
>;
|
||||
export let label = "Examples";
|
||||
export let headers: string[];
|
||||
export let samples: any[][];
|
||||
export let samples: any[][] | null = null;
|
||||
export let elem_id = "";
|
||||
export let elem_classes: string[] = [];
|
||||
export let visible = true;
|
||||
@ -34,7 +34,7 @@
|
||||
: `${root}/file=`;
|
||||
let page = 0;
|
||||
$: gallery = components.length < 2;
|
||||
let paginate = samples.length > samples_per_page;
|
||||
let paginate = samples ? samples.length > samples_per_page : false;
|
||||
|
||||
let selected_samples: any[][];
|
||||
let page_count: number;
|
||||
@ -51,6 +51,7 @@
|
||||
}
|
||||
|
||||
$: {
|
||||
samples = samples || [];
|
||||
paginate = samples.length > samples_per_page;
|
||||
if (paginate) {
|
||||
visible_pages = [];
|
||||
|
@ -43,27 +43,10 @@ class TestDataset:
|
||||
assert dataset.samples == [["value 1"], ["value 2"]]
|
||||
|
||||
def test_postprocessing(self):
|
||||
test_file_dir = Path(Path(__file__).parent, "test_files")
|
||||
bus = Path(test_file_dir, "bus.png")
|
||||
|
||||
dataset = gr.Dataset(
|
||||
components=["number", "textbox", "image", "html", "markdown"], type="index"
|
||||
)
|
||||
|
||||
output = dataset.postprocess(
|
||||
samples=[
|
||||
[5, "hello", bus, "<b>Bold</b>", "**Bold**"],
|
||||
[15, "hi", bus, "<i>Italics</i>", "*Italics*"],
|
||||
],
|
||||
)
|
||||
|
||||
assert output == {
|
||||
"samples": [
|
||||
[5, "hello", bus, "<b>Bold</b>", "**Bold**"],
|
||||
[15, "hi", bus, "<i>Italics</i>", "*Italics*"],
|
||||
],
|
||||
"__type__": "update",
|
||||
}
|
||||
assert dataset.postprocess(1) == 1
|
||||
|
||||
|
||||
@patch(
|
||||
|
@ -732,6 +732,33 @@ class TestBlocksPostprocessing:
|
||||
):
|
||||
await demo.postprocess_data(demo.fns[0], predictions=(1, 2), state=None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dataset_is_updated(self):
|
||||
def update(value):
|
||||
return value, gr.Dataset(samples=[["New A"], ["New B"]])
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
with gr.Row():
|
||||
textbox = gr.Textbox()
|
||||
dataset = gr.Dataset(
|
||||
components=["text"], samples=[["Original"]], label="Saved Prompts"
|
||||
)
|
||||
dataset.click(update, inputs=[dataset], outputs=[textbox, dataset])
|
||||
app, _, _ = demo.launch(prevent_thread_lock=True)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
session_1 = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": [0], "session_hash": "1", "fn_index": 0},
|
||||
)
|
||||
assert "Original" in session_1.json()["data"][0]
|
||||
session_2 = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": [0], "session_hash": "1", "fn_index": 0},
|
||||
)
|
||||
assert "New" in session_2.json()["data"][0]
|
||||
|
||||
|
||||
class TestStateHolder:
|
||||
@pytest.mark.asyncio
|
||||
|
Loading…
Reference in New Issue
Block a user