added first round of tests for networking.py

This commit is contained in:
Abubakar Abid 2021-10-26 16:36:11 -05:00
parent bdf866de88
commit 14f2e46b19

View File

@ -7,6 +7,7 @@ import requests
import warnings
import tempfile
from unittest.mock import ANY
import urllib.request
class TestUser(unittest.TestCase):
@ -90,6 +91,10 @@ class TestFlaskRoutes(unittest.TestCase):
response = self.client.post('/api/queue/push/', json={"data": "test", "action": "test"})
self.assertEqual(response.status_code, 200)
def test_queue_push_route(self):
networking.queue.get_status = mock.MagicMock(return_value=(None, None))
response = self.client.post('/api/queue/status/', json={"hash": "test"})
self.assertEqual(response.status_code, 200)
def tearDown(self) -> None:
self.io.close()
@ -174,18 +179,37 @@ class TestInterpretation(unittest.TestCase):
self.assertEqual(response.status_code, 200)
io.close()
class TestInterpretation(unittest.TestCase):
def test_interpretation(self):
io = gr.Interface(lambda x: len(x), "text", "label", interpretation="default")
class TestState(unittest.TestCase):
def test_state_initialization(self):
io = gr.Interface(lambda x: len(x), "text", "label")
app, _, _ = io.launch(prevent_thread_lock=True)
client = app.test_client()
io.interpret = mock.MagicMock(return_value=(None, None))
with mock.patch('requests.post') as mock_post:
response = client.post('/api/interpret/', json={"data": ["test test"]})
mock_post.assert_called_once()
self.assertEqual(response.status_code, 200)
io.close()
with app.test_request_context():
self.assertIsNone(networking.get_state())
def test_state_value(self):
io = gr.Interface(lambda x: len(x), "text", "label")
io.launch(prevent_thread_lock=True)
app, _, _ = io.launch(prevent_thread_lock=True)
with app.test_request_context():
networking.set_state("test")
client = app.test_client()
client.post('/api/predict/', json={"data": [0]})
self.assertEquals(networking.get_state(), "test")
class TestURLs(unittest.TestCase):
def test_url_ok(self):
urllib.request.urlopen = mock.MagicMock(return_value="test")
res = networking.url_request("http://www.gradio.app")
self.assertEquals(res, "test")
def test_setup_tunnel(self):
networking.create_tunnel = mock.MagicMock(return_value="test")
res = networking.setup_tunnel(None, None)
self.assertEquals(res, "test")
def test_url_ok(self):
res = networking.url_ok("https://www.gradio.app")
self.assertTrue(res)
if __name__ == '__main__':