gradio/test/test_routes.py
Abubakar Abid 871c9713b4
Restored /api/predict/ endpoint for Interfaces (#1199)
* updated PyPi version to 2.9b25

* added /api/predict reverse compatibility

* fixed flagging

* formatting

* fixed networking tests

* added queue false
2022-05-09 18:05:30 -07:00

120 lines
4.1 KiB
Python

"""Contains tests for networking.py and app.py"""
import os
import unittest
import unittest.mock as mock
from fastapi.testclient import TestClient
from gradio import Interface, close_all, queueing
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestRoutes(unittest.TestCase):
def setUp(self) -> None:
self.io = Interface(lambda x: x + x, "text", "text")
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
self.client = TestClient(self.app)
def test_get_main_route(self):
response = self.client.get("/")
self.assertEqual(response.status_code, 200)
# def test_get_api_route(self):
# response = self.client.get("/api/")
# self.assertEqual(response.status_code, 200)
def test_static_files_served_safely(self):
# Make sure things outside the static folder are not accessible
response = self.client.get(r"/static/..%2findex.html")
self.assertEqual(response.status_code, 404)
response = self.client.get(r"/static/..%2f..%2fapi_docs.html")
self.assertEqual(response.status_code, 404)
def test_get_config_route(self):
response = self.client.get("/config/")
self.assertEqual(response.status_code, 200)
def test_predict_route(self):
response = self.client.post(
"/api/predict/", json={"data": ["test"], "fn_index": 0}
)
self.assertEqual(response.status_code, 200)
output = dict(response.json())
self.assertEqual(output["data"], ["testtest"])
def test_predict_route_without_fn_index(self):
response = self.client.post("/api/predict/", json={"data": ["test"]})
self.assertEqual(response.status_code, 200)
output = dict(response.json())
self.assertEqual(output["data"], ["testtest"])
def test_state(self):
def predict(input, history):
if history is None:
history = ""
history += input
return history, history
io = Interface(predict, ["textbox", "state"], ["textbox", "state"])
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post(
"/api/predict/",
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
)
output = dict(response.json())
print("output", output)
self.assertEqual(output["data"], ["test", None])
response = client.post(
"/api/predict/",
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
)
output = dict(response.json())
self.assertEqual(output["data"], ["testtest", None])
def test_queue_push_route(self):
queueing.push = mock.MagicMock(return_value=(None, None))
response = self.client.post(
"/api/queue/push/",
json={"data": "test", "action": "test", "fn_index": 0, "session_hash": "-"},
)
self.assertEqual(response.status_code, 200)
def test_queue_push_route_2(self):
queueing.get_status = mock.MagicMock(return_value=(None, None))
response = self.client.post("/api/queue/status/", json={"hash": "test"})
self.assertEqual(response.status_code, 200)
def tearDown(self) -> None:
self.io.close()
close_all()
class TestAuthenticatedRoutes(unittest.TestCase):
def setUp(self) -> None:
self.io = Interface(lambda x: x, "text", "text")
self.app, _, _ = self.io.launch(
auth=("test", "correct_password"), prevent_thread_lock=True
)
self.client = TestClient(self.app)
def test_post_login(self):
response = self.client.post(
"/login", data=dict(username="test", password="correct_password")
)
self.assertEqual(response.status_code, 302)
response = self.client.post(
"/login", data=dict(username="test", password="incorrect_password")
)
self.assertEqual(response.status_code, 400)
def tearDown(self) -> None:
self.io.close()
close_all()
if __name__ == "__main__":
unittest.main()