gradio/test/test_routes.py

113 lines
3.7 KiB
Python
Raw Normal View History

"""Contains tests for networking.py and app.py"""
import os
import unittest
import unittest.mock as mock
from fastapi import FastAPI
from fastapi.testclient import TestClient
Release new queue beta (#1969) * queue-refactor-backend (#1489) * queue-refactor-backend - create a template for the new design * queue-refactor-backend - clean after the old queue * queue-refactor-backend - add basic test to websocket endpoint * queue-refactor-backend - small fix * queue-refactor-backend - debugs&fixes&finalizations - test the flow with postman * queue-refactor-backend - tweaks on websocket closing * queue-refactor-backend - cleanup * queue-refactor-backend - cleanup & tweaks * queue-refactor-backend - cleanup & tweaks * queue-refactor-backend - cleanup & tweaks - correct the exception handling * queue-refactor-backend - add websockets dependency * queue-refactor-backend - reformat * queue-refactor-backend - add single event test * queue-refactor-backend - tweaks - remove outdated tests * queue-refactor-backend - reformat * queue-refactor-backend - reformat * queue-refactor-backend - reformat * queue-refactor-backend - add Queue configurations to Blocks.launch() - add live_queue_update to send estimations whenever a job gets fetched from the Queue * queue-refactor-backend - add Queue configurations to Blocks.launch() - add live_queue_update to send estimations whenever a job gets fetched from the Queue * queue-refactor-backend - tweaks * queue-refactor-backend - make SLEEP_WHEN_FREE shorter Co-authored-by: Ali Abid <aabid94@gmail.com> * Add estimation parameters to queue (#1889) * - tweaks on Estimation * version * Revert "version" This reverts commit bd1f4d7bfe3658a4967b93126859a62a511a70e2. * some fix and tweaks * implement queue frontend (#1950) * implement queue frontend * fix types * fix ws endpoint in build mode * cleanup * Queue tweaks (#1909) * tweaks on estimation payload * Queue keep ws connections open (#1910) * 1. keep ws connections open after the event process is completed 2. do not send estimations periodically if live queue updates is open * fix calculation * 1. tweaks on event_queue * fix issue - create new ws for each request * format * fix * fix tests * fix tests * tets * test * changes * changes * changes * change' * wtf * changes * changes * file perms * Release queue beta v1 (#1971) * - release the new queue * - bypass the issue in the tests - rewrite the lost part in the codebase * - add concurrent queue example (#1978) * rank_eta calc * Queue fixes (#1981) * change * format * - comment out queue tests as they dont work well * - reformat * Update gradio/event_queue.py Co-authored-by: Ömer Faruk Özdemir <farukozderim@gmail.com> * changes * changes * change * weird fix Co-authored-by: Ömer Faruk Özdemir <farukozderim@gmail.com> * release-queue-v3 (#1988) * Fix frontend queuing to target secure WSS (#1996) * change * format * changes * queue-concurrency-tweaks (#2002) 1. make gather_data and broadcast_estimation sequential instead of concurrent because they were deleting elements at the same time and raising expections which was lowering the performance * Update Queue API, documentation (#2026) * changes * changes * fixes * changes * change * fix Co-authored-by: Ömer Faruk Özdemir <farukozderim@gmail.com> Co-authored-by: pngwn <hello@pngwn.io>
2022-08-18 02:17:56 +08:00
from gradio import Interface, close_all, routes
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestRoutes(unittest.TestCase):
def setUp(self) -> None:
2022-03-24 06:50:10 +08:00
self.io = Interface(lambda x: x + x, "text", "text")
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
self.client = TestClient(self.app)
def test_get_main_route(self):
response = self.client.get("/")
self.assertEqual(response.status_code, 200)
2022-03-24 05:53:21 +08:00
# def test_get_api_route(self):
# response = self.client.get("/api/")
# self.assertEqual(response.status_code, 200)
def test_static_files_served_safely(self):
# Make sure things outside the static folder are not accessible
response = self.client.get(r"/static/..%2findex.html")
self.assertEqual(response.status_code, 404)
response = self.client.get(r"/static/..%2f..%2fapi_docs.html")
self.assertEqual(response.status_code, 404)
def test_get_config_route(self):
response = self.client.get("/config/")
self.assertEqual(response.status_code, 200)
def test_predict_route(self):
2022-03-24 06:50:10 +08:00
response = self.client.post(
"/api/predict/", json={"data": ["test"], "fn_index": 0}
2022-03-24 06:50:10 +08:00
)
self.assertEqual(response.status_code, 200)
output = dict(response.json())
2022-03-24 05:53:21 +08:00
self.assertEqual(output["data"], ["testtest"])
def test_predict_route_without_fn_index(self):
response = self.client.post("/api/predict/", json={"data": ["test"]})
self.assertEqual(response.status_code, 200)
output = dict(response.json())
self.assertEqual(output["data"], ["testtest"])
2022-03-04 13:24:17 +08:00
def test_state(self):
2022-04-19 17:06:38 +08:00
def predict(input, history):
if history is None:
history = ""
2022-03-04 13:24:17 +08:00
history += input
return history, history
io = Interface(predict, ["textbox", "state"], ["textbox", "state"])
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
2022-04-19 16:48:33 +08:00
response = client.post(
2022-04-19 17:15:53 +08:00
"/api/predict/",
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
2022-04-19 16:48:33 +08:00
)
2022-03-04 13:24:17 +08:00
output = dict(response.json())
self.assertEqual(output["data"], ["test", None])
2022-04-19 16:48:33 +08:00
response = client.post(
2022-04-19 17:15:53 +08:00
"/api/predict/",
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
2022-04-19 16:48:33 +08:00
)
2022-03-04 13:24:17 +08:00
output = dict(response.json())
self.assertEqual(output["data"], ["testtest", None])
def tearDown(self) -> None:
self.io.close()
close_all()
class TestApp:
def test_create_app(self):
app = routes.App.create_app(Interface(lambda x: x, "text", "text"))
assert isinstance(app, FastAPI)
class TestAuthenticatedRoutes(unittest.TestCase):
def setUp(self) -> None:
self.io = Interface(lambda x: x, "text", "text")
self.app, _, _ = self.io.launch(
auth=("test", "correct_password"), prevent_thread_lock=True
)
self.client = TestClient(self.app)
def test_post_login(self):
response = self.client.post(
"/login", data=dict(username="test", password="correct_password")
)
self.assertEqual(response.status_code, 302)
response = self.client.post(
"/login", data=dict(username="test", password="incorrect_password")
)
self.assertEqual(response.status_code, 400)
def tearDown(self) -> None:
self.io.close()
close_all()
if __name__ == "__main__":
unittest.main()