mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
Always return headers from postprocess (#1893)
* always return data/headers from dataframe postprocess * add demo back * handle case of no new headers but different length list * fix tests * change * change * added unit tests * formatting Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
bdced314bc
commit
5d03174e44
@ -2462,7 +2462,9 @@ class Dataframe(Changeable, IOComponent):
|
||||
|
||||
self.__validate_headers(headers, self.col_count[0])
|
||||
|
||||
self.headers = headers
|
||||
self.headers = (
|
||||
headers if headers is not None else list(range(1, self.col_count[0] + 1))
|
||||
)
|
||||
self.datatype = (
|
||||
datatype if isinstance(datatype, list) else [datatype] * self.col_count[0]
|
||||
)
|
||||
@ -2482,8 +2484,11 @@ class Dataframe(Changeable, IOComponent):
|
||||
[values[c] for c in column_dtypes] for _ in range(self.row_count[0])
|
||||
]
|
||||
|
||||
self.value = value if value is not None else self.test_input
|
||||
self.value = self.__process_markdown(self.value, datatype)
|
||||
self.value = (
|
||||
self.postprocess(value)
|
||||
if value is not None
|
||||
else self.postprocess(self.test_input)
|
||||
)
|
||||
|
||||
self.max_rows = max_rows
|
||||
self.max_cols = max_cols
|
||||
@ -2596,7 +2601,19 @@ class Dataframe(Changeable, IOComponent):
|
||||
if isinstance(y, (np.ndarray, list)):
|
||||
if isinstance(y, np.ndarray):
|
||||
y = y.tolist()
|
||||
|
||||
_headers = self.headers
|
||||
|
||||
if len(self.headers) < len(y[0]):
|
||||
_headers = [
|
||||
*self.headers,
|
||||
*list(range(len(self.headers) + 1, len(y[0]) + 1)),
|
||||
]
|
||||
elif len(self.headers) > len(y[0]):
|
||||
_headers = self.headers[0 : len(y[0])]
|
||||
|
||||
return {
|
||||
"headers": _headers,
|
||||
"data": Dataframe.__process_markdown(y, self.datatype),
|
||||
}
|
||||
raise ValueError("Cannot process value as a Dataframe")
|
||||
|
@ -10,6 +10,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import PIL
|
||||
import pytest
|
||||
from requests import head
|
||||
|
||||
import gradio as gr
|
||||
from gradio import media_data
|
||||
@ -1038,11 +1039,14 @@ class TestDataframe(unittest.TestCase):
|
||||
"datatype": ["str", "str", "str"],
|
||||
"row_count": (3, "dynamic"),
|
||||
"col_count": (3, "dynamic"),
|
||||
"value": [
|
||||
["", "", ""],
|
||||
["", "", ""],
|
||||
["", "", ""],
|
||||
],
|
||||
"value": {
|
||||
"data": [
|
||||
["", "", ""],
|
||||
["", "", ""],
|
||||
["", "", ""],
|
||||
],
|
||||
"headers": ["Name", "Age", "Member"],
|
||||
},
|
||||
"name": "dataframe",
|
||||
"show_label": True,
|
||||
"label": "Dataframe Input",
|
||||
@ -1063,26 +1067,11 @@ class TestDataframe(unittest.TestCase):
|
||||
wrong_type = gr.Dataframe(type="unknown")
|
||||
wrong_type.preprocess(x_data)
|
||||
|
||||
# Output functionalities
|
||||
dataframe_output = gr.Dataframe()
|
||||
output = dataframe_output.postprocess(np.zeros((2, 2)))
|
||||
self.assertDictEqual(output, {"data": [[0, 0], [0, 0]]})
|
||||
output = dataframe_output.postprocess([[1, 3, 5]])
|
||||
self.assertDictEqual(output, {"data": [[1, 3, 5]]})
|
||||
output = dataframe_output.postprocess(
|
||||
pd.DataFrame([[2, True], [3, True], [4, False]], columns=["num", "prime"])
|
||||
)
|
||||
self.assertDictEqual(
|
||||
output,
|
||||
{
|
||||
"headers": ["num", "prime"],
|
||||
"data": [[2, True], [3, True], [4, False]],
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
dataframe_output.get_config(),
|
||||
{
|
||||
"headers": None,
|
||||
"headers": [1, 2, 3],
|
||||
"max_rows": 20,
|
||||
"max_cols": None,
|
||||
"overflow_row_behaviour": "paginate",
|
||||
@ -1095,15 +1084,38 @@ class TestDataframe(unittest.TestCase):
|
||||
"datatype": ["str", "str", "str"],
|
||||
"row_count": (3, "dynamic"),
|
||||
"col_count": (3, "dynamic"),
|
||||
"value": [
|
||||
["", "", ""],
|
||||
["", "", ""],
|
||||
["", "", ""],
|
||||
],
|
||||
"value": {
|
||||
"data": [
|
||||
["", "", ""],
|
||||
["", "", ""],
|
||||
["", "", ""],
|
||||
],
|
||||
"headers": [1, 2, 3],
|
||||
},
|
||||
"interactive": None,
|
||||
"wrap": False,
|
||||
},
|
||||
)
|
||||
|
||||
def test_postprocess(self):
|
||||
"""
|
||||
postprocess
|
||||
"""
|
||||
dataframe_output = gr.Dataframe()
|
||||
output = dataframe_output.postprocess(np.zeros((2, 2)))
|
||||
self.assertDictEqual(output, {"data": [[0, 0], [0, 0]], "headers": [1, 2]})
|
||||
output = dataframe_output.postprocess([[1, 3, 5]])
|
||||
self.assertDictEqual(output, {"data": [[1, 3, 5]], "headers": [1, 2, 3]})
|
||||
output = dataframe_output.postprocess(
|
||||
pd.DataFrame([[2, True], [3, True], [4, False]], columns=["num", "prime"])
|
||||
)
|
||||
self.assertDictEqual(
|
||||
output,
|
||||
{
|
||||
"headers": ["num", "prime"],
|
||||
"data": [[2, True], [3, True], [4, False]],
|
||||
},
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.Dataframe(type="unknown")
|
||||
wrong_type.postprocess(0)
|
||||
@ -1128,6 +1140,26 @@ class TestDataframe(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
# When the headers don't match the data
|
||||
dataframe_output = gr.Dataframe(headers=["one", "two", "three"])
|
||||
output = dataframe_output.postprocess([[2, True], [3, True]])
|
||||
self.assertDictEqual(
|
||||
output,
|
||||
{
|
||||
"headers": ["one", "two"],
|
||||
"data": [[2, True], [3, True]],
|
||||
},
|
||||
)
|
||||
dataframe_output = gr.Dataframe(headers=["one", "two", "three"])
|
||||
output = dataframe_output.postprocess([[2, True, "ab", 4], [3, True, "cd", 5]])
|
||||
self.assertDictEqual(
|
||||
output,
|
||||
{
|
||||
"headers": ["one", "two", "three", 4],
|
||||
"data": [[2, True, "ab", 4], [3, True, "cd", 5]],
|
||||
},
|
||||
)
|
||||
|
||||
def test_in_interface_as_input(self):
|
||||
"""
|
||||
Interface, process,
|
||||
@ -1135,7 +1167,7 @@ class TestDataframe(unittest.TestCase):
|
||||
x_data = {"data": [[1, 2, 3], [4, 5, 6]]}
|
||||
iface = gr.Interface(np.max, "numpy", "number")
|
||||
self.assertEqual(iface.process([x_data]), [6])
|
||||
x_data = {"data": [["Tim"], ["Jon"], ["Sal"]]}
|
||||
x_data = {"data": [["Tim"], ["Jon"], ["Sal"]], "headers": [1, 2, 3]}
|
||||
|
||||
def get_last(my_list):
|
||||
return my_list[-1][-1]
|
||||
@ -1153,7 +1185,8 @@ class TestDataframe(unittest.TestCase):
|
||||
|
||||
iface = gr.Interface(check_odd, "numpy", "numpy")
|
||||
self.assertEqual(
|
||||
iface.process([{"data": [[2, 3, 4]]}])[0], {"data": [[True, False, True]]}
|
||||
iface.process([{"data": [[2, 3, 4]]}])[0],
|
||||
{"data": [[True, False, True]], "headers": [1, 2, 3]},
|
||||
)
|
||||
|
||||
|
||||
|
@ -12,7 +12,10 @@
|
||||
export let headers: Headers = [];
|
||||
export let elem_id: string = "";
|
||||
export let visible: boolean = true;
|
||||
export let value: Data | { data: Data; headers: Headers } = [["", "", ""]];
|
||||
export let value: { data: Data; headers: Headers } = {
|
||||
data: [["", "", ""]],
|
||||
headers: ["1", "2", "3"]
|
||||
};
|
||||
export let mode: "static" | "dynamic";
|
||||
export let col_count: [number, "fixed" | "dynamic"];
|
||||
export let row_count: [number, "fixed" | "dynamic"];
|
||||
|
Loading…
x
Reference in New Issue
Block a user