diff --git a/test/test_networking.py b/test/test_networking.py index e035a1ed22..261458686f 100644 --- a/test/test_networking.py +++ b/test/test_networking.py @@ -7,6 +7,7 @@ import requests import warnings import tempfile from unittest.mock import ANY +import urllib.request class TestUser(unittest.TestCase): @@ -90,6 +91,10 @@ class TestFlaskRoutes(unittest.TestCase): response = self.client.post('/api/queue/push/', json={"data": "test", "action": "test"}) self.assertEqual(response.status_code, 200) + def test_queue_push_route(self): + networking.queue.get_status = mock.MagicMock(return_value=(None, None)) + response = self.client.post('/api/queue/status/', json={"hash": "test"}) + self.assertEqual(response.status_code, 200) def tearDown(self) -> None: self.io.close() @@ -174,18 +179,37 @@ class TestInterpretation(unittest.TestCase): self.assertEqual(response.status_code, 200) io.close() -class TestInterpretation(unittest.TestCase): - def test_interpretation(self): - io = gr.Interface(lambda x: len(x), "text", "label", interpretation="default") +class TestState(unittest.TestCase): + def test_state_initialization(self): + io = gr.Interface(lambda x: len(x), "text", "label") app, _, _ = io.launch(prevent_thread_lock=True) - client = app.test_client() - io.interpret = mock.MagicMock(return_value=(None, None)) - with mock.patch('requests.post') as mock_post: - response = client.post('/api/interpret/', json={"data": ["test test"]}) - mock_post.assert_called_once() - self.assertEqual(response.status_code, 200) - io.close() + with app.test_request_context(): + self.assertIsNone(networking.get_state()) + def test_state_value(self): + io = gr.Interface(lambda x: len(x), "text", "label") + io.launch(prevent_thread_lock=True) + app, _, _ = io.launch(prevent_thread_lock=True) + with app.test_request_context(): + networking.set_state("test") + client = app.test_client() + client.post('/api/predict/', json={"data": [0]}) + self.assertEquals(networking.get_state(), "test") + +class TestURLs(unittest.TestCase): + def test_url_ok(self): + urllib.request.urlopen = mock.MagicMock(return_value="test") + res = networking.url_request("http://www.gradio.app") + self.assertEquals(res, "test") + + def test_setup_tunnel(self): + networking.create_tunnel = mock.MagicMock(return_value="test") + res = networking.setup_tunnel(None, None) + self.assertEquals(res, "test") + + def test_url_ok(self): + res = networking.url_ok("https://www.gradio.app") + self.assertTrue(res) if __name__ == '__main__':