mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
added pydantic data models for all requests
This commit is contained in:
parent
9e2cac6e4c
commit
c65f9a599f
@ -7,6 +7,9 @@ from typing import Dict, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from gradio.routes import QueuePushRequest
|
||||
|
||||
|
||||
DB_FILE = "gradio_queue.db"
|
||||
|
||||
|
||||
@ -106,8 +109,9 @@ def pop() -> Tuple[int, str, Dict, str]:
|
||||
return result[0], result[1], json.loads(result[2]), result[3]
|
||||
|
||||
|
||||
def push(input_data: Dict, action: str) -> Tuple[str, int]:
|
||||
input_data = json.dumps(input_data)
|
||||
def push(request: QueuePushRequest) -> Tuple[str, int]:
|
||||
action = request.action
|
||||
input_data = json.dumps({'data': request.data})
|
||||
hash = generate_hash()
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
|
@ -69,10 +69,34 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
||||
class PredictRequest(BaseModel):
|
||||
session_hash: Optional[str]
|
||||
example_id: Optional[int]
|
||||
data: Any
|
||||
data: List[Any]
|
||||
state: Optional[Any]
|
||||
|
||||
|
||||
class FlagData(BaseModel):
|
||||
input_data: List[Any]
|
||||
output_data: List[Any]
|
||||
flag_option: Optional[str]
|
||||
flag_index: Optional[int]
|
||||
|
||||
|
||||
class FlagRequest(BaseModel):
|
||||
data: FlagData
|
||||
|
||||
|
||||
class InterpretRequest(BaseModel):
|
||||
data: List[Any]
|
||||
|
||||
|
||||
class QueueStatusRequest(BaseModel):
|
||||
hash: str
|
||||
|
||||
|
||||
class QueuePushRequest(BaseModel):
|
||||
action: str
|
||||
data: Any
|
||||
|
||||
|
||||
###########
|
||||
# Auth
|
||||
###########
|
||||
@ -247,29 +271,26 @@ async def predict(body: PredictRequest, username: str = Depends(get_current_user
|
||||
|
||||
|
||||
@app.post("/api/flag/", dependencies=[Depends(login_check)])
|
||||
async def flag(request: Request, username: str = Depends(get_current_user)):
|
||||
async def flag(body: FlagRequest, username: str = Depends(get_current_user)):
|
||||
if app.launchable.analytics_enabled:
|
||||
await utils.log_feature_analytics(app.launchable.ip_address, "flag")
|
||||
body = await request.json()
|
||||
data = body["data"]
|
||||
await run_in_threadpool(
|
||||
app.launchable.flagging_callback.flag,
|
||||
app.launchable,
|
||||
data["input_data"],
|
||||
data["output_data"],
|
||||
flag_option=data.get("flag_option"),
|
||||
flag_index=data.get("flag_index"),
|
||||
body.data.input_data,
|
||||
body.data.output_data,
|
||||
flag_option=body.data.flag_option,
|
||||
flag_index=body.data.flag_index,
|
||||
username=username,
|
||||
)
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@app.post("/api/interpret/", dependencies=[Depends(login_check)])
|
||||
async def interpret(request: Request):
|
||||
async def interpret(body: InterpretRequest):
|
||||
if app.launchable.analytics_enabled:
|
||||
await utils.log_feature_analytics(app.launchable.ip_address, "interpret")
|
||||
body = await request.json()
|
||||
raw_input = body["data"]
|
||||
raw_input = body.data
|
||||
interpretation_scores, alternative_outputs = await run_in_threadpool(
|
||||
app.launchable.interpret, raw_input
|
||||
)
|
||||
@ -280,18 +301,14 @@ async def interpret(request: Request):
|
||||
|
||||
|
||||
@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
|
||||
async def queue_push(request: Request):
|
||||
body = await request.json()
|
||||
action = body["action"]
|
||||
job_hash, queue_position = queueing.push(body, action)
|
||||
async def queue_push(body: QueuePushRequest):
|
||||
job_hash, queue_position = queueing.push(body)
|
||||
return {"hash": job_hash, "queue_position": queue_position}
|
||||
|
||||
|
||||
@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
|
||||
async def queue_status(request: Request):
|
||||
body = await request.json()
|
||||
hash = body["hash"]
|
||||
status, data = queueing.get_status(hash)
|
||||
async def queue_status(body: QueueStatusRequest):
|
||||
status, data = queueing.get_status(body.hash)
|
||||
return {"status": status, "data": data}
|
||||
|
||||
|
||||
|
@ -4,6 +4,7 @@ import os
|
||||
import unittest
|
||||
|
||||
from gradio import queueing
|
||||
from gradio.routes import QueuePushRequest
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -30,9 +31,11 @@ class TestQueuingActions(unittest.TestCase):
|
||||
queueing.close()
|
||||
|
||||
def test_push_pop_status(self):
|
||||
hash1, position = queueing.push({"data": "test1"}, "predict")
|
||||
request = QueuePushRequest(data="test1", action="predict")
|
||||
hash1, position = queueing.push(request)
|
||||
self.assertEquals(position, 0)
|
||||
hash2, position = queueing.push({"data": "test2"}, "predict")
|
||||
request = QueuePushRequest(data="test2", action="predict")
|
||||
hash2, position = queueing.push(request)
|
||||
self.assertEquals(position, 1)
|
||||
status, position = queueing.get_status(hash2)
|
||||
self.assertEquals(status, "QUEUED")
|
||||
@ -43,8 +46,9 @@ class TestQueuingActions(unittest.TestCase):
|
||||
self.assertEquals(action, "predict")
|
||||
|
||||
def test_jobs(self):
|
||||
hash1, _ = queueing.push({"data": "test1"}, "predict")
|
||||
hash2, position = queueing.push({"data": "test1"}, "predict")
|
||||
request = QueuePushRequest(data="test1", action="predict")
|
||||
hash1, _ = queueing.push(request)
|
||||
hash2, position = queueing.push(request)
|
||||
self.assertEquals(position, 1)
|
||||
|
||||
queueing.start_job(hash1)
|
||||
|
Loading…
Reference in New Issue
Block a user