import numpy as np import unittest import os from gradio import outputs import json PACKAGE_NAME = 'gradio' BASE64_IMG = "" # Where to find the static resources associated with each template. BASE_OUTPUT_INTERFACE_JS_PATH = 'static/js/interfaces/output/{}.js' class TestLabel(unittest.TestCase): def test_postprocessing_string(self): string = 'happy' out = outputs.Label() label = out.postprocess(string) self.assertDictEqual(label, {outputs.Label.LABEL_KEY: string}) def test_postprocessing_dict(self): orig_label = { 3: 0.7, 1: 0.2, 0: 0.1 } true_label = {outputs.Label.LABEL_KEY: 3, outputs.Label.CONFIDENCES_KEY: [ {outputs.Label.LABEL_KEY: 3, outputs.Label.CONFIDENCE_KEY: 0.7}, {outputs.Label.LABEL_KEY: 1, outputs.Label.CONFIDENCE_KEY: 0.2}, {outputs.Label.LABEL_KEY: 0, outputs.Label.CONFIDENCE_KEY: 0.1}, ]} out = outputs.Label() label = out.postprocess(orig_label) self.assertDictEqual(label, true_label) def test_postprocessing_array(self): array = np.array([0.1, 0.2, 0, 0.7, 0]) out = outputs.Label() self.assertRaises(ValueError, out.postprocess, array) def test_postprocessing_int(self): label = 3 true_label = {outputs.Label.LABEL_KEY: '3'} out = outputs.Label() label = out.postprocess(label) self.assertDictEqual(label, true_label) class TestTextbox(unittest.TestCase): def test_postprocessing(self): string = 'happy' out = outputs.Textbox() string = out.postprocess(string) self.assertEqual(string, string) class TestImage(unittest.TestCase): def test_postprocessing(self): string = BASE64_IMG out = outputs.Textbox() string = out.postprocess(string) self.assertEqual(string, string) if __name__ == '__main__': unittest.main()