Better processing of example data prior to creating dataset component (#2128)

* WIP commit

* Add method to components

* Validate image filepath

* Remove unused imports

* Fix validation

* Only display name for model3D

* Remove image validation

* Don't use abstractmethod - add tests

* Remove unused import

* Remove breakpoint
This commit is contained in:
Freddy Boulton 2022-08-31 15:46:43 -04:00 committed by GitHub
parent f2ab162b5d
commit 880c63e200
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 4 deletions

View File

@ -211,6 +211,10 @@ class IOComponent(Component, Serializable):
load_fn = None load_fn = None
return load_fn, initial_value return load_fn, initial_value
def as_example(self, input_data):
"""Return the input data in a way that can be displayed by the examples dataset component in the front-end."""
return input_data
class FormComponent: class FormComponent:
expected_parent = Form expected_parent = Form
@ -2162,6 +2166,9 @@ class File(Changeable, Clearable, IOComponent, FileSerializable):
rounded=rounded, rounded=rounded,
) )
def as_example(self, input_data):
return Path(input_data).name
@document("change", "style") @document("change", "style")
class Dataframe(Changeable, IOComponent, JSONSerializable): class Dataframe(Changeable, IOComponent, JSONSerializable):
@ -2169,7 +2176,7 @@ class Dataframe(Changeable, IOComponent, JSONSerializable):
Accepts or displays 2D input through a spreadsheet-like component for dataframes. Accepts or displays 2D input through a spreadsheet-like component for dataframes.
Preprocessing: passes the uploaded spreadsheet data as a {pandas.DataFrame}, {numpy.array}, {List[List]}, or {List} depending on `type` Preprocessing: passes the uploaded spreadsheet data as a {pandas.DataFrame}, {numpy.array}, {List[List]}, or {List} depending on `type`
Postprocessing: expects a {pandas.DataFrame}, {numpy.array}, {List[List]}, {List}, a {Dict} with keys `data` (and optionally `headers`), or {str} path to a csv, which is rendered in the spreadsheet. Postprocessing: expects a {pandas.DataFrame}, {numpy.array}, {List[List]}, {List}, a {Dict} with keys `data` (and optionally `headers`), or {str} path to a csv, which is rendered in the spreadsheet.
Examples-format: a {str} filepath to a csv with data. Examples-format: a {str} filepath to a csv with data, a pandas dataframe, or a list of lists (excluding headers) where each sublist is a row of data.
Demos: filter_records, matrix_transpose, tax_calculator Demos: filter_records, matrix_transpose, tax_calculator
""" """
@ -2416,6 +2423,13 @@ class Dataframe(Changeable, IOComponent, JSONSerializable):
rounded=rounded, rounded=rounded,
) )
def as_example(self, input_data):
if isinstance(input_data, pd.DataFrame):
return input_data.head(n=5).to_dict(orient="split")["data"]
elif isinstance(input_data, np.ndarray):
return input_data.tolist()
return input_data
@document("change", "style") @document("change", "style")
class Timeseries(Changeable, IOComponent, JSONSerializable): class Timeseries(Changeable, IOComponent, JSONSerializable):
@ -3608,6 +3622,9 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable):
rounded=rounded, rounded=rounded,
) )
def as_example(self, input_data):
return Path(input_data).name
@document("change", "clear") @document("change", "clear")
class Plot(Changeable, Clearable, IOComponent, JSONSerializable): class Plot(Changeable, Clearable, IOComponent, JSONSerializable):
@ -3772,7 +3789,7 @@ class Dataset(Clickable, Component):
self, self,
*, *,
label: Optional[str] = None, label: Optional[str] = None,
components: List[Component] | List[str], components: List[IOComponent] | List[str],
samples: List[List[Any]], samples: List[List[Any]],
headers: Optional[List[str]] = None, headers: Optional[List[str]] = None,
type: str = "values", type: str = "values",
@ -3791,6 +3808,9 @@ class Dataset(Clickable, Component):
""" """
Component.__init__(self, visible=visible, elem_id=elem_id, **kwargs) Component.__init__(self, visible=visible, elem_id=elem_id, **kwargs)
self.components = [get_component_instance(c, render=False) for c in components] self.components = [get_component_instance(c, render=False) for c in components]
for example in samples:
for i, (component, ex) in enumerate(zip(self.components, example)):
example[i] = component.as_example(ex)
self.type = type self.type = type
self.label = label self.label = label
if headers is not None: if headers is not None:

View File

@ -6,10 +6,9 @@ from __future__ import annotations
import csv import csv
import inspect import inspect
import os import os
import shutil
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Callable, List, Optional
import anyio import anyio

View File

@ -1780,5 +1780,41 @@ class TestState:
assert result[0] == 2 assert result[0] == 2
def test_dataframe_as_example_converts_dataframes():
df_comp = gr.Dataframe()
assert df_comp.as_example(pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) == [
[1, 5],
[2, 6],
[3, 7],
[4, 8],
]
assert df_comp.as_example(np.array([[1, 2], [3, 4.0]])) == [[1.0, 2.0], [3.0, 4.0]]
@pytest.mark.parametrize("component", [gr.Model3D, gr.File])
def test_as_example_returns_file_basename(component):
component = component()
assert component.as_example("/home/freddy/sources/example.ext") == "example.ext"
@patch("gradio.components.IOComponent.as_example")
@patch("gradio.components.File.as_example")
@patch("gradio.components.Dataframe.as_example")
@patch("gradio.components.Model3D.as_example")
def test_dataset_calls_as_example(*mocks):
gr.Dataset(
components=[gr.Dataframe(), gr.File(), gr.Image(), gr.Model3D()],
samples=[
[
pd.DataFrame({"a": np.array([1, 2, 3])}),
"foo.png",
"bar.jpeg",
"duck.obj",
]
],
)
assert all([m.called for m in mocks])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()