diff --git a/gradio/queueing.py b/gradio/queueing.py index 9b6ab16abf..d402db9290 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -7,6 +7,9 @@ from typing import Dict, Tuple import requests +from gradio.routes import QueuePushRequest + + DB_FILE = "gradio_queue.db" @@ -106,8 +109,9 @@ def pop() -> Tuple[int, str, Dict, str]: return result[0], result[1], json.loads(result[2]), result[3] -def push(input_data: Dict, action: str) -> Tuple[str, int]: - input_data = json.dumps(input_data) +def push(request: QueuePushRequest) -> Tuple[str, int]: + action = request.action + input_data = json.dumps({'data': request.data}) hash = generate_hash() conn = sqlite3.connect(DB_FILE) c = conn.cursor() diff --git a/gradio/routes.py b/gradio/routes.py index 5a101d0042..482141a09d 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -69,10 +69,34 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB) class PredictRequest(BaseModel): session_hash: Optional[str] example_id: Optional[int] - data: Any + data: List[Any] state: Optional[Any] +class FlagData(BaseModel): + input_data: List[Any] + output_data: List[Any] + flag_option: Optional[str] + flag_index: Optional[int] + + +class FlagRequest(BaseModel): + data: FlagData + + +class InterpretRequest(BaseModel): + data: List[Any] + + +class QueueStatusRequest(BaseModel): + hash: str + + +class QueuePushRequest(BaseModel): + action: str + data: Any + + ########### # Auth ########### @@ -247,29 +271,26 @@ async def predict(body: PredictRequest, username: str = Depends(get_current_user @app.post("/api/flag/", dependencies=[Depends(login_check)]) -async def flag(request: Request, username: str = Depends(get_current_user)): +async def flag(body: FlagRequest, username: str = Depends(get_current_user)): if app.launchable.analytics_enabled: await utils.log_feature_analytics(app.launchable.ip_address, "flag") - body = await request.json() - data = body["data"] await run_in_threadpool( app.launchable.flagging_callback.flag, app.launchable, - data["input_data"], - data["output_data"], - flag_option=data.get("flag_option"), - flag_index=data.get("flag_index"), + body.data.input_data, + body.data.output_data, + flag_option=body.data.flag_option, + flag_index=body.data.flag_index, username=username, ) return {"success": True} @app.post("/api/interpret/", dependencies=[Depends(login_check)]) -async def interpret(request: Request): +async def interpret(body: InterpretRequest): if app.launchable.analytics_enabled: await utils.log_feature_analytics(app.launchable.ip_address, "interpret") - body = await request.json() - raw_input = body["data"] + raw_input = body.data interpretation_scores, alternative_outputs = await run_in_threadpool( app.launchable.interpret, raw_input ) @@ -280,18 +301,14 @@ async def interpret(request: Request): @app.post("/api/queue/push/", dependencies=[Depends(login_check)]) -async def queue_push(request: Request): - body = await request.json() - action = body["action"] - job_hash, queue_position = queueing.push(body, action) +async def queue_push(body: QueuePushRequest): + job_hash, queue_position = queueing.push(body) return {"hash": job_hash, "queue_position": queue_position} @app.post("/api/queue/status/", dependencies=[Depends(login_check)]) -async def queue_status(request: Request): - body = await request.json() - hash = body["hash"] - status, data = queueing.get_status(hash) +async def queue_status(body: QueueStatusRequest): + status, data = queueing.get_status(body.hash) return {"status": status, "data": data} diff --git a/test/test_queuing.py b/test/test_queuing.py index 7bbfe5d159..fcae876635 100644 --- a/test/test_queuing.py +++ b/test/test_queuing.py @@ -4,6 +4,7 @@ import os import unittest from gradio import queueing +from gradio.routes import QueuePushRequest os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -30,9 +31,11 @@ class TestQueuingActions(unittest.TestCase): queueing.close() def test_push_pop_status(self): - hash1, position = queueing.push({"data": "test1"}, "predict") + request = QueuePushRequest(data="test1", action="predict") + hash1, position = queueing.push(request) self.assertEquals(position, 0) - hash2, position = queueing.push({"data": "test2"}, "predict") + request = QueuePushRequest(data="test2", action="predict") + hash2, position = queueing.push(request) self.assertEquals(position, 1) status, position = queueing.get_status(hash2) self.assertEquals(status, "QUEUED") @@ -43,8 +46,9 @@ class TestQueuingActions(unittest.TestCase): self.assertEquals(action, "predict") def test_jobs(self): - hash1, _ = queueing.push({"data": "test1"}, "predict") - hash2, position = queueing.push({"data": "test1"}, "predict") + request = QueuePushRequest(data="test1", action="predict") + hash1, _ = queueing.push(request) + hash2, position = queueing.push(request) self.assertEquals(position, 1) queueing.start_job(hash1)