mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
Better processing of example data prior to creating dataset component (#2128)
* WIP commit * Add method to components * Validate image filepath * Remove unused imports * Fix validation * Only display name for model3D * Remove image validation * Don't use abstractmethod - add tests * Remove unused import * Remove breakpoint
This commit is contained in:
parent
f2ab162b5d
commit
880c63e200
@ -211,6 +211,10 @@ class IOComponent(Component, Serializable):
|
|||||||
load_fn = None
|
load_fn = None
|
||||||
return load_fn, initial_value
|
return load_fn, initial_value
|
||||||
|
|
||||||
|
def as_example(self, input_data):
|
||||||
|
"""Return the input data in a way that can be displayed by the examples dataset component in the front-end."""
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
|
||||||
class FormComponent:
|
class FormComponent:
|
||||||
expected_parent = Form
|
expected_parent = Form
|
||||||
@ -2162,6 +2166,9 @@ class File(Changeable, Clearable, IOComponent, FileSerializable):
|
|||||||
rounded=rounded,
|
rounded=rounded,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def as_example(self, input_data):
|
||||||
|
return Path(input_data).name
|
||||||
|
|
||||||
|
|
||||||
@document("change", "style")
|
@document("change", "style")
|
||||||
class Dataframe(Changeable, IOComponent, JSONSerializable):
|
class Dataframe(Changeable, IOComponent, JSONSerializable):
|
||||||
@ -2169,7 +2176,7 @@ class Dataframe(Changeable, IOComponent, JSONSerializable):
|
|||||||
Accepts or displays 2D input through a spreadsheet-like component for dataframes.
|
Accepts or displays 2D input through a spreadsheet-like component for dataframes.
|
||||||
Preprocessing: passes the uploaded spreadsheet data as a {pandas.DataFrame}, {numpy.array}, {List[List]}, or {List} depending on `type`
|
Preprocessing: passes the uploaded spreadsheet data as a {pandas.DataFrame}, {numpy.array}, {List[List]}, or {List} depending on `type`
|
||||||
Postprocessing: expects a {pandas.DataFrame}, {numpy.array}, {List[List]}, {List}, a {Dict} with keys `data` (and optionally `headers`), or {str} path to a csv, which is rendered in the spreadsheet.
|
Postprocessing: expects a {pandas.DataFrame}, {numpy.array}, {List[List]}, {List}, a {Dict} with keys `data` (and optionally `headers`), or {str} path to a csv, which is rendered in the spreadsheet.
|
||||||
Examples-format: a {str} filepath to a csv with data.
|
Examples-format: a {str} filepath to a csv with data, a pandas dataframe, or a list of lists (excluding headers) where each sublist is a row of data.
|
||||||
Demos: filter_records, matrix_transpose, tax_calculator
|
Demos: filter_records, matrix_transpose, tax_calculator
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -2416,6 +2423,13 @@ class Dataframe(Changeable, IOComponent, JSONSerializable):
|
|||||||
rounded=rounded,
|
rounded=rounded,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def as_example(self, input_data):
|
||||||
|
if isinstance(input_data, pd.DataFrame):
|
||||||
|
return input_data.head(n=5).to_dict(orient="split")["data"]
|
||||||
|
elif isinstance(input_data, np.ndarray):
|
||||||
|
return input_data.tolist()
|
||||||
|
return input_data
|
||||||
|
|
||||||
|
|
||||||
@document("change", "style")
|
@document("change", "style")
|
||||||
class Timeseries(Changeable, IOComponent, JSONSerializable):
|
class Timeseries(Changeable, IOComponent, JSONSerializable):
|
||||||
@ -3608,6 +3622,9 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable):
|
|||||||
rounded=rounded,
|
rounded=rounded,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def as_example(self, input_data):
|
||||||
|
return Path(input_data).name
|
||||||
|
|
||||||
|
|
||||||
@document("change", "clear")
|
@document("change", "clear")
|
||||||
class Plot(Changeable, Clearable, IOComponent, JSONSerializable):
|
class Plot(Changeable, Clearable, IOComponent, JSONSerializable):
|
||||||
@ -3772,7 +3789,7 @@ class Dataset(Clickable, Component):
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
label: Optional[str] = None,
|
label: Optional[str] = None,
|
||||||
components: List[Component] | List[str],
|
components: List[IOComponent] | List[str],
|
||||||
samples: List[List[Any]],
|
samples: List[List[Any]],
|
||||||
headers: Optional[List[str]] = None,
|
headers: Optional[List[str]] = None,
|
||||||
type: str = "values",
|
type: str = "values",
|
||||||
@ -3791,6 +3808,9 @@ class Dataset(Clickable, Component):
|
|||||||
"""
|
"""
|
||||||
Component.__init__(self, visible=visible, elem_id=elem_id, **kwargs)
|
Component.__init__(self, visible=visible, elem_id=elem_id, **kwargs)
|
||||||
self.components = [get_component_instance(c, render=False) for c in components]
|
self.components = [get_component_instance(c, render=False) for c in components]
|
||||||
|
for example in samples:
|
||||||
|
for i, (component, ex) in enumerate(zip(self.components, example)):
|
||||||
|
example[i] = component.as_example(ex)
|
||||||
self.type = type
|
self.type = type
|
||||||
self.label = label
|
self.label = label
|
||||||
if headers is not None:
|
if headers is not None:
|
||||||
|
@ -6,10 +6,9 @@ from __future__ import annotations
|
|||||||
import csv
|
import csv
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Callable, List, Optional
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
|
|
||||||
|
@ -1780,5 +1780,41 @@ class TestState:
|
|||||||
assert result[0] == 2
|
assert result[0] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataframe_as_example_converts_dataframes():
|
||||||
|
df_comp = gr.Dataframe()
|
||||||
|
assert df_comp.as_example(pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) == [
|
||||||
|
[1, 5],
|
||||||
|
[2, 6],
|
||||||
|
[3, 7],
|
||||||
|
[4, 8],
|
||||||
|
]
|
||||||
|
assert df_comp.as_example(np.array([[1, 2], [3, 4.0]])) == [[1.0, 2.0], [3.0, 4.0]]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("component", [gr.Model3D, gr.File])
|
||||||
|
def test_as_example_returns_file_basename(component):
|
||||||
|
component = component()
|
||||||
|
assert component.as_example("/home/freddy/sources/example.ext") == "example.ext"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("gradio.components.IOComponent.as_example")
|
||||||
|
@patch("gradio.components.File.as_example")
|
||||||
|
@patch("gradio.components.Dataframe.as_example")
|
||||||
|
@patch("gradio.components.Model3D.as_example")
|
||||||
|
def test_dataset_calls_as_example(*mocks):
|
||||||
|
gr.Dataset(
|
||||||
|
components=[gr.Dataframe(), gr.File(), gr.Image(), gr.Model3D()],
|
||||||
|
samples=[
|
||||||
|
[
|
||||||
|
pd.DataFrame({"a": np.array([1, 2, 3])}),
|
||||||
|
"foo.png",
|
||||||
|
"bar.jpeg",
|
||||||
|
"duck.obj",
|
||||||
|
]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert all([m.called for m in mocks])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user