gradio/test/test_networking.py

247 lines
10 KiB
Python
Raw Normal View History

2021-10-22 04:02:52 +08:00
from gradio import networking
2021-10-24 16:56:54 +08:00
import gradio as gr
2021-10-22 04:02:52 +08:00
import unittest
import unittest.mock as mock
2021-10-24 16:56:54 +08:00
import ipaddress
import requests
import warnings
import tempfile
from unittest.mock import ANY
import urllib.request
import os
os.environ["GRADIO_ANALYTICS_ENABLED"] = "True" # Enables analytics
2021-10-22 04:02:52 +08:00
2021-10-24 16:56:54 +08:00
class TestUser(unittest.TestCase):
def test_id(self):
user = networking.User("test")
self.assertEqual(user.get_id(), "test")
def test_load_user(self):
user = networking.load_user("test")
self.assertEqual(user.get_id(), "test")
class TestIPAddress(unittest.TestCase):
def test_get_ip(self):
ip = networking.get_local_ip_address()
try: # check whether ip is valid
ipaddress.ip_address(ip)
except ValueError:
self.fail("Invalid IP address")
@mock.patch("requests.get")
def test_get_ip_without_internet(self, mock_get):
mock_get.side_effect = requests.ConnectionError()
ip = networking.get_local_ip_address()
self.assertEqual(ip, "No internet connection")
class TestPort(unittest.TestCase):
def test_port_is_in_range(self):
start = 7860
end = 7960
try:
port = networking.get_first_available_port(start, end)
self.assertTrue(start <= port <= end)
except OSError:
warnings.warn("Unable to test, no ports available")
def test_same_port_is_returned(self):
start = 7860
end = 7960
try:
port1 = networking.get_first_available_port(start, end)
port2 = networking.get_first_available_port(start, end)
self.assertEqual(port1, port2)
except OSError:
warnings.warn("Unable to test, no ports available")
2021-10-26 00:06:36 +08:00
2021-10-24 16:56:54 +08:00
class TestFlaskRoutes(unittest.TestCase):
def setUp(self) -> None:
self.io = gr.Interface(lambda x: x, "text", "text", analytics_enabled=False)
2021-10-24 16:56:54 +08:00
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
self.client = self.app.test_client()
2021-10-26 00:06:36 +08:00
def test_get_main_route(self):
2021-10-24 16:56:54 +08:00
response = self.client.get('/')
self.assertEqual(response.status_code, 200)
2021-10-26 00:06:36 +08:00
2021-11-04 05:08:19 +08:00
def test_get_api_route(self):
response = self.client.get('/api/')
self.assertEqual(response.status_code, 200)
2021-10-26 00:06:36 +08:00
def test_get_config_route(self):
response = self.client.get('/config/')
2021-10-24 16:56:54 +08:00
self.assertEqual(response.status_code, 200)
2021-10-26 20:57:29 +08:00
def test_enable_sharing_route(self):
path = "www.gradio.app"
response = self.client.get('/enable_sharing/www.gradio.app')
self.assertEqual(response.status_code, 200)
self.assertEqual(self.io.config["share_url"], path)
def test_predict_route(self):
response = self.client.post('/api/predict/', json={"data": ["test"]})
self.assertEqual(response.status_code, 200)
output = dict(response.get_json())
self.assertEqual(output["data"], ["test"])
self.assertTrue("durations" in output)
self.assertTrue("avg_durations" in output)
2021-10-27 04:57:03 +08:00
def test_queue_push_route(self):
networking.queue.push = mock.MagicMock(return_value=(None, None))
response = self.client.post('/api/queue/push/', json={"data": "test", "action": "test"})
self.assertEqual(response.status_code, 200)
def test_queue_push_route(self):
networking.queue.get_status = mock.MagicMock(return_value=(None, None))
response = self.client.post('/api/queue/status/', json={"hash": "test"})
self.assertEqual(response.status_code, 200)
2021-10-27 04:57:03 +08:00
2021-10-24 16:56:54 +08:00
def tearDown(self) -> None:
self.io.close()
2021-10-26 00:06:36 +08:00
gr.reset_all()
class TestAuthenticatedFlaskRoutes(unittest.TestCase):
def setUp(self) -> None:
self.io = gr.Interface(lambda x: x, "text", "text", analytics_enabled=False)
2021-10-26 00:06:36 +08:00
self.app, _, _ = self.io.launch(auth=("test", "correct_password"), prevent_thread_lock=True)
self.client = self.app.test_client()
def test_get_login_route(self):
response = self.client.get('/login')
self.assertEqual(response.status_code, 200)
def test_post_login(self):
response = self.client.post('/login', data=dict(username="test", password="correct_password"))
self.assertEqual(response.status_code, 302)
response = self.client.post('/login', data=dict(username="test", password="incorrect_password"))
self.assertEqual(response.status_code, 401)
def tearDown(self) -> None:
self.io.close()
gr.reset_all()
2021-10-26 20:57:29 +08:00
class TestInterfaceCustomParameters(unittest.TestCase):
def test_show_error(self):
io = gr.Interface(lambda x: 1/x, "number", "number", analytics_enabled=False)
2021-10-26 20:57:29 +08:00
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = app.test_client()
response = client.post('/api/predict/', json={"data": [0]})
self.assertEqual(response.status_code, 500)
self.assertTrue("error" in response.get_json())
io.close()
def test_feature_logging(self):
with mock.patch('requests.post') as mock_post:
io = gr.Interface(lambda x: 1/x, "number", "number")
io.launch(show_error=True, prevent_thread_lock=True)
networking.log_feature_analytics("test_feature")
mock_post.assert_called_with(networking.GRADIO_FEATURE_ANALYTICS_URL, data=ANY, timeout=ANY)
io.close()
2021-10-26 20:57:29 +08:00
io = gr.Interface(lambda x: 1/x, "number", "number", analytics_enabled=False)
print(io.analytics_enabled)
2021-10-26 20:57:29 +08:00
io.launch(show_error=True, prevent_thread_lock=True)
with mock.patch('requests.post') as mock_post:
networking.log_feature_analytics("test_feature")
mock_post.assert_not_called()
io.close()
class TestFlagging(unittest.TestCase):
def test_num_rows_written(self):
io = gr.Interface(lambda x: x, "text", "text", analytics_enabled=False)
io.launch(prevent_thread_lock=True)
with tempfile.TemporaryDirectory() as tmpdirname:
row_count = networking.flag_data(["test"], ["test"], flag_path=tmpdirname)
self.assertEquals(row_count, 1) # 2 rows written including header
row_count = networking.flag_data("test", "test", flag_path=tmpdirname)
self.assertEquals(row_count, 2) # 3 rows written including header
io.close()
@mock.patch("requests.post")
@mock.patch("gradio.networking.flag_data")
def test_flagging_analytics(self, mock_flag, mock_post):
io = gr.Interface(lambda x: x, "text", "text")
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = app.test_client()
response = client.post('/api/flag/', json={"data": {"input_data": ["test"], "output_data": ["test"]}})
mock_post.assert_any_call(networking.GRADIO_FEATURE_ANALYTICS_URL, data=ANY, timeout=ANY)
mock_flag.assert_called_once()
self.assertEqual(response.status_code, 200)
io.close()
@mock.patch("requests.post")
2021-10-27 04:57:03 +08:00
class TestInterpretation(unittest.TestCase):
def test_interpretation(self, mock_post):
2021-10-27 04:57:03 +08:00
io = gr.Interface(lambda x: len(x), "text", "label", interpretation="default")
app, _, _ = io.launch(prevent_thread_lock=True)
client = app.test_client()
io.interpret = mock.MagicMock(return_value=(None, None))
response = client.post('/api/interpret/', json={"data": ["test test"]})
mock_post.assert_any_call(networking.GRADIO_FEATURE_ANALYTICS_URL, data=ANY, timeout=ANY)
2021-10-27 04:57:03 +08:00
self.assertEqual(response.status_code, 200)
io.close()
class TestState(unittest.TestCase):
def test_state_initialization(self):
io = gr.Interface(lambda x: len(x), "text", "label", analytics_enabled=False)
2021-10-27 04:57:03 +08:00
app, _, _ = io.launch(prevent_thread_lock=True)
with app.test_request_context():
self.assertIsNone(networking.get_state())
def test_state_value(self):
io = gr.Interface(lambda x: len(x), "text", "label", analytics_enabled=False)
io.launch(prevent_thread_lock=True)
app, _, _ = io.launch(prevent_thread_lock=True)
with app.test_request_context():
networking.set_state("test")
client = app.test_client()
client.post('/api/predict/', json={"data": [0]})
self.assertEquals(networking.get_state(), "test")
class TestURLs(unittest.TestCase):
def test_url_ok(self):
urllib.request.urlopen = mock.MagicMock(return_value="test")
res = networking.url_request("http://www.gradio.app")
self.assertEquals(res, "test")
def test_setup_tunnel(self):
networking.create_tunnel = mock.MagicMock(return_value="test")
res = networking.setup_tunnel(None, None)
self.assertEquals(res, "test")
def test_url_ok(self):
res = networking.url_ok("https://www.gradio.app")
self.assertTrue(res)
2021-10-26 20:57:29 +08:00
2021-10-22 04:02:52 +08:00
2021-11-04 05:08:19 +08:00
class TestQueuing(unittest.TestCase):
def test_queueing(self):
io = gr.Interface(lambda x: x, "text", "text", analytics_enabled=False)
2021-11-04 05:08:19 +08:00
app, _, _ = io.launch(prevent_thread_lock=True)
client = app.test_client()
# mock queue methods and post method
networking.queue.pop = mock.MagicMock(return_value=(None, None, None, 'predict'))
networking.queue.pass_job = mock.MagicMock(return_value=(None, None))
networking.queue.fail_job = mock.MagicMock(return_value=(None, None))
networking.queue.start_job = mock.MagicMock(return_value=None)
requests.post = mock.MagicMock(return_value=mock.MagicMock(status_code=200))
# execute queue action successfully
networking.queue_thread('test_path', test_mode=True)
networking.queue.pass_job.assert_called_once()
# execute queue action unsuccessfully
requests.post = mock.MagicMock(return_value=mock.MagicMock(status_code=500))
networking.queue_thread('test_path', test_mode=True)
networking.queue.fail_job.assert_called_once()
# no more things on the queue so methods shouldn't be called any more times
networking.queue.pop = mock.MagicMock(return_value=None)
networking.queue.pass_job.assert_called_once()
networking.queue.fail_job.assert_called_once()
2021-10-22 04:02:52 +08:00
if __name__ == '__main__':
unittest.main()