updated to use pydantic models

This commit is contained in:
Abubakar Abid 2022-03-04 11:03:10 -05:00
parent fd034ee29f
commit 863b4287a9
2 changed files with 22 additions and 13 deletions

View File

@ -28,7 +28,7 @@ from gradio.outputs import OutputComponent
from gradio.outputs import State as o_State # type: ignore
from gradio.outputs import get_output_instance
from gradio.process_examples import load_from_cache, process_example
from gradio.routes import predict
from gradio.routes import PredictRequest
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import flask
@ -559,19 +559,18 @@ class Interface(Launchable):
else:
return predictions
def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]:
def process_api(self, data: PredictRequest, username: str = None) -> Dict[str, Any]:
flag_index = None
if data.get("example_id") is not None:
example_id = data["example_id"]
if data.example_id is not None:
if self.cache_examples:
prediction = load_from_cache(self, example_id)
prediction = load_from_cache(self, data.example_id)
durations = None
else:
prediction, durations = process_example(self, example_id)
prediction, durations = process_example(self, data.example_id)
else:
raw_input = data["data"]
raw_input = data.data
if self.stateful:
state = data["state"]
state = data.state
raw_input[self.state_param_index] = state
prediction, durations = self.process(raw_input)
if self.allow_flagging == "auto":

View File

@ -21,6 +21,7 @@ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates
from jinja2.exceptions import TemplateNotFound
from pydantic import BaseModel
from starlette.responses import RedirectResponse
from gradio import encryptor, queueing, utils
@ -60,6 +61,17 @@ app.state_holder = state_holder
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
###########
# Data Models
###########
class PredictRequest(BaseModel):
session_hash: Optional[str]
example_id: Optional[int]
data: Any
state: Optional[Any]
###########
# Auth
###########
@ -211,15 +223,13 @@ def api_docs(request: Request):
@app.post("/api/predict/", dependencies=[Depends(login_check)])
async def predict(request: Request, username: str = Depends(get_current_user)):
body = await request.json()
async def predict(body: PredictRequest, username: str = Depends(get_current_user)):
if app.launchable.stateful:
session_hash = body.get("session_hash", None)
session_hash = body.session_hash
state = app.state_holder.get(
(session_hash, "state"), app.launchable.state_default
)
body["state"] = state
body.state = state
try:
output = await run_in_threadpool(app.launchable.process_api, body, username)
if app.launchable.stateful: