mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
check for sample inputs func and add unit test
This commit is contained in:
parent
c3f79b0937
commit
53b194f755
@ -139,7 +139,7 @@ class Interface:
|
||||
self.output_interface.__class__.__name__.lower(),
|
||||
)
|
||||
|
||||
if self.input_interface.__class__.__name__.lower() == "sketchpad" or self.input_interface.__class__.__name__.lower() == "textbox":
|
||||
if hasattr(self.input_interface, 'get_sample_inputs'):
|
||||
networking.set_sample_data_in_config_file(
|
||||
output_directory,
|
||||
self.input_interface.get_sample_inputs()
|
||||
|
2
setup.py
2
setup.py
@ -21,5 +21,7 @@ setup(
|
||||
'psutil',
|
||||
'paramiko',
|
||||
'scipy',
|
||||
'IPython',
|
||||
'scikit-image',
|
||||
],
|
||||
)
|
||||
|
@ -1,8 +1,11 @@
|
||||
import unittest
|
||||
from gradio import networking
|
||||
from gradio import inputs
|
||||
from gradio import outputs
|
||||
import socket
|
||||
import tempfile
|
||||
import os
|
||||
import json
|
||||
LOCALHOST_NAME = 'localhost'
|
||||
|
||||
|
||||
@ -27,6 +30,19 @@ class TestGetAvailablePort(unittest.TestCase):
|
||||
s.close()
|
||||
self.assertFalse(port==new_port)
|
||||
|
||||
class TestSetSampleData(unittest.TestCase):
|
||||
def test_set_sample_data(self):
|
||||
test_array = ["test1", "test2", "test3"]
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
inp = inputs.Sketchpad()
|
||||
out = outputs.Label()
|
||||
networking.build_template(temp_dir, inp, out)
|
||||
networking.set_sample_data_in_config_file(temp_dir, test_array)
|
||||
config_file = os.path.join(temp_dir, 'static/config.json')
|
||||
with open(config_file) as json_file:
|
||||
data = json.load(json_file)
|
||||
self.assertFalse(test_array == data["sample_inputs"])
|
||||
|
||||
# class TestCopyFiles(unittest.TestCase):
|
||||
# def test_copy_files(self):
|
||||
# filename = "a.txt"
|
||||
|
Loading…
Reference in New Issue
Block a user