gradio/test/test_networking.py

192 lines
7.4 KiB
Python
Raw Normal View History

2022-01-06 05:12:58 +08:00
"""Contains tests for networking.py and app.py"""
2022-01-09 03:17:18 +08:00
import aiohttp
2022-01-05 01:58:37 +08:00
from fastapi.testclient import TestClient
import os
2021-10-22 04:02:52 +08:00
import unittest
import unittest.mock as mock
import urllib.request
2022-01-05 01:58:37 +08:00
import warnings
2021-11-13 14:33:59 +08:00
2022-01-05 01:58:37 +08:00
from gradio import flagging, Interface, networking, reset_all, utils
2021-10-22 04:02:52 +08:00
2022-01-05 01:58:37 +08:00
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
2021-10-22 04:02:52 +08:00
2021-10-24 16:56:54 +08:00
class TestPort(unittest.TestCase):
def test_port_is_in_range(self):
start = 7860
end = 7960
try:
port = networking.get_first_available_port(start, end)
self.assertTrue(start <= port <= end)
except OSError:
warnings.warn("Unable to test, no ports available")
def test_same_port_is_returned(self):
start = 7860
end = 7960
try:
port1 = networking.get_first_available_port(start, end)
port2 = networking.get_first_available_port(start, end)
self.assertEqual(port1, port2)
except OSError:
warnings.warn("Unable to test, no ports available")
2021-10-26 00:06:36 +08:00
2022-01-05 01:58:37 +08:00
class TestRoutes(unittest.TestCase):
2021-10-24 16:56:54 +08:00
def setUp(self) -> None:
2021-11-21 07:45:13 +08:00
self.io = Interface(lambda x: x, "text", "text")
2021-10-24 16:56:54 +08:00
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
2022-01-05 01:58:37 +08:00
self.client = TestClient(self.app)
2021-10-24 16:56:54 +08:00
2021-10-26 00:06:36 +08:00
def test_get_main_route(self):
2021-10-24 16:56:54 +08:00
response = self.client.get('/')
self.assertEqual(response.status_code, 200)
2021-10-26 00:06:36 +08:00
2021-11-04 05:08:19 +08:00
def test_get_api_route(self):
response = self.client.get('/api/')
2021-10-24 16:56:54 +08:00
self.assertEqual(response.status_code, 200)
def test_static_files_served_safely(self):
# Make sure things outside the static folder are not accessible
2022-01-05 01:58:37 +08:00
response = self.client.get(r'/static/..%2findex.html')
self.assertEqual(response.status_code, 404)
response = self.client.get(r'/static/..%2f..%2fapi_docs.html')
2022-01-05 01:58:37 +08:00
self.assertEqual(response.status_code, 404)
2021-10-26 00:06:36 +08:00
def test_get_config_route(self):
response = self.client.get('/config/')
self.assertEqual(response.status_code, 200)
2021-10-26 00:06:36 +08:00
2021-10-26 20:57:29 +08:00
def test_predict_route(self):
response = self.client.post('/api/predict/', json={"data": ["test"]})
self.assertEqual(response.status_code, 200)
2022-01-05 01:58:37 +08:00
output = dict(response.json())
2021-10-26 20:57:29 +08:00
self.assertEqual(output["data"], ["test"])
self.assertTrue("durations" in output)
self.assertTrue("avg_durations" in output)
2022-01-05 01:58:37 +08:00
# def test_queue_push_route(self):
# networking.queue.push = mock.MagicMock(return_value=(None, None))
# response = self.client.post('/api/queue/push/', json={"data": "test", "action": "test"})
# self.assertEqual(response.status_code, 200)
2021-10-27 04:57:03 +08:00
2022-01-05 01:58:37 +08:00
# def test_queue_push_route(self):
# networking.queue.get_status = mock.MagicMock(return_value=(None, None))
# response = self.client.post('/api/queue/status/', json={"hash": "test"})
# self.assertEqual(response.status_code, 200)
2021-10-27 04:57:03 +08:00
2021-10-24 16:56:54 +08:00
def tearDown(self) -> None:
self.io.close()
2021-11-21 07:45:13 +08:00
reset_all()
2021-10-26 00:06:36 +08:00
2022-01-05 01:58:37 +08:00
class TestAuthenticatedRoutes(unittest.TestCase):
2021-10-26 00:06:36 +08:00
def setUp(self) -> None:
2021-11-21 07:45:13 +08:00
self.io = Interface(lambda x: x, "text", "text")
2021-10-26 00:06:36 +08:00
self.app, _, _ = self.io.launch(auth=("test", "correct_password"), prevent_thread_lock=True)
2022-01-05 01:58:37 +08:00
self.client = TestClient(self.app)
2021-10-26 00:06:36 +08:00
def test_post_login(self):
response = self.client.post('/login', data=dict(username="test", password="correct_password"))
2022-01-06 05:12:58 +08:00
self.assertEqual(response.status_code, 302)
2021-10-26 00:06:36 +08:00
response = self.client.post('/login', data=dict(username="test", password="incorrect_password"))
2022-01-05 01:58:37 +08:00
self.assertEqual(response.status_code, 400)
2021-10-26 00:06:36 +08:00
def tearDown(self) -> None:
self.io.close()
2021-11-21 07:45:13 +08:00
reset_all()
2021-10-26 00:06:36 +08:00
2021-11-13 02:10:55 +08:00
2021-10-26 20:57:29 +08:00
class TestInterfaceCustomParameters(unittest.TestCase):
def test_show_error(self):
2021-11-21 07:45:13 +08:00
io = Interface(lambda x: 1/x, "number", "number")
2021-10-26 20:57:29 +08:00
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
2022-01-05 01:58:37 +08:00
client = TestClient(app)
2021-10-26 20:57:29 +08:00
response = client.post('/api/predict/', json={"data": [0]})
self.assertEqual(response.status_code, 500)
2022-01-05 01:58:37 +08:00
self.assertTrue("error" in response.json())
2021-10-26 20:57:29 +08:00
io.close()
class TestFlagging(unittest.TestCase):
2022-01-09 03:17:18 +08:00
def test_flagging_analytics(self):
2021-11-21 07:45:13 +08:00
callback = flagging.CSVLogger()
callback.flag = mock.MagicMock()
2022-01-09 03:17:18 +08:00
aiohttp.ClientSession.post = mock.MagicMock()
2022-01-10 13:49:03 +08:00
aiohttp.ClientSession.post.__aenter__ = None
aiohttp.ClientSession.post.__aexit__ = None
2022-01-05 01:58:37 +08:00
io = Interface(
lambda x: x, "text", "text",
analytics_enabled=True, flagging_callback=callback)
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
2022-01-05 01:58:37 +08:00
client = TestClient(app)
response = client.post(
'/api/flag/',
json={"data": {"input_data": ["test"], "output_data": ["test"]}})
2022-01-09 03:17:18 +08:00
aiohttp.ClientSession.post.assert_called()
2021-11-21 07:45:13 +08:00
callback.flag.assert_called_once()
self.assertEqual(response.status_code, 200)
io.close()
2021-11-13 02:10:55 +08:00
2021-10-27 04:57:03 +08:00
class TestInterpretation(unittest.TestCase):
2022-01-10 13:49:03 +08:00
def test_interpretation(self):
2022-01-05 01:58:37 +08:00
io = Interface(
lambda x: len(x), "text", "label",
interpretation="default", analytics_enabled=True)
2021-10-27 04:57:03 +08:00
app, _, _ = io.launch(prevent_thread_lock=True)
2022-01-05 01:58:37 +08:00
client = TestClient(app)
2022-01-10 13:49:03 +08:00
aiohttp.ClientSession.post = mock.MagicMock()
aiohttp.ClientSession.post.__aenter__ = None
aiohttp.ClientSession.post.__aexit__ = None
2021-10-27 04:57:03 +08:00
io.interpret = mock.MagicMock(return_value=(None, None))
2022-01-05 01:58:37 +08:00
response = client.post(
'/api/interpret/', json={"data": ["test test"]})
2022-01-10 13:49:03 +08:00
aiohttp.ClientSession.post.assert_called()
2021-10-27 04:57:03 +08:00
self.assertEqual(response.status_code, 200)
io.close()
class TestURLs(unittest.TestCase):
def test_url_ok(self):
urllib.request.urlopen = mock.MagicMock(return_value="test")
res = networking.url_request("http://www.gradio.app")
2021-11-13 14:33:59 +08:00
self.assertEqual(res, "test")
def test_setup_tunnel(self):
networking.create_tunnel = mock.MagicMock(return_value="test")
res = networking.setup_tunnel(None, None)
2021-11-13 14:33:59 +08:00
self.assertEqual(res, "test")
def test_url_ok(self):
res = networking.url_ok("https://www.gradio.app")
self.assertTrue(res)
2021-10-26 20:57:29 +08:00
2021-10-22 04:02:52 +08:00
2022-01-05 01:58:37 +08:00
# class TestQueuing(unittest.TestCase):
# def test_queueing(self):
# # mock queue methods and post method
# networking.queue.pop = mock.MagicMock(return_value=(None, None, None, 'predict'))
# networking.queue.pass_job = mock.MagicMock(return_value=(None, None))
# networking.queue.fail_job = mock.MagicMock(return_value=(None, None))
# networking.queue.start_job = mock.MagicMock(return_value=None)
# requests.post = mock.MagicMock(return_value=mock.MagicMock(status_code=200))
# # execute queue action successfully
# networking.queue_thread('test_path', test_mode=True)
# networking.queue.pass_job.assert_called_once()
# # execute queue action unsuccessfully
# requests.post = mock.MagicMock(return_value=mock.MagicMock(status_code=500))
# networking.queue_thread('test_path', test_mode=True)
# networking.queue.fail_job.assert_called_once()
# # no more things on the queue so methods shouldn't be called any more times
# networking.queue.pop = mock.MagicMock(return_value=None)
# networking.queue.pass_job.assert_called_once()
# networking.queue.fail_job.assert_called_once()
2021-11-04 05:08:19 +08:00
2021-10-22 04:02:52 +08:00
if __name__ == '__main__':
unittest.main()