gradio/test/test_networking.py

138 lines
4.6 KiB
Python
Raw Normal View History

2022-01-06 05:12:58 +08:00
"""Contains tests for networking.py and app.py"""
2022-01-05 01:58:37 +08:00
import os
2021-10-22 04:02:52 +08:00
import unittest
import unittest.mock as mock
import urllib
2022-01-05 01:58:37 +08:00
import warnings
2021-11-13 14:33:59 +08:00
from fastapi.testclient import TestClient
2022-03-25 13:58:07 +08:00
import gradio as gr
from gradio import Interface, networking
2021-10-22 04:02:52 +08:00
2022-01-05 01:58:37 +08:00
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
2021-10-22 04:02:52 +08:00
2021-10-24 16:56:54 +08:00
class TestPort(unittest.TestCase):
def test_port_is_in_range(self):
start = 7860
end = 7960
try:
port = networking.get_first_available_port(start, end)
self.assertTrue(start <= port <= end)
except OSError:
warnings.warn("Unable to test, no ports available")
def test_same_port_is_returned(self):
start = 7860
end = 7960
try:
port1 = networking.get_first_available_port(start, end)
port2 = networking.get_first_available_port(start, end)
self.assertEqual(port1, port2)
except OSError:
warnings.warn("Unable to test, no ports available")
2021-10-26 00:06:36 +08:00
class TestInterfaceErrors(unittest.TestCase):
def test_processing_error(self):
io = Interface(lambda x: 1 / x, "number", "number")
2021-10-26 20:57:29 +08:00
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
2022-01-05 01:58:37 +08:00
client = TestClient(app)
response = client.post("/api/predict/", json={"data": [0], "fn_index": 0})
2021-10-26 20:57:29 +08:00
self.assertEqual(response.status_code, 500)
2022-01-05 01:58:37 +08:00
self.assertTrue("error" in response.json())
2021-10-26 20:57:29 +08:00
io.close()
2022-02-28 05:47:12 +08:00
def test_validation_error(self):
io = Interface(lambda x: 1 / x, "number", "number")
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"fn_index": [0]})
self.assertEqual(response.status_code, 422)
io.close()
2022-02-28 05:47:12 +08:00
class TestStartServer(unittest.TestCase):
def test_start_server(self):
io = Interface(lambda x: x, "number", "number")
2022-02-28 21:01:13 +08:00
io.favicon_path = None
io.config = io.get_config_file()
io.show_error = True
2022-03-25 13:58:07 +08:00
io.flagging_callback.setup(gr.Number(), io.flagging_dir)
2022-02-28 21:01:13 +08:00
io.auth = None
2022-02-28 21:07:54 +08:00
port = networking.get_first_available_port(
2022-02-28 05:47:12 +08:00
networking.INITIAL_PORT_VALUE,
networking.INITIAL_PORT_VALUE + networking.TRY_NUM_PORTS,
)
io.enable_queue = False
_, local_path, _, server = networking.start_server(io, server_port=port)
url = urllib.parse.urlparse(local_path)
self.assertEquals(url.scheme, "http")
self.assertEquals(url.port, port)
server.close()
2021-10-26 20:57:29 +08:00
2022-04-19 16:35:17 +08:00
# class TestFlagging(unittest.TestCase):
# def test_flagging_analytics(self):
# callback = flagging.CSVLogger()
# callback.flag = mock.MagicMock()
# aiohttp.ClientSession.post = mock.MagicMock()
# aiohttp.ClientSession.post.__aenter__ = None
# aiohttp.ClientSession.post.__aexit__ = None
# io = Interface(
# lambda x: x,
# "text",
# "text",
# analytics_enabled=True,
# flagging_callback=callback,
# )
# app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
# client = TestClient(app)
# response = client.post(
# "/api/flag/",
# json={"data": {"input_data": ["test"], "output_data": ["test"]}},
# )
# aiohttp.ClientSession.post.assert_called()
# callback.flag.assert_called_once()
# self.assertEqual(response.status_code, 200)
# io.close()
# class TestInterpretation(unittest.TestCase):
# def test_interpretation(self):
# io = Interface(
# lambda x: len(x),
# "text",
# "label",
# interpretation="default",
# analytics_enabled=True,
# )
# app, _, _ = io.launch(prevent_thread_lock=True)
# client = TestClient(app)
# aiohttp.ClientSession.post = mock.MagicMock()
# aiohttp.ClientSession.post.__aenter__ = None
# aiohttp.ClientSession.post.__aexit__ = None
# io.interpret = mock.MagicMock(return_value=(None, None))
# response = client.post("/api/interpret/", json={"data": ["test test"]})
# aiohttp.ClientSession.post.assert_called()
# self.assertEqual(response.status_code, 200)
# io.close()
2021-10-27 04:57:03 +08:00
class TestURLs(unittest.TestCase):
def test_setup_tunnel(self):
networking.create_tunnel = mock.MagicMock(return_value="test")
res = networking.setup_tunnel(None, None)
2021-11-13 14:33:59 +08:00
self.assertEqual(res, "test")
def test_url_ok(self):
res = networking.url_ok("https://www.gradio.app")
self.assertTrue(res)
2021-10-26 20:57:29 +08:00
2021-10-22 04:02:52 +08:00
if __name__ == "__main__":
2021-10-22 04:02:52 +08:00
unittest.main()