mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
renamed app.py to routes.py, separated tests
This commit is contained in:
parent
6debebecd9
commit
8ccd82187c
@ -129,15 +129,17 @@ def static_resource(path: str):
|
||||
|
||||
@app.get("/file/{path:path}", dependencies=[Depends(login_check)])
|
||||
def file(path):
|
||||
if app.interface.encrypt and isinstance(
|
||||
app.interface.examples, str) and path.startswith(
|
||||
app.interface.examples):
|
||||
if (
|
||||
app.interface.encrypt
|
||||
and isinstance(app.interface.examples, str)
|
||||
and path.startswith(app.interface.examples)
|
||||
):
|
||||
with open(safe_join(app.cwd, path), "rb") as encrypted_file:
|
||||
encrypted_data = encrypted_file.read()
|
||||
file_data = encryptor.decrypt(
|
||||
app.interface.encryption_key, encrypted_data)
|
||||
file_data = encryptor.decrypt(app.interface.encryption_key, encrypted_data)
|
||||
return FileResponse(
|
||||
io.BytesIO(file_data), attachment_filename=os.path.basename(path))
|
||||
io.BytesIO(file_data), attachment_filename=os.path.basename(path)
|
||||
)
|
||||
else:
|
||||
return FileResponse(safe_join(app.cwd, path))
|
||||
|
@ -35,77 +35,6 @@ class TestPort(unittest.TestCase):
|
||||
warnings.warn("Unable to test, no ports available")
|
||||
|
||||
|
||||
class TestRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
|
||||
self.client = TestClient(self.app)
|
||||
|
||||
def test_get_main_route(self):
|
||||
response = self.client.get("/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_get_api_route(self):
|
||||
response = self.client.get("/api/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_static_files_served_safely(self):
|
||||
# Make sure things outside the static folder are not accessible
|
||||
response = self.client.get(r"/static/..%2findex.html")
|
||||
self.assertEqual(response.status_code, 404)
|
||||
response = self.client.get(r"/static/..%2f..%2fapi_docs.html")
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
def test_get_config_route(self):
|
||||
response = self.client.get("/config/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_predict_route(self):
|
||||
response = self.client.post("/api/predict/", json={"data": ["test"]})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
output = dict(response.json())
|
||||
self.assertEqual(output["data"], ["test"])
|
||||
self.assertTrue("durations" in output)
|
||||
self.assertTrue("avg_durations" in output)
|
||||
|
||||
def test_queue_push_route(self):
|
||||
queueing.push = mock.MagicMock(return_value=(None, None))
|
||||
response = self.client.post('/api/queue/push/', json={"data": "test", "action": "test"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_queue_push_route(self):
|
||||
queueing.get_status = mock.MagicMock(return_value=(None, None))
|
||||
response = self.client.post('/api/queue/status/', json={"hash": "test"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.io.close()
|
||||
reset_all()
|
||||
|
||||
|
||||
class TestAuthenticatedRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(
|
||||
auth=("test", "correct_password"), prevent_thread_lock=True
|
||||
)
|
||||
self.client = TestClient(self.app)
|
||||
|
||||
def test_post_login(self):
|
||||
response = self.client.post(
|
||||
"/login", data=dict(username="test", password="correct_password")
|
||||
)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
response = self.client.post(
|
||||
"/login", data=dict(username="test", password="incorrect_password")
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.io.close()
|
||||
reset_all()
|
||||
|
||||
|
||||
class TestInterfaceCustomParameters(unittest.TestCase):
|
||||
def test_show_error(self):
|
||||
io = Interface(lambda x: 1 / x, "number", "number")
|
||||
|
88
test/test_routes.py
Normal file
88
test/test_routes.py
Normal file
@ -0,0 +1,88 @@
|
||||
"""Contains tests for networking.py and app.py"""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from gradio import Interface, flagging, networking, queueing, reset_all
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
|
||||
self.client = TestClient(self.app)
|
||||
|
||||
def test_get_main_route(self):
|
||||
response = self.client.get("/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_get_api_route(self):
|
||||
response = self.client.get("/api/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_static_files_served_safely(self):
|
||||
# Make sure things outside the static folder are not accessible
|
||||
response = self.client.get(r"/static/..%2findex.html")
|
||||
self.assertEqual(response.status_code, 404)
|
||||
response = self.client.get(r"/static/..%2f..%2fapi_docs.html")
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
def test_get_config_route(self):
|
||||
response = self.client.get("/config/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_predict_route(self):
|
||||
response = self.client.post("/api/predict/", json={"data": ["test"]})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
output = dict(response.json())
|
||||
self.assertEqual(output["data"], ["test"])
|
||||
self.assertTrue("durations" in output)
|
||||
self.assertTrue("avg_durations" in output)
|
||||
|
||||
def test_queue_push_route(self):
|
||||
queueing.push = mock.MagicMock(return_value=(None, None))
|
||||
response = self.client.post(
|
||||
"/api/queue/push/", json={"data": "test", "action": "test"}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_queue_push_route(self):
|
||||
queueing.get_status = mock.MagicMock(return_value=(None, None))
|
||||
response = self.client.post("/api/queue/status/", json={"hash": "test"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.io.close()
|
||||
reset_all()
|
||||
|
||||
|
||||
class TestAuthenticatedRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(
|
||||
auth=("test", "correct_password"), prevent_thread_lock=True
|
||||
)
|
||||
self.client = TestClient(self.app)
|
||||
|
||||
def test_post_login(self):
|
||||
response = self.client.post(
|
||||
"/login", data=dict(username="test", password="correct_password")
|
||||
)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
response = self.client.post(
|
||||
"/login", data=dict(username="test", password="incorrect_password")
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.io.close()
|
||||
reset_all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user