allow data frame to change number of columns (#1716)

* allow data frame to change number of columns

* allow data frame to change number of columns

* fix types

* fix tests

* handle case when headers are not present

* fix tests

* fix tests finally

* reinstate demo

* address review comments

* tweak test for fix
This commit is contained in:
pngwn 2022-07-12 11:35:20 +01:00 committed by GitHub
parent 4c26f412a8
commit a18c7ddf04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 60 additions and 47 deletions

View File

@ -16,7 +16,7 @@ import tempfile
import warnings
from copy import deepcopy
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict
import matplotlib.figure
import numpy as np
@ -2394,7 +2394,11 @@ class File(Changeable, Clearable, IOComponent):
)
@document()
class DataframeData(TypedDict):
headers: List[str]
data: List[List[str | int | bool]]
class Dataframe(Changeable, IOComponent):
"""
Accepts or displays 2D input through a spreadsheet-like component for dataframes.
@ -2526,26 +2530,22 @@ class Dataframe(Changeable, IOComponent):
}
return IOComponent.add_interactive_to_config(updated_config, interactive)
def preprocess(
self, x: List[List[str | Number | bool]]
) -> pd.DataFrame | np.ndarray | List[List[str | float | bool]]:
def preprocess(self, x: DataframeData):
"""
Parameters:
x: 2D array of str, numeric, or bool data
x (Dict[headers: List[str], data: List[List[str | int | bool]]]): 2D array of str, numeric, or bool data
Returns:
Dataframe in requested format
"""
if self.type == "pandas":
if self.headers:
return pd.DataFrame(x, columns=self.headers)
if x.get("headers") is not None:
return pd.DataFrame(x["data"], columns=x.get("headers"))
else:
return pd.DataFrame(x)
if self.col_count[0] == 1:
x = [row[0] for row in x]
return pd.DataFrame(x["data"])
if self.type == "numpy":
return np.array(x)
return np.array(x["data"])
elif self.type == "array":
return x
return x["data"]
else:
raise ValueError(
"Unknown type: "
@ -2592,8 +2592,6 @@ class Dataframe(Changeable, IOComponent):
if isinstance(y, (np.ndarray, list)):
if isinstance(y, np.ndarray):
y = y.tolist()
if len(y) == 0 or not isinstance(y[0], list):
y = [y]
return {
"data": Dataframe.__process_markdown(y, self.datatype),
}

View File

@ -999,7 +999,10 @@ class TestDataframe(unittest.TestCase):
"""
Preprocess, serialize, save_flagged, restore_flagged, generate_sample, get_config
"""
x_data = [["Tim", 12, False], ["Jan", 24, True]]
x_data = {
"data": [["Tim", 12, False], ["Jan", 24, True]],
"headers": ["Name", "Age", "Member"],
}
dataframe_input = gr.Dataframe(headers=["Name", "Age", "Member"])
output = dataframe_input.preprocess(x_data)
self.assertEqual(output["Age"][1], 24)
@ -1046,7 +1049,7 @@ class TestDataframe(unittest.TestCase):
)
dataframe_input = gr.Dataframe()
output = dataframe_input.preprocess(x_data)
self.assertEqual(output[1][1], 24)
self.assertEqual(output["Age"][1], 24)
with self.assertRaises(ValueError):
wrong_type = gr.Dataframe(type="unknown")
wrong_type.preprocess(x_data)
@ -1120,13 +1123,13 @@ class TestDataframe(unittest.TestCase):
"""
Interface, process,
"""
x_data = [[1, 2, 3], [4, 5, 6]]
x_data = {"data": [[1, 2, 3], [4, 5, 6]]}
iface = gr.Interface(np.max, "numpy", "number")
self.assertEqual(iface.process([x_data]), [6])
x_data = [["Tim"], ["Jon"], ["Sal"]]
x_data = {"data": [["Tim"], ["Jon"], ["Sal"]]}
def get_last(my_list):
return my_list[-1]
return my_list[-1][-1]
iface = gr.Interface(get_last, "list", "text")
self.assertEqual(iface.process([x_data]), ["Sal"])
@ -1140,7 +1143,9 @@ class TestDataframe(unittest.TestCase):
return array % 2 == 0
iface = gr.Interface(check_odd, "numpy", "numpy")
self.assertEqual(iface.process([[2, 3, 4]])[0], {"data": [[True, False, True]]})
self.assertEqual(
iface.process([{"data": [[2, 3, 4]]}])[0], {"data": [[True, False, True]]}
)
class TestVideo(unittest.TestCase):
@ -1356,14 +1361,16 @@ class TestTimeseries(unittest.TestCase):
"""
timeseries_output = gr.Timeseries(x="time", y=["retail", "food", "other"])
iface = gr.Interface(lambda x: x, "dataframe", timeseries_output)
df = pd.DataFrame(
{
"time": [1, 2, 3, 4],
"retail": [1, 2, 3, 2],
"food": [1, 2, 3, 2],
"other": [1, 2, 4, 2],
}
)
df = {
"data": pd.DataFrame(
{
"time": [1, 2, 3, 4],
"retail": [1, 2, 3, 2],
"food": [1, 2, 3, 2],
"other": [1, 2, 4, 2],
}
)
}
self.assertEqual(
iface.process([df]),
[
@ -1585,6 +1592,7 @@ class TestJSON(unittest.TestCase):
"""
def get_avg_age_per_gender(data):
print(data)
return {
"M": int(data[data["gender"] == "M"].mean()),
"F": int(data[data["gender"] == "F"].mean()),
@ -1603,7 +1611,10 @@ class TestJSON(unittest.TestCase):
["O", 20],
["F", 30],
]
self.assertDictEqual(iface.process([y_data])[0], {"M": 35, "F": 25, "O": 20})
self.assertDictEqual(
iface.process([{"data": y_data, "headers": ["gender", "age"]}])[0],
{"M": 35, "F": 25, "O": 20},
)
class TestHTML(unittest.TestCase):

View File

@ -22,18 +22,6 @@
export let wrap: boolean;
export let datatype: Datatype | Array<Datatype>;
$: {
if (value && !Array.isArray(value)) {
if (Array.isArray(value.headers)) headers = value.headers;
value =
value.data.length === 0 ? [Array(headers.length).fill("")] : value.data;
} else if (value === null) {
value = [Array(headers.length).fill("")];
} else {
value = value;
}
}
const dispatch = createEventDispatcher();
export let loading_status: LoadingStatus;

View File

@ -22,7 +22,23 @@
export let style: Styles = {};
export let wrap: boolean = false;
const dispatch = createEventDispatcher<{ change: typeof values }>();
$: {
if (values && !Array.isArray(values)) {
if (Array.isArray(values.headers)) headers = values.headers;
values =
values.data.length === 0
? [Array(headers.length).fill("")]
: values.data;
} else if (values === null) {
values = [Array(headers.length).fill("")];
} else {
values = values;
}
}
const dispatch = createEventDispatcher<{
change: { data: typeof values; headers: typeof headers };
}>();
let editing: boolean | string = false;
let selected: boolean | string = false;
@ -114,10 +130,10 @@
let old_val: undefined | Array<Array<string | number>> = undefined;
$: _headers &&
dispatch(
"change",
data.map((r) => r.map(({ value }) => value))
);
dispatch("change", {
data: data.map((r) => r.map(({ value }) => value)),
headers: _headers.map((h) => h.value)
});
function get_sort_status(
name: string,