gradio/test/test_outputs.py

78 lines
15 KiB
Python
Raw Normal View History

2019-02-28 08:54:08 +08:00
import numpy as np
import unittest
import os
from gradio import outputs
2019-03-06 15:23:04 +08:00
import json
2019-02-28 08:54:08 +08:00
PACKAGE_NAME = 'gradio'
2019-03-18 20:38:10 +08:00
BASE64_IMG = "
2019-02-28 08:54:08 +08:00
class TestLabel(unittest.TestCase):
2019-06-19 04:13:50 +08:00
def test_path_exists(self):
out = outputs.Label()
2020-07-02 08:34:51 +08:00
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__.lower())
2019-06-20 03:03:54 +08:00
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
2019-02-28 08:54:08 +08:00
2020-06-12 03:31:44 +08:00
# def test_postprocessing_string(self):
# string = 'happy'
# out = outputs.Label()
# label = json.loads(out.postprocess(string))
# self.assertDictEqual(label, {outputs.Label.LABEL_KEY: string})
#
# def test_postprocessing_1D_array(self):
# array = np.array([0.1, 0.2, 0, 0.7, 0])
# true_label = {outputs.Label.LABEL_KEY: 3,
# outputs.Label.CONFIDENCES_KEY: [
# {outputs.Label.LABEL_KEY: 3, outputs.Label.CONFIDENCE_KEY: 0.7},
# {outputs.Label.LABEL_KEY: 1, outputs.Label.CONFIDENCE_KEY: 0.2},
# {outputs.Label.LABEL_KEY: 0, outputs.Label.CONFIDENCE_KEY: 0.1},
# ]}
# out = outputs.Label()
# label = json.loads(out.postprocess(array))
# self.assertDictEqual(label, true_label)
2019-02-28 08:54:08 +08:00
2020-06-12 03:31:44 +08:00
# def test_postprocessing_1D_array_no_confidences(self):
# array = np.array([0.1, 0.2, 0, 0.7, 0])
# true_label = {outputs.Label.LABEL_KEY: 3}
# out = outputs.Label(show_confidences=False)
# label = json.loads(out.postprocess(array))
# self.assertDictEqual(label, true_label)
#
# def test_postprocessing_int(self):
# true_label_array = np.array([[[3]]])
# true_label = {outputs.Label.LABEL_KEY: 3}
# out = outputs.Label()
# label = json.loads(out.postprocess(true_label_array))
# self.assertDictEqual(label, true_label)
2019-02-28 08:54:08 +08:00
2019-06-19 04:13:50 +08:00
class TestTextbox(unittest.TestCase):
def test_path_exists(self):
out = outputs.Textbox()
2020-07-02 08:34:51 +08:00
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__.lower())
2019-06-20 03:03:54 +08:00
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
2019-02-28 08:54:08 +08:00
def test_postprocessing(self):
string = 'happy'
out = outputs.Textbox()
string = out.postprocess(string)
2019-02-28 08:54:08 +08:00
self.assertEqual(string, string)
2019-03-18 20:38:10 +08:00
class TestImage(unittest.TestCase):
2019-06-19 04:13:50 +08:00
def test_path_exists(self):
out = outputs.Image()
2020-07-02 08:34:51 +08:00
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__qualname__.lower())
2019-06-20 03:03:54 +08:00
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
2019-03-18 20:38:10 +08:00
def test_postprocessing(self):
string = BASE64_IMG
out = outputs.Textbox()
string = out.postprocess(string)
self.assertEqual(string, string)
2019-02-28 08:54:08 +08:00
if __name__ == '__main__':
unittest.main()