fixed outputs tests

This commit is contained in:
Abubakar Abid 2022-03-23 15:49:37 -07:00
parent fea2c32716
commit fa27aa5876
3 changed files with 33 additions and 25 deletions

View File

@ -2097,6 +2097,8 @@ class Dataframe(Component):
dtype = "numpy"
elif isinstance(y, list):
dtype = "array"
else:
raise ValueError("Cannot determine the type of DataFrame output.")
else:
dtype = self.output_type
if dtype == "pandas":

View File

@ -18,8 +18,8 @@ from markdown_it import MarkdownIt
from mdit_py_plugins.footnote import footnote_plugin
from gradio import interpretation, utils
from gradio.blocks import BlockContext, Column, Row
from gradio.components import Button, Component, Markdown, get_component_instance
from gradio.components import Component, get_component_instance, Markdown, Button
from gradio.blocks import BlockContext, Row, Column
from gradio.external import load_from_pipeline, load_interface # type: ignore
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
from gradio.inputs import State as i_State # type: ignore

View File

@ -2,6 +2,7 @@ import os
import tempfile
import unittest
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
@ -12,18 +13,13 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestTextbox(unittest.TestCase):
def test_as_component(self):
with self.assertRaises(ValueError):
wrong_type = gr.outputs.Textbox(type="unknown")
wrong_type.postprocess(0)
def test_in_interface(self):
iface = gr.Interface(lambda x: x[-1], "textbox", gr.outputs.Textbox())
self.assertEqual(iface.process(["Hello"])[0], ["o"])
iface = gr.Interface(
lambda x: x / 2, "number", gr.outputs.Textbox(type="number")
lambda x: x / 2, "number", gr.outputs.Textbox()
)
self.assertEqual(iface.process([10])[0], [5])
self.assertEqual(iface.process([10])[0], ['5.0'])
class TestLabel(unittest.TestCase):
@ -79,9 +75,6 @@ class TestLabel(unittest.TestCase):
],
},
)
with self.assertRaises(ValueError):
label_output = gr.outputs.Label(type="unknown")
label_output.deserialize([1, 2, 3])
def test_in_interface(self):
x_img = gr.test_data.BASE64_IMAGE
@ -239,21 +232,21 @@ class TestAudio(unittest.TestCase):
)
)
self.assertEqual(
audio_output.get_template_context(), {"name": "audio", "label": None}
audio_output.get_template_context(), {"name": "audio",
"label": None,
"source": "upload",
"css": {}}
)
with self.assertRaises(ValueError):
wrong_type = gr.outputs.Audio(type="unknown")
wrong_type.postprocess(y_audio.name)
self.assertTrue(
audio_output.deserialize(gr.test_data.BASE64_AUDIO["data"]).endswith(".wav")
)
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = audio_output.save_flagged(
tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO["data"], None
tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO, None
)
self.assertEqual("audio_output/0.wav", to_save)
to_save = audio_output.save_flagged(
tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO["data"], None
tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO, None
)
self.assertEqual("audio_output/1.wav", to_save)
@ -331,11 +324,11 @@ class TestFile(unittest.TestCase):
file_output = gr.outputs.File()
with tempfile.TemporaryDirectory() as tmpdirname:
to_save = file_output.save_flagged(
tmpdirname, "file_output", gr.test_data.BASE64_FILE, None
tmpdirname, "file_output", [gr.test_data.BASE64_FILE], None
)
self.assertEqual("file_output/0", to_save)
to_save = file_output.save_flagged(
tmpdirname, "file_output", gr.test_data.BASE64_FILE, None
tmpdirname, "file_output", [gr.test_data.BASE64_FILE], None
)
self.assertEqual("file_output/1", to_save)
@ -363,7 +356,14 @@ class TestDataframe(unittest.TestCase):
"overflow_row_behaviour": "paginate",
"name": "dataframe",
"label": None,
"css": {}
"css": {},
"datatype": "str",
"row_count": 3,
"col_count": 3,
"col_width": None,
"default": [[None, None, None], [None, None, None], [None, None, None]],
"name": "dataframe",
},
)
with self.assertRaises(ValueError):
@ -373,11 +373,14 @@ class TestDataframe(unittest.TestCase):
to_save = dataframe_output.save_flagged(
tmpdirname, "dataframe_output", output, None
)
self.assertEqual(to_save, "[[2, true], [3, true], [4, false]]")
self.assertEqual(to_save, json.dumps({
"headers": ["num", "prime"],
"data": [[2, True], [3, True], [4, False]]
}))
self.assertEqual(
dataframe_output.restore_flagged(tmpdirname, to_save, None),
{"data": [[2, True], [3, True], [4, False]]},
)
{"headers": ["num", "prime"],
"data": [[2, True], [3, True], [4, False]]})
def test_in_interface(self):
def check_odd(array):
@ -413,7 +416,10 @@ class TestCarousel(unittest.TestCase):
self.assertEqual(
carousel_output.get_template_context(),
{
"components": [{"name": "textbox", "label": None}],
"components": [
{"name": "textbox", "label": None, "default": "", "lines": 1,
"css": {}, 'placeholder': None}
],
"name": "carousel",
"label": "Disease",
"css": {}