Always return headers from postprocess (#1893)

* always return data/headers from dataframe postprocess

* add demo back

* handle case of no new headers but different length list

* fix tests

* change

* change

* added unit tests

* formatting

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
pngwn 2022-08-01 15:57:31 -07:00 committed by GitHub
parent bdced314bc
commit 5d03174e44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 85 additions and 32 deletions

View File

@ -2462,7 +2462,9 @@ class Dataframe(Changeable, IOComponent):
self.__validate_headers(headers, self.col_count[0])
self.headers = headers
self.headers = (
headers if headers is not None else list(range(1, self.col_count[0] + 1))
)
self.datatype = (
datatype if isinstance(datatype, list) else [datatype] * self.col_count[0]
)
@ -2482,8 +2484,11 @@ class Dataframe(Changeable, IOComponent):
[values[c] for c in column_dtypes] for _ in range(self.row_count[0])
]
self.value = value if value is not None else self.test_input
self.value = self.__process_markdown(self.value, datatype)
self.value = (
self.postprocess(value)
if value is not None
else self.postprocess(self.test_input)
)
self.max_rows = max_rows
self.max_cols = max_cols
@ -2596,7 +2601,19 @@ class Dataframe(Changeable, IOComponent):
if isinstance(y, (np.ndarray, list)):
if isinstance(y, np.ndarray):
y = y.tolist()
_headers = self.headers
if len(self.headers) < len(y[0]):
_headers = [
*self.headers,
*list(range(len(self.headers) + 1, len(y[0]) + 1)),
]
elif len(self.headers) > len(y[0]):
_headers = self.headers[0 : len(y[0])]
return {
"headers": _headers,
"data": Dataframe.__process_markdown(y, self.datatype),
}
raise ValueError("Cannot process value as a Dataframe")

View File

@ -10,6 +10,7 @@ import numpy as np
import pandas as pd
import PIL
import pytest
from requests import head
import gradio as gr
from gradio import media_data
@ -1038,11 +1039,14 @@ class TestDataframe(unittest.TestCase):
"datatype": ["str", "str", "str"],
"row_count": (3, "dynamic"),
"col_count": (3, "dynamic"),
"value": [
["", "", ""],
["", "", ""],
["", "", ""],
],
"value": {
"data": [
["", "", ""],
["", "", ""],
["", "", ""],
],
"headers": ["Name", "Age", "Member"],
},
"name": "dataframe",
"show_label": True,
"label": "Dataframe Input",
@ -1063,26 +1067,11 @@ class TestDataframe(unittest.TestCase):
wrong_type = gr.Dataframe(type="unknown")
wrong_type.preprocess(x_data)
# Output functionalities
dataframe_output = gr.Dataframe()
output = dataframe_output.postprocess(np.zeros((2, 2)))
self.assertDictEqual(output, {"data": [[0, 0], [0, 0]]})
output = dataframe_output.postprocess([[1, 3, 5]])
self.assertDictEqual(output, {"data": [[1, 3, 5]]})
output = dataframe_output.postprocess(
pd.DataFrame([[2, True], [3, True], [4, False]], columns=["num", "prime"])
)
self.assertDictEqual(
output,
{
"headers": ["num", "prime"],
"data": [[2, True], [3, True], [4, False]],
},
)
self.assertEqual(
dataframe_output.get_config(),
{
"headers": None,
"headers": [1, 2, 3],
"max_rows": 20,
"max_cols": None,
"overflow_row_behaviour": "paginate",
@ -1095,15 +1084,38 @@ class TestDataframe(unittest.TestCase):
"datatype": ["str", "str", "str"],
"row_count": (3, "dynamic"),
"col_count": (3, "dynamic"),
"value": [
["", "", ""],
["", "", ""],
["", "", ""],
],
"value": {
"data": [
["", "", ""],
["", "", ""],
["", "", ""],
],
"headers": [1, 2, 3],
},
"interactive": None,
"wrap": False,
},
)
def test_postprocess(self):
"""
postprocess
"""
dataframe_output = gr.Dataframe()
output = dataframe_output.postprocess(np.zeros((2, 2)))
self.assertDictEqual(output, {"data": [[0, 0], [0, 0]], "headers": [1, 2]})
output = dataframe_output.postprocess([[1, 3, 5]])
self.assertDictEqual(output, {"data": [[1, 3, 5]], "headers": [1, 2, 3]})
output = dataframe_output.postprocess(
pd.DataFrame([[2, True], [3, True], [4, False]], columns=["num", "prime"])
)
self.assertDictEqual(
output,
{
"headers": ["num", "prime"],
"data": [[2, True], [3, True], [4, False]],
},
)
with self.assertRaises(ValueError):
wrong_type = gr.Dataframe(type="unknown")
wrong_type.postprocess(0)
@ -1128,6 +1140,26 @@ class TestDataframe(unittest.TestCase):
},
)
# When the headers don't match the data
dataframe_output = gr.Dataframe(headers=["one", "two", "three"])
output = dataframe_output.postprocess([[2, True], [3, True]])
self.assertDictEqual(
output,
{
"headers": ["one", "two"],
"data": [[2, True], [3, True]],
},
)
dataframe_output = gr.Dataframe(headers=["one", "two", "three"])
output = dataframe_output.postprocess([[2, True, "ab", 4], [3, True, "cd", 5]])
self.assertDictEqual(
output,
{
"headers": ["one", "two", "three", 4],
"data": [[2, True, "ab", 4], [3, True, "cd", 5]],
},
)
def test_in_interface_as_input(self):
"""
Interface, process,
@ -1135,7 +1167,7 @@ class TestDataframe(unittest.TestCase):
x_data = {"data": [[1, 2, 3], [4, 5, 6]]}
iface = gr.Interface(np.max, "numpy", "number")
self.assertEqual(iface.process([x_data]), [6])
x_data = {"data": [["Tim"], ["Jon"], ["Sal"]]}
x_data = {"data": [["Tim"], ["Jon"], ["Sal"]], "headers": [1, 2, 3]}
def get_last(my_list):
return my_list[-1][-1]
@ -1153,7 +1185,8 @@ class TestDataframe(unittest.TestCase):
iface = gr.Interface(check_odd, "numpy", "numpy")
self.assertEqual(
iface.process([{"data": [[2, 3, 4]]}])[0], {"data": [[True, False, True]]}
iface.process([{"data": [[2, 3, 4]]}])[0],
{"data": [[True, False, True]], "headers": [1, 2, 3]},
)

View File

@ -12,7 +12,10 @@
export let headers: Headers = [];
export let elem_id: string = "";
export let visible: boolean = true;
export let value: Data | { data: Data; headers: Headers } = [["", "", ""]];
export let value: { data: Data; headers: Headers } = {
data: [["", "", ""]],
headers: ["1", "2", "3"]
};
export let mode: "static" | "dynamic";
export let col_count: [number, "fixed" | "dynamic"];
export let row_count: [number, "fixed" | "dynamic"];