gradio/test/test_outputs.py

80 lines
14 KiB
Python
Raw Normal View History

2019-02-28 08:54:08 +08:00
import numpy as np
import unittest
import os
from gradio import outputs
2019-03-06 15:23:04 +08:00
import json
2019-02-28 08:54:08 +08:00
PACKAGE_NAME = 'gradio'
2019-03-18 20:38:10 +08:00
BASE64_IMG = "
2019-02-28 08:54:08 +08:00
class TestLabel(unittest.TestCase):
2019-06-19 04:13:50 +08:00
def test_path_exists(self):
out = outputs.Label()
2020-07-02 08:34:51 +08:00
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__.lower())
2019-06-20 03:03:54 +08:00
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
2019-02-28 08:54:08 +08:00
def test_postprocessing_string(self):
string = 'happy'
out = outputs.Label()
label = out.postprocess(string)
self.assertDictEqual(label, {outputs.Label.LABEL_KEY: string})
def test_postprocessing_dict(self):
orig_label = {
3: 0.7,
1: 0.2,
0: 0.1
}
true_label = {outputs.Label.LABEL_KEY: 3,
outputs.Label.CONFIDENCES_KEY: [
{outputs.Label.LABEL_KEY: 3, outputs.Label.CONFIDENCE_KEY: 0.7},
{outputs.Label.LABEL_KEY: 1, outputs.Label.CONFIDENCE_KEY: 0.2},
{outputs.Label.LABEL_KEY: 0, outputs.Label.CONFIDENCE_KEY: 0.1},
]}
out = outputs.Label()
label = out.postprocess(orig_label)
self.assertDictEqual(label, true_label)
2019-02-28 08:54:08 +08:00
def test_postprocessing_array(self):
array = np.array([0.1, 0.2, 0, 0.7, 0])
out = outputs.Label()
self.assertRaises(ValueError, out.postprocess, array)
def test_postprocessing_int(self):
label = 3
true_label = {outputs.Label.LABEL_KEY: '3'}
out = outputs.Label()
label = out.postprocess(label)
self.assertDictEqual(label, true_label)
2019-02-28 08:54:08 +08:00
2019-06-19 04:13:50 +08:00
class TestTextbox(unittest.TestCase):
def test_path_exists(self):
out = outputs.Textbox()
2020-07-02 08:34:51 +08:00
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__.lower())
2019-06-20 03:03:54 +08:00
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
2019-02-28 08:54:08 +08:00
def test_postprocessing(self):
string = 'happy'
out = outputs.Textbox()
string = out.postprocess(string)
2019-02-28 08:54:08 +08:00
self.assertEqual(string, string)
2019-03-18 20:38:10 +08:00
class TestImage(unittest.TestCase):
2019-06-19 04:13:50 +08:00
def test_path_exists(self):
out = outputs.Image()
2020-07-02 08:34:51 +08:00
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__qualname__.lower())
2019-06-20 03:03:54 +08:00
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
2019-03-18 20:38:10 +08:00
def test_postprocessing(self):
string = BASE64_IMG
out = outputs.Textbox()
string = out.postprocess(string)
self.assertEqual(string, string)
2019-02-28 08:54:08 +08:00
if __name__ == '__main__':
unittest.main()