gradio/test/test_outputs.py
2019-02-27 16:54:08 -08:00

51 lines
1.4 KiB
Python

import numpy as np
import unittest
import os
from gradio import outputs
PACKAGE_NAME = 'gradio'
class TestLabel(unittest.TestCase):
def test_path_exists(self):
out = outputs.Label()
path = out._get_template_path()
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_postprocessing_string(self):
string = 'happy'
out = outputs.Label()
label = out._postprocess(string)
self.assertEqual(label, string)
def test_postprocessing_one_hot(self):
one_hot = np.array([0, 0, 0, 1, 0])
true_label = 3
out = outputs.Label()
label = out._postprocess(one_hot)
self.assertEqual(label, true_label)
def test_postprocessing_int(self):
true_label_array = np.array([[[3]]])
true_label = 3
out = outputs.Label()
label = out._postprocess(true_label_array)
self.assertEqual(label, true_label)
class TestTextbox(unittest.TestCase):
def test_path_exists(self):
out = outputs.Textbox()
path = out._get_template_path()
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_postprocessing(self):
string = 'happy'
out = outputs.Textbox()
string = out._postprocess(string)
self.assertEqual(string, string)
if __name__ == '__main__':
unittest.main()