check for sample inputs func and add unit test

This commit is contained in:
dawoodkhan82 2019-07-23 14:40:05 -07:00
parent c3f79b0937
commit 53b194f755
3 changed files with 19 additions and 1 deletions

View File

@ -139,7 +139,7 @@ class Interface:
self.output_interface.__class__.__name__.lower(),
)
if self.input_interface.__class__.__name__.lower() == "sketchpad" or self.input_interface.__class__.__name__.lower() == "textbox":
if hasattr(self.input_interface, 'get_sample_inputs'):
networking.set_sample_data_in_config_file(
output_directory,
self.input_interface.get_sample_inputs()

View File

@ -21,5 +21,7 @@ setup(
'psutil',
'paramiko',
'scipy',
'IPython',
'scikit-image',
],
)

View File

@ -1,8 +1,11 @@
import unittest
from gradio import networking
from gradio import inputs
from gradio import outputs
import socket
import tempfile
import os
import json
LOCALHOST_NAME = 'localhost'
@ -27,6 +30,19 @@ class TestGetAvailablePort(unittest.TestCase):
s.close()
self.assertFalse(port==new_port)
class TestSetSampleData(unittest.TestCase):
def test_set_sample_data(self):
test_array = ["test1", "test2", "test3"]
temp_dir = tempfile.mkdtemp()
inp = inputs.Sketchpad()
out = outputs.Label()
networking.build_template(temp_dir, inp, out)
networking.set_sample_data_in_config_file(temp_dir, test_array)
config_file = os.path.join(temp_dir, 'static/config.json')
with open(config_file) as json_file:
data = json.load(json_file)
self.assertFalse(test_array == data["sample_inputs"])
# class TestCopyFiles(unittest.TestCase):
# def test_copy_files(self):
# filename = "a.txt"