From 863b4287a9e51e53bafb44176c99b31e13644ded Mon Sep 17 00:00:00 2001 From: Abubakar Abid <abubakar@huggingface.co> Date: Fri, 4 Mar 2022 11:03:10 -0500 Subject: [PATCH] updated to use pydantic models --- gradio/interface.py | 15 +++++++-------- gradio/routes.py | 20 +++++++++++++++----- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/gradio/interface.py b/gradio/interface.py index 42b1821aa5..1a9c55226c 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -28,7 +28,7 @@ from gradio.outputs import OutputComponent from gradio.outputs import State as o_State # type: ignore from gradio.outputs import get_output_instance from gradio.process_examples import load_from_cache, process_example -from gradio.routes import predict +from gradio.routes import PredictRequest if TYPE_CHECKING: # Only import for type checking (is False at runtime). import flask @@ -559,19 +559,18 @@ class Interface(Launchable): else: return predictions - def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]: + def process_api(self, data: PredictRequest, username: str = None) -> Dict[str, Any]: flag_index = None - if data.get("example_id") is not None: - example_id = data["example_id"] + if data.example_id is not None: if self.cache_examples: - prediction = load_from_cache(self, example_id) + prediction = load_from_cache(self, data.example_id) durations = None else: - prediction, durations = process_example(self, example_id) + prediction, durations = process_example(self, data.example_id) else: - raw_input = data["data"] + raw_input = data.data if self.stateful: - state = data["state"] + state = data.state raw_input[self.state_param_index] = state prediction, durations = self.process(raw_input) if self.allow_flagging == "auto": diff --git a/gradio/routes.py b/gradio/routes.py index 880cd7a7ab..72e079675a 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -21,6 +21,7 @@ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse from fastapi.security import OAuth2PasswordRequestForm from fastapi.templating import Jinja2Templates from jinja2.exceptions import TemplateNotFound +from pydantic import BaseModel from starlette.responses import RedirectResponse from gradio import encryptor, queueing, utils @@ -60,6 +61,17 @@ app.state_holder = state_holder templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB) +########### +# Data Models +########### + +class PredictRequest(BaseModel): + session_hash: Optional[str] + example_id: Optional[int] + data: Any + state: Optional[Any] + + ########### # Auth ########### @@ -211,15 +223,13 @@ def api_docs(request: Request): @app.post("/api/predict/", dependencies=[Depends(login_check)]) -async def predict(request: Request, username: str = Depends(get_current_user)): - body = await request.json() - +async def predict(body: PredictRequest, username: str = Depends(get_current_user)): if app.launchable.stateful: - session_hash = body.get("session_hash", None) + session_hash = body.session_hash state = app.state_holder.get( (session_hash, "state"), app.launchable.state_default ) - body["state"] = state + body.state = state try: output = await run_in_threadpool(app.launchable.process_api, body, username) if app.launchable.stateful: