From 880c63e200f523f8600fe1dd2eb74c4880bcc1fc Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Wed, 31 Aug 2022 15:46:43 -0400 Subject: [PATCH] Better processing of example data prior to creating dataset component (#2128) * WIP commit * Add method to components * Validate image filepath * Remove unused imports * Fix validation * Only display name for model3D * Remove image validation * Don't use abstractmethod - add tests * Remove unused import * Remove breakpoint --- gradio/components.py | 24 ++++++++++++++++++++++-- gradio/examples.py | 3 +-- test/test_components.py | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/gradio/components.py b/gradio/components.py index 6c87cbb76f..e6de494522 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -211,6 +211,10 @@ class IOComponent(Component, Serializable): load_fn = None return load_fn, initial_value + def as_example(self, input_data): + """Return the input data in a way that can be displayed by the examples dataset component in the front-end.""" + return input_data + class FormComponent: expected_parent = Form @@ -2162,6 +2166,9 @@ class File(Changeable, Clearable, IOComponent, FileSerializable): rounded=rounded, ) + def as_example(self, input_data): + return Path(input_data).name + @document("change", "style") class Dataframe(Changeable, IOComponent, JSONSerializable): @@ -2169,7 +2176,7 @@ class Dataframe(Changeable, IOComponent, JSONSerializable): Accepts or displays 2D input through a spreadsheet-like component for dataframes. Preprocessing: passes the uploaded spreadsheet data as a {pandas.DataFrame}, {numpy.array}, {List[List]}, or {List} depending on `type` Postprocessing: expects a {pandas.DataFrame}, {numpy.array}, {List[List]}, {List}, a {Dict} with keys `data` (and optionally `headers`), or {str} path to a csv, which is rendered in the spreadsheet. - Examples-format: a {str} filepath to a csv with data. + Examples-format: a {str} filepath to a csv with data, a pandas dataframe, or a list of lists (excluding headers) where each sublist is a row of data. Demos: filter_records, matrix_transpose, tax_calculator """ @@ -2416,6 +2423,13 @@ class Dataframe(Changeable, IOComponent, JSONSerializable): rounded=rounded, ) + def as_example(self, input_data): + if isinstance(input_data, pd.DataFrame): + return input_data.head(n=5).to_dict(orient="split")["data"] + elif isinstance(input_data, np.ndarray): + return input_data.tolist() + return input_data + @document("change", "style") class Timeseries(Changeable, IOComponent, JSONSerializable): @@ -3608,6 +3622,9 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable): rounded=rounded, ) + def as_example(self, input_data): + return Path(input_data).name + @document("change", "clear") class Plot(Changeable, Clearable, IOComponent, JSONSerializable): @@ -3772,7 +3789,7 @@ class Dataset(Clickable, Component): self, *, label: Optional[str] = None, - components: List[Component] | List[str], + components: List[IOComponent] | List[str], samples: List[List[Any]], headers: Optional[List[str]] = None, type: str = "values", @@ -3791,6 +3808,9 @@ class Dataset(Clickable, Component): """ Component.__init__(self, visible=visible, elem_id=elem_id, **kwargs) self.components = [get_component_instance(c, render=False) for c in components] + for example in samples: + for i, (component, ex) in enumerate(zip(self.components, example)): + example[i] = component.as_example(ex) self.type = type self.label = label if headers is not None: diff --git a/gradio/examples.py b/gradio/examples.py index cf5f88d123..9e0bc6bd8b 100644 --- a/gradio/examples.py +++ b/gradio/examples.py @@ -6,10 +6,9 @@ from __future__ import annotations import csv import inspect import os -import shutil import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, List, Optional import anyio diff --git a/test/test_components.py b/test/test_components.py index 9d015577cd..6f179b6e31 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -1780,5 +1780,41 @@ class TestState: assert result[0] == 2 +def test_dataframe_as_example_converts_dataframes(): + df_comp = gr.Dataframe() + assert df_comp.as_example(pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) == [ + [1, 5], + [2, 6], + [3, 7], + [4, 8], + ] + assert df_comp.as_example(np.array([[1, 2], [3, 4.0]])) == [[1.0, 2.0], [3.0, 4.0]] + + +@pytest.mark.parametrize("component", [gr.Model3D, gr.File]) +def test_as_example_returns_file_basename(component): + component = component() + assert component.as_example("/home/freddy/sources/example.ext") == "example.ext" + + +@patch("gradio.components.IOComponent.as_example") +@patch("gradio.components.File.as_example") +@patch("gradio.components.Dataframe.as_example") +@patch("gradio.components.Model3D.as_example") +def test_dataset_calls_as_example(*mocks): + gr.Dataset( + components=[gr.Dataframe(), gr.File(), gr.Image(), gr.Model3D()], + samples=[ + [ + pd.DataFrame({"a": np.array([1, 2, 3])}), + "foo.png", + "bar.jpeg", + "duck.obj", + ] + ], + ) + assert all([m.called for m in mocks]) + + if __name__ == "__main__": unittest.main()