mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
Better processing of example data prior to creating dataset component (#2128)
* WIP commit * Add method to components * Validate image filepath * Remove unused imports * Fix validation * Only display name for model3D * Remove image validation * Don't use abstractmethod - add tests * Remove unused import * Remove breakpoint
This commit is contained in:
parent
f2ab162b5d
commit
880c63e200
@ -211,6 +211,10 @@ class IOComponent(Component, Serializable):
|
||||
load_fn = None
|
||||
return load_fn, initial_value
|
||||
|
||||
def as_example(self, input_data):
|
||||
"""Return the input data in a way that can be displayed by the examples dataset component in the front-end."""
|
||||
return input_data
|
||||
|
||||
|
||||
class FormComponent:
|
||||
expected_parent = Form
|
||||
@ -2162,6 +2166,9 @@ class File(Changeable, Clearable, IOComponent, FileSerializable):
|
||||
rounded=rounded,
|
||||
)
|
||||
|
||||
def as_example(self, input_data):
|
||||
return Path(input_data).name
|
||||
|
||||
|
||||
@document("change", "style")
|
||||
class Dataframe(Changeable, IOComponent, JSONSerializable):
|
||||
@ -2169,7 +2176,7 @@ class Dataframe(Changeable, IOComponent, JSONSerializable):
|
||||
Accepts or displays 2D input through a spreadsheet-like component for dataframes.
|
||||
Preprocessing: passes the uploaded spreadsheet data as a {pandas.DataFrame}, {numpy.array}, {List[List]}, or {List} depending on `type`
|
||||
Postprocessing: expects a {pandas.DataFrame}, {numpy.array}, {List[List]}, {List}, a {Dict} with keys `data` (and optionally `headers`), or {str} path to a csv, which is rendered in the spreadsheet.
|
||||
Examples-format: a {str} filepath to a csv with data.
|
||||
Examples-format: a {str} filepath to a csv with data, a pandas dataframe, or a list of lists (excluding headers) where each sublist is a row of data.
|
||||
Demos: filter_records, matrix_transpose, tax_calculator
|
||||
"""
|
||||
|
||||
@ -2416,6 +2423,13 @@ class Dataframe(Changeable, IOComponent, JSONSerializable):
|
||||
rounded=rounded,
|
||||
)
|
||||
|
||||
def as_example(self, input_data):
|
||||
if isinstance(input_data, pd.DataFrame):
|
||||
return input_data.head(n=5).to_dict(orient="split")["data"]
|
||||
elif isinstance(input_data, np.ndarray):
|
||||
return input_data.tolist()
|
||||
return input_data
|
||||
|
||||
|
||||
@document("change", "style")
|
||||
class Timeseries(Changeable, IOComponent, JSONSerializable):
|
||||
@ -3608,6 +3622,9 @@ class Model3D(Changeable, Editable, Clearable, IOComponent, FileSerializable):
|
||||
rounded=rounded,
|
||||
)
|
||||
|
||||
def as_example(self, input_data):
|
||||
return Path(input_data).name
|
||||
|
||||
|
||||
@document("change", "clear")
|
||||
class Plot(Changeable, Clearable, IOComponent, JSONSerializable):
|
||||
@ -3772,7 +3789,7 @@ class Dataset(Clickable, Component):
|
||||
self,
|
||||
*,
|
||||
label: Optional[str] = None,
|
||||
components: List[Component] | List[str],
|
||||
components: List[IOComponent] | List[str],
|
||||
samples: List[List[Any]],
|
||||
headers: Optional[List[str]] = None,
|
||||
type: str = "values",
|
||||
@ -3791,6 +3808,9 @@ class Dataset(Clickable, Component):
|
||||
"""
|
||||
Component.__init__(self, visible=visible, elem_id=elem_id, **kwargs)
|
||||
self.components = [get_component_instance(c, render=False) for c in components]
|
||||
for example in samples:
|
||||
for i, (component, ex) in enumerate(zip(self.components, example)):
|
||||
example[i] = component.as_example(ex)
|
||||
self.type = type
|
||||
self.label = label
|
||||
if headers is not None:
|
||||
|
@ -6,10 +6,9 @@ from __future__ import annotations
|
||||
import csv
|
||||
import inspect
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional
|
||||
|
||||
import anyio
|
||||
|
||||
|
@ -1780,5 +1780,41 @@ class TestState:
|
||||
assert result[0] == 2
|
||||
|
||||
|
||||
def test_dataframe_as_example_converts_dataframes():
|
||||
df_comp = gr.Dataframe()
|
||||
assert df_comp.as_example(pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) == [
|
||||
[1, 5],
|
||||
[2, 6],
|
||||
[3, 7],
|
||||
[4, 8],
|
||||
]
|
||||
assert df_comp.as_example(np.array([[1, 2], [3, 4.0]])) == [[1.0, 2.0], [3.0, 4.0]]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("component", [gr.Model3D, gr.File])
|
||||
def test_as_example_returns_file_basename(component):
|
||||
component = component()
|
||||
assert component.as_example("/home/freddy/sources/example.ext") == "example.ext"
|
||||
|
||||
|
||||
@patch("gradio.components.IOComponent.as_example")
|
||||
@patch("gradio.components.File.as_example")
|
||||
@patch("gradio.components.Dataframe.as_example")
|
||||
@patch("gradio.components.Model3D.as_example")
|
||||
def test_dataset_calls_as_example(*mocks):
|
||||
gr.Dataset(
|
||||
components=[gr.Dataframe(), gr.File(), gr.Image(), gr.Model3D()],
|
||||
samples=[
|
||||
[
|
||||
pd.DataFrame({"a": np.array([1, 2, 3])}),
|
||||
"foo.png",
|
||||
"bar.jpeg",
|
||||
"duck.obj",
|
||||
]
|
||||
],
|
||||
)
|
||||
assert all([m.called for m in mocks])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user