gradio/test/test_routes.py

120 lines
4.1 KiB
Python
Raw Normal View History

"""Contains tests for networking.py and app.py"""
import os
import unittest
import unittest.mock as mock
from fastapi.testclient import TestClient
from gradio import Interface, close_all, queueing
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestRoutes(unittest.TestCase):
def setUp(self) -> None:
2022-03-24 06:50:10 +08:00
self.io = Interface(lambda x: x + x, "text", "text")
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
self.client = TestClient(self.app)
def test_get_main_route(self):
response = self.client.get("/")
self.assertEqual(response.status_code, 200)
2022-03-24 05:53:21 +08:00
# def test_get_api_route(self):
# response = self.client.get("/api/")
# self.assertEqual(response.status_code, 200)
def test_static_files_served_safely(self):
# Make sure things outside the static folder are not accessible
response = self.client.get(r"/static/..%2findex.html")
self.assertEqual(response.status_code, 404)
response = self.client.get(r"/static/..%2f..%2fapi_docs.html")
self.assertEqual(response.status_code, 404)
def test_get_config_route(self):
response = self.client.get("/config/")
self.assertEqual(response.status_code, 200)
def test_predict_route(self):
2022-03-24 06:50:10 +08:00
response = self.client.post(
"/api/predict/", json={"data": ["test"], "fn_index": 0}
2022-03-24 06:50:10 +08:00
)
self.assertEqual(response.status_code, 200)
output = dict(response.json())
2022-03-24 05:53:21 +08:00
self.assertEqual(output["data"], ["testtest"])
def test_predict_route_without_fn_index(self):
response = self.client.post("/api/predict/", json={"data": ["test"]})
self.assertEqual(response.status_code, 200)
output = dict(response.json())
self.assertEqual(output["data"], ["testtest"])
2022-03-04 13:24:17 +08:00
def test_state(self):
2022-04-19 17:06:38 +08:00
def predict(input, history):
if history is None:
history = ""
2022-03-04 13:24:17 +08:00
history += input
return history, history
io = Interface(predict, ["textbox", "state"], ["textbox", "state"])
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
2022-04-19 16:48:33 +08:00
response = client.post(
2022-04-19 17:15:53 +08:00
"/api/predict/",
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
2022-04-19 16:48:33 +08:00
)
2022-03-04 13:24:17 +08:00
output = dict(response.json())
print("output", output)
self.assertEqual(output["data"], ["test", None])
2022-04-19 16:48:33 +08:00
response = client.post(
2022-04-19 17:15:53 +08:00
"/api/predict/",
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
2022-04-19 16:48:33 +08:00
)
2022-03-04 13:24:17 +08:00
output = dict(response.json())
self.assertEqual(output["data"], ["testtest", None])
def test_queue_push_route(self):
queueing.push = mock.MagicMock(return_value=(None, None))
response = self.client.post(
"/api/queue/push/",
json={"data": "test", "action": "test", "fn_index": 0, "session_hash": "-"},
)
self.assertEqual(response.status_code, 200)
def test_queue_push_route_2(self):
queueing.get_status = mock.MagicMock(return_value=(None, None))
response = self.client.post("/api/queue/status/", json={"hash": "test"})
self.assertEqual(response.status_code, 200)
def tearDown(self) -> None:
self.io.close()
close_all()
class TestAuthenticatedRoutes(unittest.TestCase):
def setUp(self) -> None:
self.io = Interface(lambda x: x, "text", "text")
self.app, _, _ = self.io.launch(
auth=("test", "correct_password"), prevent_thread_lock=True
)
self.client = TestClient(self.app)
def test_post_login(self):
response = self.client.post(
"/login", data=dict(username="test", password="correct_password")
)
self.assertEqual(response.status_code, 302)
response = self.client.post(
"/login", data=dict(username="test", password="incorrect_password")
)
self.assertEqual(response.status_code, 400)
def tearDown(self) -> None:
self.io.close()
close_all()
if __name__ == "__main__":
unittest.main()