added pydantic data models for all requests

This commit is contained in:
Abubakar Abid 2022-03-05 21:56:50 -05:00
parent 9e2cac6e4c
commit c65f9a599f
3 changed files with 50 additions and 25 deletions

View File

@ -7,6 +7,9 @@ from typing import Dict, Tuple
import requests
from gradio.routes import QueuePushRequest
DB_FILE = "gradio_queue.db"
@ -106,8 +109,9 @@ def pop() -> Tuple[int, str, Dict, str]:
return result[0], result[1], json.loads(result[2]), result[3]
def push(input_data: Dict, action: str) -> Tuple[str, int]:
input_data = json.dumps(input_data)
def push(request: QueuePushRequest) -> Tuple[str, int]:
action = request.action
input_data = json.dumps({'data': request.data})
hash = generate_hash()
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()

View File

@ -69,10 +69,34 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
class PredictRequest(BaseModel):
session_hash: Optional[str]
example_id: Optional[int]
data: Any
data: List[Any]
state: Optional[Any]
class FlagData(BaseModel):
input_data: List[Any]
output_data: List[Any]
flag_option: Optional[str]
flag_index: Optional[int]
class FlagRequest(BaseModel):
data: FlagData
class InterpretRequest(BaseModel):
data: List[Any]
class QueueStatusRequest(BaseModel):
hash: str
class QueuePushRequest(BaseModel):
action: str
data: Any
###########
# Auth
###########
@ -247,29 +271,26 @@ async def predict(body: PredictRequest, username: str = Depends(get_current_user
@app.post("/api/flag/", dependencies=[Depends(login_check)])
async def flag(request: Request, username: str = Depends(get_current_user)):
async def flag(body: FlagRequest, username: str = Depends(get_current_user)):
if app.launchable.analytics_enabled:
await utils.log_feature_analytics(app.launchable.ip_address, "flag")
body = await request.json()
data = body["data"]
await run_in_threadpool(
app.launchable.flagging_callback.flag,
app.launchable,
data["input_data"],
data["output_data"],
flag_option=data.get("flag_option"),
flag_index=data.get("flag_index"),
body.data.input_data,
body.data.output_data,
flag_option=body.data.flag_option,
flag_index=body.data.flag_index,
username=username,
)
return {"success": True}
@app.post("/api/interpret/", dependencies=[Depends(login_check)])
async def interpret(request: Request):
async def interpret(body: InterpretRequest):
if app.launchable.analytics_enabled:
await utils.log_feature_analytics(app.launchable.ip_address, "interpret")
body = await request.json()
raw_input = body["data"]
raw_input = body.data
interpretation_scores, alternative_outputs = await run_in_threadpool(
app.launchable.interpret, raw_input
)
@ -280,18 +301,14 @@ async def interpret(request: Request):
@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
async def queue_push(request: Request):
body = await request.json()
action = body["action"]
job_hash, queue_position = queueing.push(body, action)
async def queue_push(body: QueuePushRequest):
job_hash, queue_position = queueing.push(body)
return {"hash": job_hash, "queue_position": queue_position}
@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
async def queue_status(request: Request):
body = await request.json()
hash = body["hash"]
status, data = queueing.get_status(hash)
async def queue_status(body: QueueStatusRequest):
status, data = queueing.get_status(body.hash)
return {"status": status, "data": data}

View File

@ -4,6 +4,7 @@ import os
import unittest
from gradio import queueing
from gradio.routes import QueuePushRequest
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -30,9 +31,11 @@ class TestQueuingActions(unittest.TestCase):
queueing.close()
def test_push_pop_status(self):
hash1, position = queueing.push({"data": "test1"}, "predict")
request = QueuePushRequest(data="test1", action="predict")
hash1, position = queueing.push(request)
self.assertEquals(position, 0)
hash2, position = queueing.push({"data": "test2"}, "predict")
request = QueuePushRequest(data="test2", action="predict")
hash2, position = queueing.push(request)
self.assertEquals(position, 1)
status, position = queueing.get_status(hash2)
self.assertEquals(status, "QUEUED")
@ -43,8 +46,9 @@ class TestQueuingActions(unittest.TestCase):
self.assertEquals(action, "predict")
def test_jobs(self):
hash1, _ = queueing.push({"data": "test1"}, "predict")
hash2, position = queueing.push({"data": "test1"}, "predict")
request = QueuePushRequest(data="test1", action="predict")
hash1, _ = queueing.push(request)
hash2, position = queueing.push(request)
self.assertEquals(position, 1)
queueing.start_job(hash1)