fixed broken tests

This commit is contained in:
Abubakar Abid 2020-07-08 17:02:17 -05:00
parent 78cdc465d9
commit 6d85108529
3 changed files with 5 additions and 5 deletions

View File

@ -55,7 +55,7 @@ class Label(AbstractOutput):
def postprocess(self, prediction): def postprocess(self, prediction):
if isinstance(prediction, str) or isinstance(prediction, Number): if isinstance(prediction, str) or isinstance(prediction, Number):
return {"label": prediction} return {"label": str(prediction)}
elif isinstance(prediction, dict): elif isinstance(prediction, dict):
sorted_pred = sorted( sorted_pred = sorted(
prediction.items(), prediction.items(),

View File

@ -7,13 +7,13 @@ import gradio.outputs
class TestInterface(unittest.TestCase): class TestInterface(unittest.TestCase):
def test_input_output_mapping(self): def test_input_output_mapping(self):
io = gr.Interface(inputs='SketCHPad', outputs='textBOX', fn=lambda x: x) io = gr.Interface(inputs='SketCHPad', outputs='TexT', fn=lambda x: x)
self.assertIsInstance(io.input_interfaces[0], gradio.inputs.Sketchpad) self.assertIsInstance(io.input_interfaces[0], gradio.inputs.Sketchpad)
self.assertIsInstance(io.output_interfaces[0], gradio.outputs.Textbox) self.assertIsInstance(io.output_interfaces[0], gradio.outputs.Textbox)
def test_input_interface_is_instance(self): def test_input_interface_is_instance(self):
inp = gradio.inputs.Image() inp = gradio.inputs.Image()
io = gr.Interface(inputs=inp, outputs='textBOX', fn=lambda x: x) io = gr.Interface(inputs=inp, outputs='teXT', fn=lambda x: x)
self.assertEqual(io.input_interfaces[0], inp) self.assertEqual(io.input_interfaces[0], inp)
def test_output_interface_is_instance(self): def test_output_interface_is_instance(self):
@ -24,7 +24,7 @@ class TestInterface(unittest.TestCase):
def test_prediction(self): def test_prediction(self):
def model(x): def model(x):
return 2*x return 2*x
io = gr.Interface(inputs='textbox', outputs='textBOX', fn=model) io = gr.Interface(inputs='textbox', outputs='TEXT', fn=model)
self.assertEqual(io.predict[0](11), 22) self.assertEqual(io.predict[0](11), 22)

View File

@ -27,7 +27,7 @@ class TestGetAvailablePort(unittest.TestCase):
s.bind((networking.LOCALHOST_NAME, port)) # Bind to the port s.bind((networking.LOCALHOST_NAME, port)) # Bind to the port
new_port = networking.get_first_available_port(initial, final) new_port = networking.get_first_available_port(initial, final)
s.close() s.close()
self.assertFalse(port==new_port) self.assertFalse(port == new_port)
# class TestSetSampleData(unittest.TestCase): # class TestSetSampleData(unittest.TestCase):