2019-02-28 08:54:08 +08:00
|
|
|
import numpy as np
|
|
|
|
import unittest
|
|
|
|
import os
|
|
|
|
from gradio import outputs
|
2019-03-06 15:23:04 +08:00
|
|
|
import json
|
2019-02-28 08:54:08 +08:00
|
|
|
|
|
|
|
PACKAGE_NAME = 'gradio'
|
2019-03-18 20:38:10 +08:00
|
|
|
BASE64_IMG = "
|
2019-02-28 08:54:08 +08:00
|
|
|
|
|
|
|
|
|
|
|
class TestLabel(unittest.TestCase):
|
2019-06-19 04:13:50 +08:00
|
|
|
def test_path_exists(self):
|
|
|
|
out = outputs.Label()
|
2020-07-01 07:11:35 +08:00
|
|
|
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__)
|
2019-06-20 03:03:54 +08:00
|
|
|
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
2019-02-28 08:54:08 +08:00
|
|
|
|
2020-06-12 03:31:44 +08:00
|
|
|
# def test_postprocessing_string(self):
|
|
|
|
# string = 'happy'
|
|
|
|
# out = outputs.Label()
|
|
|
|
# label = json.loads(out.postprocess(string))
|
|
|
|
# self.assertDictEqual(label, {outputs.Label.LABEL_KEY: string})
|
|
|
|
#
|
|
|
|
# def test_postprocessing_1D_array(self):
|
|
|
|
# array = np.array([0.1, 0.2, 0, 0.7, 0])
|
|
|
|
# true_label = {outputs.Label.LABEL_KEY: 3,
|
|
|
|
# outputs.Label.CONFIDENCES_KEY: [
|
|
|
|
# {outputs.Label.LABEL_KEY: 3, outputs.Label.CONFIDENCE_KEY: 0.7},
|
|
|
|
# {outputs.Label.LABEL_KEY: 1, outputs.Label.CONFIDENCE_KEY: 0.2},
|
|
|
|
# {outputs.Label.LABEL_KEY: 0, outputs.Label.CONFIDENCE_KEY: 0.1},
|
|
|
|
# ]}
|
|
|
|
# out = outputs.Label()
|
|
|
|
# label = json.loads(out.postprocess(array))
|
|
|
|
# self.assertDictEqual(label, true_label)
|
2019-02-28 08:54:08 +08:00
|
|
|
|
2020-06-12 03:31:44 +08:00
|
|
|
# def test_postprocessing_1D_array_no_confidences(self):
|
|
|
|
# array = np.array([0.1, 0.2, 0, 0.7, 0])
|
|
|
|
# true_label = {outputs.Label.LABEL_KEY: 3}
|
|
|
|
# out = outputs.Label(show_confidences=False)
|
|
|
|
# label = json.loads(out.postprocess(array))
|
|
|
|
# self.assertDictEqual(label, true_label)
|
|
|
|
#
|
|
|
|
# def test_postprocessing_int(self):
|
|
|
|
# true_label_array = np.array([[[3]]])
|
|
|
|
# true_label = {outputs.Label.LABEL_KEY: 3}
|
|
|
|
# out = outputs.Label()
|
|
|
|
# label = json.loads(out.postprocess(true_label_array))
|
|
|
|
# self.assertDictEqual(label, true_label)
|
2019-02-28 08:54:08 +08:00
|
|
|
|
|
|
|
|
2019-06-19 04:13:50 +08:00
|
|
|
class TestTextbox(unittest.TestCase):
|
|
|
|
def test_path_exists(self):
|
|
|
|
out = outputs.Textbox()
|
2020-07-01 07:11:35 +08:00
|
|
|
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__)
|
2019-06-20 03:03:54 +08:00
|
|
|
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
2019-02-28 08:54:08 +08:00
|
|
|
|
|
|
|
def test_postprocessing(self):
|
|
|
|
string = 'happy'
|
|
|
|
out = outputs.Textbox()
|
2019-02-28 13:39:18 +08:00
|
|
|
string = out.postprocess(string)
|
2019-02-28 08:54:08 +08:00
|
|
|
self.assertEqual(string, string)
|
|
|
|
|
|
|
|
|
2019-03-18 20:38:10 +08:00
|
|
|
class TestImage(unittest.TestCase):
|
2019-06-19 04:13:50 +08:00
|
|
|
def test_path_exists(self):
|
|
|
|
out = outputs.Image()
|
2020-07-01 07:11:35 +08:00
|
|
|
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__qualname__)
|
2019-06-20 03:03:54 +08:00
|
|
|
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
2019-03-18 20:38:10 +08:00
|
|
|
|
|
|
|
def test_postprocessing(self):
|
|
|
|
string = BASE64_IMG
|
|
|
|
out = outputs.Textbox()
|
|
|
|
string = out.postprocess(string)
|
|
|
|
self.assertEqual(string, string)
|
|
|
|
|
|
|
|
|
2019-02-28 08:54:08 +08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|