mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
fixed outputs tests
This commit is contained in:
parent
fea2c32716
commit
fa27aa5876
@ -2097,6 +2097,8 @@ class Dataframe(Component):
|
||||
dtype = "numpy"
|
||||
elif isinstance(y, list):
|
||||
dtype = "array"
|
||||
else:
|
||||
raise ValueError("Cannot determine the type of DataFrame output.")
|
||||
else:
|
||||
dtype = self.output_type
|
||||
if dtype == "pandas":
|
||||
|
@ -18,8 +18,8 @@ from markdown_it import MarkdownIt
|
||||
from mdit_py_plugins.footnote import footnote_plugin
|
||||
|
||||
from gradio import interpretation, utils
|
||||
from gradio.blocks import BlockContext, Column, Row
|
||||
from gradio.components import Button, Component, Markdown, get_component_instance
|
||||
from gradio.components import Component, get_component_instance, Markdown, Button
|
||||
from gradio.blocks import BlockContext, Row, Column
|
||||
from gradio.external import load_from_pipeline, load_interface # type: ignore
|
||||
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
|
||||
from gradio.inputs import State as i_State # type: ignore
|
||||
|
@ -2,6 +2,7 @@ import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -12,18 +13,13 @@ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_as_component(self):
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.outputs.Textbox(type="unknown")
|
||||
wrong_type.postprocess(0)
|
||||
|
||||
def test_in_interface(self):
|
||||
iface = gr.Interface(lambda x: x[-1], "textbox", gr.outputs.Textbox())
|
||||
self.assertEqual(iface.process(["Hello"])[0], ["o"])
|
||||
iface = gr.Interface(
|
||||
lambda x: x / 2, "number", gr.outputs.Textbox(type="number")
|
||||
lambda x: x / 2, "number", gr.outputs.Textbox()
|
||||
)
|
||||
self.assertEqual(iface.process([10])[0], [5])
|
||||
self.assertEqual(iface.process([10])[0], ['5.0'])
|
||||
|
||||
|
||||
class TestLabel(unittest.TestCase):
|
||||
@ -79,9 +75,6 @@ class TestLabel(unittest.TestCase):
|
||||
],
|
||||
},
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
label_output = gr.outputs.Label(type="unknown")
|
||||
label_output.deserialize([1, 2, 3])
|
||||
|
||||
def test_in_interface(self):
|
||||
x_img = gr.test_data.BASE64_IMAGE
|
||||
@ -239,21 +232,21 @@ class TestAudio(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
audio_output.get_template_context(), {"name": "audio", "label": None}
|
||||
audio_output.get_template_context(), {"name": "audio",
|
||||
"label": None,
|
||||
"source": "upload",
|
||||
"css": {}}
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
wrong_type = gr.outputs.Audio(type="unknown")
|
||||
wrong_type.postprocess(y_audio.name)
|
||||
self.assertTrue(
|
||||
audio_output.deserialize(gr.test_data.BASE64_AUDIO["data"]).endswith(".wav")
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = audio_output.save_flagged(
|
||||
tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO["data"], None
|
||||
tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO, None
|
||||
)
|
||||
self.assertEqual("audio_output/0.wav", to_save)
|
||||
to_save = audio_output.save_flagged(
|
||||
tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO["data"], None
|
||||
tmpdirname, "audio_output", gr.test_data.BASE64_AUDIO, None
|
||||
)
|
||||
self.assertEqual("audio_output/1.wav", to_save)
|
||||
|
||||
@ -331,11 +324,11 @@ class TestFile(unittest.TestCase):
|
||||
file_output = gr.outputs.File()
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
to_save = file_output.save_flagged(
|
||||
tmpdirname, "file_output", gr.test_data.BASE64_FILE, None
|
||||
tmpdirname, "file_output", [gr.test_data.BASE64_FILE], None
|
||||
)
|
||||
self.assertEqual("file_output/0", to_save)
|
||||
to_save = file_output.save_flagged(
|
||||
tmpdirname, "file_output", gr.test_data.BASE64_FILE, None
|
||||
tmpdirname, "file_output", [gr.test_data.BASE64_FILE], None
|
||||
)
|
||||
self.assertEqual("file_output/1", to_save)
|
||||
|
||||
@ -363,7 +356,14 @@ class TestDataframe(unittest.TestCase):
|
||||
"overflow_row_behaviour": "paginate",
|
||||
"name": "dataframe",
|
||||
"label": None,
|
||||
"css": {}
|
||||
"css": {},
|
||||
"datatype": "str",
|
||||
"row_count": 3,
|
||||
"col_count": 3,
|
||||
"col_width": None,
|
||||
"default": [[None, None, None], [None, None, None], [None, None, None]],
|
||||
"name": "dataframe",
|
||||
|
||||
},
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
@ -373,11 +373,14 @@ class TestDataframe(unittest.TestCase):
|
||||
to_save = dataframe_output.save_flagged(
|
||||
tmpdirname, "dataframe_output", output, None
|
||||
)
|
||||
self.assertEqual(to_save, "[[2, true], [3, true], [4, false]]")
|
||||
self.assertEqual(to_save, json.dumps({
|
||||
"headers": ["num", "prime"],
|
||||
"data": [[2, True], [3, True], [4, False]]
|
||||
}))
|
||||
self.assertEqual(
|
||||
dataframe_output.restore_flagged(tmpdirname, to_save, None),
|
||||
{"data": [[2, True], [3, True], [4, False]]},
|
||||
)
|
||||
{"headers": ["num", "prime"],
|
||||
"data": [[2, True], [3, True], [4, False]]})
|
||||
|
||||
def test_in_interface(self):
|
||||
def check_odd(array):
|
||||
@ -413,7 +416,10 @@ class TestCarousel(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
carousel_output.get_template_context(),
|
||||
{
|
||||
"components": [{"name": "textbox", "label": None}],
|
||||
"components": [
|
||||
{"name": "textbox", "label": None, "default": "", "lines": 1,
|
||||
"css": {}, 'placeholder': None}
|
||||
],
|
||||
"name": "carousel",
|
||||
"label": "Disease",
|
||||
"css": {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user