gradio/test/test_outputs.py

78 lines
15 KiB
Python
Raw Normal View History

2019-02-28 08:54:08 +08:00
import numpy as np
import unittest
import os
from gradio import outputs
2019-03-06 15:23:04 +08:00
import json
2019-02-28 08:54:08 +08:00
PACKAGE_NAME = 'gradio'
2019-03-18 20:38:10 +08:00
BASE64_IMG = "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEASABIAAD/2wBDAAYEBQYFBAYGBQYHBwYIChAKCgkJChQODwwQFxQYGBcUFhYaHSUfGhsjHBYWICwgIyYnKSopGR8tMC0oMCUoKSj/2wBDAQcHBwoIChMKChMoGhYaKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCj/wgARCADIAMgDASIAAhEBAxEB/8QAHAAAAQUBAQEAAAAAAAAAAAAABQACAwQGBwEI/8QAGgEAAwEBAQEAAAAAAAAAAAAAAQIDBAAFBv/aAAwDAQACEAMQAAAB09a1OqDQ98RNiD50j1JoWICkgVvMXUcqCNvqPJeNIQ++Pr0KuXmGfJPfPrQvUA6AQpUg3Y6OqBCz2SJZ7N5Zm+SPp0Escx6s70WQRK41/dsKw2yVmDkAUiVmgur0womMbrYuzSqrUxSYuMvBGByxWsSerPHXBJz0n2aZ9bZMc8C3cKvk3alA87zHa8Y84KNW2JzFKJlKV/S8LvkZSYpc70xBb1K88NddNM7i2kEpht+eWSLVZsMlNmsWaSpmvjLxGiHz01bl10e/Tl6IaFl3kUC2IVtWyW9wCJOqyAukhJrPR7/GHrftey1JfYIqcsz0gOpyxw2wPzIwXeDOpcwoxEUd7Rn6cWpvpH2KKdNA/Nkxxz2kLStNpc5pM95JLUnAfa8usaflqiwJ5keBTT0Gnnx3NptDzyUd0jOvyQnZNS31sUhtxb/JF21YlTMiSNEKKU6cm9KF0OTVIvYxK3XdHZYxRd3NjrwXRi5qvk5zRwq1kk4uOD6V82vkgs5wfhmh9LNWkjizPmXtuHqCDKnbsyNI47vv86o7cvSPMdRou0pi6C9eEKston5CGdjQuMv3P2Lz03yZPG622LXVq3tx6+KWfZWdliNMYrCr2/H6QNE4+PXM1wAacdIGNi4iRXB9Pym6WmZyG3FYPRE3ZDPd7qo/NmHk2twZX0Mmyrtp56Yracl3sbEbFa3mbJJyqNUBiJvmviLvudcnJrruk5gux25BfX+adLzaBgUaKhr2xfmG0eJsYXxFI4G/SdugXmz5AiS9GTi8o65mcGqVaZTEJfThKJUkD0tea8KZS0glVuDubps1yTz9OIwHauLz9GXqVXS1wj+c9K4nXPG9kulJ3wp0sWKDDx2/k2jtWs0s79nw+85rnvD5Qv78rYLFfuvVpSqPv3+jPO15jAlq8/Z6oUjir4mIwNuDdmkmjgqrpmI82L2FH9fA1G9TFNvoTmSUWz5pLbmUKTCbTpSptMml5/oYeZJPX65iEtHz3NJktkPWJN0kaRMUKUX9YkD6kgf/xAAoEAACAgIBBAEEAwEBAAAAAAABAgMEABESBRMhIjEGEBQyFSMzQST/2gAIAQEAAQUC6aU7dqNhKH0qyA2HJmfsesXws/DA4OM+8aQ65eN/b4zlhPLJm44p3jKTi5H+sSPlvayLrOlx+tw+0/6Rpt1GslPonhXXlmmA7jKysrDiMcYqnXFjnE7dGAkTQqKGCqCHraEAK4LSoLjd5j8UyBDcPt8gLrPOa2AngrrEXbPGCTB5ZCAgO/HCEjTN/Y8g1YOU2HbGcvTyrFdiVtNG3JksquSzBzvyPBkIwHwh8yYPBU7Mk4TJbgz8thkd1hiXYTiMrtNoCwdZQK8H1iTKMsnkVk9ZPbIV0eIwt/ZDHyzWskYgxnA3lnBC+MVJLT2KC8T0ksR0ldfwyZb6ZYQRzyxTRv3UsKdxfqvwoyROS7ZHVdtJ6jZ0x967euTD1TzgzWaLvGojRmzebznm861CgWm4V7TDdROS9s6jq+CvFrsfrHJkntkv614/7O3rOPLCvkR4kZ28XEdM20uH51hAw5yy+BJVRyuRS8zU8VoAZFDACQ7lsIe2o/sLAZOchHgONRkZKPIIUK3tcsqK9DcFO31G8uVOqd3HtcRJ1tFav1eGQvkv6Wxxnh85BGWqo/ZiUlhEQJbcw7QPvY+JlJMI9OJJO8GGLliR5rtYsj871iTnRSSc3awXGkRJIp6ki1QsaWGXs3X3NCwD1VkEYTakMMX9rs+LiAHEbkR4CucOK2b8bxyVURjjLUrk1olQ2Ryw1EmxenZ2eC3aUccejPIKkyioxFaabWCfEOzb/wBUydzuqfb/AIozWAeQmBfOve1ZCYLRnc9SBb+TQN+RqSGYFSwbOqShK3T64hjeqFswj0nXOHlY8vD2A1k/70U8gbPb0oX14edaCZO3bzrAkkjpyKglocneoFk7xzp9s9xJ86pN3JmrtJGg4LH8S5/0t63H3Lkylmox+sUXElfsIhxdcVfHUE3Wik51On0UErWNB7ALdSqpOIIkjWW2AKftZrdSgsR/s0X6y/IHnLH+6x5yXKXk4cIOd3iGV2ztSZ+MTluBqVii5/Dl6hIskdl5D1Kz2oZZ27Y2cWHhTSJe2I2Bj+JPknWMx4yDch/zkbT0oRx44I/I4gSzxIzXYFx+pQrn8shfqx70Ud3sw/njY6llyfvSM2dOr91rf9cXf4ZBd5FW8SHyck8IP9JW4xu2zVTwzeLd94p/zJGntQR8addHwQRszrDFP+UGez055InSRG85onK1MsasXET1PyOnxK0jQwBG/wCnE0Ws+iQ/sRyWcANSmUR3bfGawwafZLNXltSV+lxpGaUYSbpMbg12hnAEYvQCRTW9krjEjAypAZCo0HP43U4nWY8fM54x2LrNK1zvVoB669J/3FKumT1ImnnrCLI+3l67ErR9VrgSdUgORdSgcWyBPbiLry5h18/GU6pfFAUO4VZX7lirK0Zisqw6hLxrO23ijIqw/pJ/nMNPL1eM5/IMs/8ALI8f5ntKrTSpUY4vTGyCKGvlfcvXcs1hIWhkDQUwp+19u3Vjxc34STYl6bBNh5144uOpLQaa0p5kiWRYUzsriw7MdFhIWSMCXutZcRw/T8A758C91otLJ1aZQPqCSIUOqV7iZ9SS6h3gmCAWXYq2I+JNk0azJVry1793tFYumSjEoEDsopMixCSflklkDK+ybv8AdPRhWKOaNpRe6eiC2R3Ol0Tcnjpx8Fh451+buXN4MUDPjOWLJizYlnxNwnXOrT9mt3Mkk3jHea8w+Mptz6qn65Y4rDMsZbotdYqIGW5hXryuXfW8UEYx4hfAJzecsLknvcBn1A/nnv7n5HiPosfO1H+pOdYu7QfPThqmTrPqKz9hn/F9mOE/Zji+i739uv8Amau248ON+7/4fTcf9X/OoSdqo42OGRLwSVwiW5zPOv2kOyPAw/YeWdtn7dacPPRPr9j+0/8Ah9Nj/wAWdek0j5TXnbPx9SWO3BgOO2gg19icJxjnwv2//8QAIxEAAgICAwABBQEAAAAAAAAAAAECEQMQEiExQQQTIjJRIP/aAAgBAwEBPwGP7El2cKRxsrVnIstCHO+tL2yy+jkyT09LwXpQtqBwoUBoWpR+RCoi4y6Jx4vUPR5Bv8bObs9EmWVZ4y+7MPyyU+WsaOBVqj7VkFXRb+CjteFN9jkY58kxagvxs5jdIU3YuhRY3fTJdIVQgSMbpPbfwUes6TIytEYR/pKaRycnZKblqHW5DySZcjjL+mGXVCkynIyLhAUqLExanks530WXZgffZGiUlBGTLyLE0xRtC81ZY5nJsx+EFkroycr/ACJeaaojmkiGSMlQ2iUUlbOWoiXR9K24NGVuUif8/wAcRNoy+D1FEVbURYlCNIm+9e6SK1m+CfusZ9N3kRm/Vk9fAhC1/8QAJBEAAgIBBAIDAQEBAAAAAAAAAAECEQMSITFBECIEMlETFGH/2gAIAQIBAT8Bl9CMtiWQTdml9FvsXAsdiVFdj/6NJKyjJk9KIy/Ta6FGhcGixY6Konuy+jkklXjSmh0iL7ExD8KXQ4dn87dsaobdCs1dGhMUN6FBV4tFF0XaOif4ONCJCnsQka6Gylz45e5aWxGJKNE2JGRCikiKR0cjN1uR9mTfZDgyy0tDle4ib3NRxHYSZWljZV8FKKHNSdCPkK5IoR9p7kcUInokPJFbUZP0aR
2019-02-28 08:54:08 +08:00
class TestLabel(unittest.TestCase):
2019-06-19 04:13:50 +08:00
def test_path_exists(self):
out = outputs.Label()
2020-07-02 08:34:51 +08:00
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__.lower())
2019-06-20 03:03:54 +08:00
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
2019-02-28 08:54:08 +08:00
2020-06-12 03:31:44 +08:00
# def test_postprocessing_string(self):
# string = 'happy'
# out = outputs.Label()
# label = json.loads(out.postprocess(string))
# self.assertDictEqual(label, {outputs.Label.LABEL_KEY: string})
#
# def test_postprocessing_1D_array(self):
# array = np.array([0.1, 0.2, 0, 0.7, 0])
# true_label = {outputs.Label.LABEL_KEY: 3,
# outputs.Label.CONFIDENCES_KEY: [
# {outputs.Label.LABEL_KEY: 3, outputs.Label.CONFIDENCE_KEY: 0.7},
# {outputs.Label.LABEL_KEY: 1, outputs.Label.CONFIDENCE_KEY: 0.2},
# {outputs.Label.LABEL_KEY: 0, outputs.Label.CONFIDENCE_KEY: 0.1},
# ]}
# out = outputs.Label()
# label = json.loads(out.postprocess(array))
# self.assertDictEqual(label, true_label)
2019-02-28 08:54:08 +08:00
2020-06-12 03:31:44 +08:00
# def test_postprocessing_1D_array_no_confidences(self):
# array = np.array([0.1, 0.2, 0, 0.7, 0])
# true_label = {outputs.Label.LABEL_KEY: 3}
# out = outputs.Label(show_confidences=False)
# label = json.loads(out.postprocess(array))
# self.assertDictEqual(label, true_label)
#
# def test_postprocessing_int(self):
# true_label_array = np.array([[[3]]])
# true_label = {outputs.Label.LABEL_KEY: 3}
# out = outputs.Label()
# label = json.loads(out.postprocess(true_label_array))
# self.assertDictEqual(label, true_label)
2019-02-28 08:54:08 +08:00
2019-06-19 04:13:50 +08:00
class TestTextbox(unittest.TestCase):
def test_path_exists(self):
out = outputs.Textbox()
2020-07-02 08:34:51 +08:00
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__.lower())
2019-06-20 03:03:54 +08:00
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
2019-02-28 08:54:08 +08:00
def test_postprocessing(self):
string = 'happy'
out = outputs.Textbox()
string = out.postprocess(string)
2019-02-28 08:54:08 +08:00
self.assertEqual(string, string)
2019-03-18 20:38:10 +08:00
class TestImage(unittest.TestCase):
2019-06-19 04:13:50 +08:00
def test_path_exists(self):
out = outputs.Image()
2020-07-02 08:34:51 +08:00
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__qualname__.lower())
2019-06-20 03:03:54 +08:00
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
2019-03-18 20:38:10 +08:00
def test_postprocessing(self):
string = BASE64_IMG
out = outputs.Textbox()
string = out.postprocess(string)
self.assertEqual(string, string)
2019-02-28 08:54:08 +08:00
if __name__ == '__main__':
unittest.main()