Better processing of example data prior to creating dataset component (#2128)

* WIP commit

* Add method to components

* Validate image filepath

* Remove unused imports

* Fix validation

* Only display name for model3D

* Remove image validation

* Don't use abstractmethod - add tests

* Remove unused import

* Remove breakpoint
This commit is contained in:
Freddy Boulton 2022-08-31 15:46:43 -04:00 committed by GitHub
parent f2ab162b5d
commit 880c63e200
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 4 deletions

View File

@ -211,6 +211,10 @@ class IOComponent(Component, Serializable):
load_fn = None
return load_fn, initial_value
def as_example(self, input_data):
"""Return the input data in a way that can be displayed by the examples dataset component in the front-end."""
return input_data
class FormComponent:
expected_parent = Form
@ -2162,6 +2166,9 @@ class File(Changeable, Clearable, IOComponent, FileSerializable):
rounded=rounded,
)
def as_example(self, input_data):
return Path(input_data).name
@document("change", "style")
class Dataframe(Changeable, IOComponent, JSONSerializable):
@ -2169,7 +2176,7 @@ class Dataframe(Changeable, IOComponent, JSONSerializable):
Accepts or displays 2D input through a spreadsheet-like component for dataframes.
Preprocessing: passes the uploaded spreadsheet data as a {pandas.DataFrame}, {numpy.array}, {List[List]}, or {List} depending on `type`
Postprocessing: expects a {pandas.DataFrame}, {numpy.array}, {List[List]}, {List}, a {Dict} with keys `data` (and optionally `headers`), or {str} path to a csv, which is rendered in the spreadsheet.
Examples-format: a {str} filepath to a csv with data.
Examples-format: a {str} filepath to a csv with data, a pandas dataframe, or a list of lists (excluding headers) where each sublist is a row of data.
Demos: filter_records, matrix_transpose, tax_calculator
"""
@ -2416,6 +2423,13 @@ class Dataframe(Changeable, IOComponent, JSONSerializable):
rounded=rounded,
)
def as_example(self, input_data):
if isinstance(input_data, pd.DataFrame):
return input_data.head(n=5).to_dict(orient="split")["data"]
elif isinstance(input_data, np.ndarray):
return input_data.tolist()
return input_data
@document("change", "style")
class Timeseries(Changeable, IOComponent, JSONSerializable):
@ -3608,6 +3622,9 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable):
rounded=rounded,
)
def as_example(self, input_data):
return Path(input_data).name
@document("change", "clear")
class Plot(Changeable, Clearable, IOComponent, JSONSerializable):
@ -3772,7 +3789,7 @@ class Dataset(Clickable, Component):
self,
*,
label: Optional[str] = None,
components: List[Component] | List[str],
components: List[IOComponent] | List[str],
samples: List[List[Any]],
headers: Optional[List[str]] = None,
type: str = "values",
@ -3791,6 +3808,9 @@ class Dataset(Clickable, Component):
"""
Component.__init__(self, visible=visible, elem_id=elem_id, **kwargs)
self.components = [get_component_instance(c, render=False) for c in components]
for example in samples:
for i, (component, ex) in enumerate(zip(self.components, example)):
example[i] = component.as_example(ex)
self.type = type
self.label = label
if headers is not None:

View File

@ -6,10 +6,9 @@ from __future__ import annotations
import csv
import inspect
import os
import shutil
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, List, Optional
import anyio

View File

@ -1780,5 +1780,41 @@ class TestState:
assert result[0] == 2
def test_dataframe_as_example_converts_dataframes():
df_comp = gr.Dataframe()
assert df_comp.as_example(pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) == [
[1, 5],
[2, 6],
[3, 7],
[4, 8],
]
assert df_comp.as_example(np.array([[1, 2], [3, 4.0]])) == [[1.0, 2.0], [3.0, 4.0]]
@pytest.mark.parametrize("component", [gr.Model3D, gr.File])
def test_as_example_returns_file_basename(component):
component = component()
assert component.as_example("/home/freddy/sources/example.ext") == "example.ext"
@patch("gradio.components.IOComponent.as_example")
@patch("gradio.components.File.as_example")
@patch("gradio.components.Dataframe.as_example")
@patch("gradio.components.Model3D.as_example")
def test_dataset_calls_as_example(*mocks):
gr.Dataset(
components=[gr.Dataframe(), gr.File(), gr.Image(), gr.Model3D()],
samples=[
[
pd.DataFrame({"a": np.array([1, 2, 3])}),
"foo.png",
"bar.jpeg",
"duck.obj",
]
],
)
assert all([m.called for m in mocks])
if __name__ == "__main__":
unittest.main()