2019-02-28 08:54:08 +08:00
|
|
|
import numpy as np
|
|
|
|
import unittest
|
|
|
|
import os
|
|
|
|
from gradio import outputs
|
2019-03-06 15:23:04 +08:00
|
|
|
import json
|
2019-02-28 08:54:08 +08:00
|
|
|
|
|
|
|
PACKAGE_NAME = 'gradio'
|
|
|
|
|
|
|
|
|
|
|
|
class TestLabel(unittest.TestCase):
|
|
|
|
def test_path_exists(self):
|
|
|
|
out = outputs.Label()
|
2019-02-28 13:39:18 +08:00
|
|
|
path = out.get_template_path()
|
2019-02-28 08:54:08 +08:00
|
|
|
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
|
|
|
|
|
|
|
def test_postprocessing_string(self):
|
|
|
|
string = 'happy'
|
|
|
|
out = outputs.Label()
|
2019-03-08 05:53:34 +08:00
|
|
|
label = json.loads(out.postprocess(string))
|
2019-03-06 14:34:59 +08:00
|
|
|
self.assertDictEqual(label, {outputs.Label.LABEL_KEY: string})
|
|
|
|
|
|
|
|
def test_postprocessing_1D_array(self):
|
|
|
|
array = np.array([0.1, 0.2, 0, 0.7, 0])
|
|
|
|
true_label = {outputs.Label.LABEL_KEY: 3,
|
|
|
|
outputs.Label.CONFIDENCES_KEY: [
|
|
|
|
{outputs.Label.LABEL_KEY: 3, outputs.Label.CONFIDENCE_KEY: 0.7},
|
|
|
|
{outputs.Label.LABEL_KEY: 1, outputs.Label.CONFIDENCE_KEY: 0.2},
|
|
|
|
{outputs.Label.LABEL_KEY: 0, outputs.Label.CONFIDENCE_KEY: 0.1},
|
|
|
|
]}
|
2019-02-28 08:54:08 +08:00
|
|
|
out = outputs.Label()
|
2019-03-08 05:53:34 +08:00
|
|
|
label = json.loads(out.postprocess(array))
|
2019-03-06 14:34:59 +08:00
|
|
|
self.assertDictEqual(label, true_label)
|
|
|
|
|
|
|
|
def test_postprocessing_1D_array_no_confidences(self):
|
|
|
|
array = np.array([0.1, 0.2, 0, 0.7, 0])
|
|
|
|
true_label = {outputs.Label.LABEL_KEY: 3}
|
|
|
|
out = outputs.Label(show_confidences=False)
|
2019-03-08 05:53:34 +08:00
|
|
|
label = json.loads(out.postprocess(array))
|
2019-03-06 14:34:59 +08:00
|
|
|
self.assertDictEqual(label, true_label)
|
2019-02-28 08:54:08 +08:00
|
|
|
|
|
|
|
def test_postprocessing_int(self):
|
|
|
|
true_label_array = np.array([[[3]]])
|
2019-03-06 14:34:59 +08:00
|
|
|
true_label = {outputs.Label.LABEL_KEY: 3}
|
2019-02-28 08:54:08 +08:00
|
|
|
out = outputs.Label()
|
2019-03-08 05:53:34 +08:00
|
|
|
label = json.loads(out.postprocess(true_label_array))
|
2019-03-06 14:34:59 +08:00
|
|
|
self.assertDictEqual(label, true_label)
|
2019-02-28 08:54:08 +08:00
|
|
|
|
|
|
|
|
|
|
|
class TestTextbox(unittest.TestCase):
|
|
|
|
def test_path_exists(self):
|
|
|
|
out = outputs.Textbox()
|
2019-02-28 13:39:18 +08:00
|
|
|
path = out.get_template_path()
|
2019-03-06 14:34:59 +08:00
|
|
|
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
2019-02-28 08:54:08 +08:00
|
|
|
|
|
|
|
def test_postprocessing(self):
|
|
|
|
string = 'happy'
|
|
|
|
out = outputs.Textbox()
|
2019-02-28 13:39:18 +08:00
|
|
|
string = out.postprocess(string)
|
2019-02-28 08:54:08 +08:00
|
|
|
self.assertEqual(string, string)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|