gradio/test/test_networking.py

130 lines
4.2 KiB
Python
Raw Normal View History

2022-01-06 05:12:58 +08:00
"""Contains tests for networking.py and app.py"""
2022-01-05 01:58:37 +08:00
import os
2021-10-22 04:02:52 +08:00
import unittest
import unittest.mock as mock
import urllib
2022-01-05 01:58:37 +08:00
import warnings
2021-11-13 14:33:59 +08:00
import aiohttp
from fastapi.testclient import TestClient
2022-03-25 13:58:07 +08:00
import gradio as gr
from gradio import Interface, flagging, networking
2021-10-22 04:02:52 +08:00
2022-01-05 01:58:37 +08:00
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
2021-10-22 04:02:52 +08:00
2021-10-24 16:56:54 +08:00
class TestPort(unittest.TestCase):
def test_port_is_in_range(self):
start = 7860
end = 7960
try:
port = networking.get_first_available_port(start, end)
self.assertTrue(start <= port <= end)
except OSError:
warnings.warn("Unable to test, no ports available")
def test_same_port_is_returned(self):
start = 7860
end = 7960
try:
port1 = networking.get_first_available_port(start, end)
port2 = networking.get_first_available_port(start, end)
self.assertEqual(port1, port2)
except OSError:
warnings.warn("Unable to test, no ports available")
2021-10-26 00:06:36 +08:00
2021-10-26 20:57:29 +08:00
class TestInterfaceCustomParameters(unittest.TestCase):
def test_show_error(self):
io = Interface(lambda x: 1 / x, "number", "number")
2021-10-26 20:57:29 +08:00
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
2022-01-05 01:58:37 +08:00
client = TestClient(app)
response = client.post("/api/predict/", json={"data": [0]})
2021-10-26 20:57:29 +08:00
self.assertEqual(response.status_code, 500)
2022-01-05 01:58:37 +08:00
self.assertTrue("error" in response.json())
2021-10-26 20:57:29 +08:00
io.close()
2022-02-28 05:47:12 +08:00
class TestStartServer(unittest.TestCase):
def test_start_server(self):
io = Interface(lambda x: x, "number", "number")
2022-02-28 21:01:13 +08:00
io.favicon_path = None
io.config = io.get_config_file()
io.show_error = True
2022-03-25 13:58:07 +08:00
io.flagging_callback.setup(gr.Number(), io.flagging_dir)
2022-02-28 21:01:13 +08:00
io.auth = None
2022-02-28 21:07:54 +08:00
port = networking.get_first_available_port(
2022-02-28 05:47:12 +08:00
networking.INITIAL_PORT_VALUE,
networking.INITIAL_PORT_VALUE + networking.TRY_NUM_PORTS,
)
_, local_path, _, server = networking.start_server(io, server_port=port)
url = urllib.parse.urlparse(local_path)
self.assertEquals(url.scheme, "http")
self.assertEquals(url.port, port)
server.close()
2021-10-26 20:57:29 +08:00
class TestFlagging(unittest.TestCase):
2022-01-09 03:17:18 +08:00
def test_flagging_analytics(self):
2021-11-21 07:45:13 +08:00
callback = flagging.CSVLogger()
callback.flag = mock.MagicMock()
2022-01-09 03:17:18 +08:00
aiohttp.ClientSession.post = mock.MagicMock()
2022-01-10 13:49:03 +08:00
aiohttp.ClientSession.post.__aenter__ = None
aiohttp.ClientSession.post.__aexit__ = None
2022-01-05 01:58:37 +08:00
io = Interface(
lambda x: x,
"text",
"text",
analytics_enabled=True,
flagging_callback=callback,
)
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
2022-01-05 01:58:37 +08:00
client = TestClient(app)
response = client.post(
"/api/flag/",
json={"data": {"input_data": ["test"], "output_data": ["test"]}},
)
2022-01-09 03:17:18 +08:00
aiohttp.ClientSession.post.assert_called()
2021-11-21 07:45:13 +08:00
callback.flag.assert_called_once()
self.assertEqual(response.status_code, 200)
io.close()
2021-11-13 02:10:55 +08:00
2021-10-27 04:57:03 +08:00
class TestInterpretation(unittest.TestCase):
2022-01-10 13:49:03 +08:00
def test_interpretation(self):
2022-01-05 01:58:37 +08:00
io = Interface(
lambda x: len(x),
"text",
"label",
interpretation="default",
analytics_enabled=True,
)
2021-10-27 04:57:03 +08:00
app, _, _ = io.launch(prevent_thread_lock=True)
2022-01-05 01:58:37 +08:00
client = TestClient(app)
2022-01-10 13:49:03 +08:00
aiohttp.ClientSession.post = mock.MagicMock()
aiohttp.ClientSession.post.__aenter__ = None
aiohttp.ClientSession.post.__aexit__ = None
2021-10-27 04:57:03 +08:00
io.interpret = mock.MagicMock(return_value=(None, None))
response = client.post("/api/interpret/", json={"data": ["test test"]})
2022-01-10 13:49:03 +08:00
aiohttp.ClientSession.post.assert_called()
2021-10-27 04:57:03 +08:00
self.assertEqual(response.status_code, 200)
io.close()
class TestURLs(unittest.TestCase):
def test_setup_tunnel(self):
networking.create_tunnel = mock.MagicMock(return_value="test")
res = networking.setup_tunnel(None, None)
2021-11-13 14:33:59 +08:00
self.assertEqual(res, "test")
def test_url_ok(self):
res = networking.url_ok("https://www.gradio.app")
self.assertTrue(res)
2021-10-26 20:57:29 +08:00
2021-10-22 04:02:52 +08:00
if __name__ == "__main__":
2021-10-22 04:02:52 +08:00
unittest.main()