mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-19 12:00:39 +08:00
updated to use pydantic models
This commit is contained in:
parent
fd034ee29f
commit
863b4287a9
@ -28,7 +28,7 @@ from gradio.outputs import OutputComponent
|
||||
from gradio.outputs import State as o_State # type: ignore
|
||||
from gradio.outputs import get_output_instance
|
||||
from gradio.process_examples import load_from_cache, process_example
|
||||
from gradio.routes import predict
|
||||
from gradio.routes import PredictRequest
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
import flask
|
||||
@ -559,19 +559,18 @@ class Interface(Launchable):
|
||||
else:
|
||||
return predictions
|
||||
|
||||
def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]:
|
||||
def process_api(self, data: PredictRequest, username: str = None) -> Dict[str, Any]:
|
||||
flag_index = None
|
||||
if data.get("example_id") is not None:
|
||||
example_id = data["example_id"]
|
||||
if data.example_id is not None:
|
||||
if self.cache_examples:
|
||||
prediction = load_from_cache(self, example_id)
|
||||
prediction = load_from_cache(self, data.example_id)
|
||||
durations = None
|
||||
else:
|
||||
prediction, durations = process_example(self, example_id)
|
||||
prediction, durations = process_example(self, data.example_id)
|
||||
else:
|
||||
raw_input = data["data"]
|
||||
raw_input = data.data
|
||||
if self.stateful:
|
||||
state = data["state"]
|
||||
state = data.state
|
||||
raw_input[self.state_param_index] = state
|
||||
prediction, durations = self.process(raw_input)
|
||||
if self.allow_flagging == "auto":
|
||||
|
@ -21,6 +21,7 @@ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from jinja2.exceptions import TemplateNotFound
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
from gradio import encryptor, queueing, utils
|
||||
@ -60,6 +61,17 @@ app.state_holder = state_holder
|
||||
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
||||
|
||||
|
||||
###########
|
||||
# Data Models
|
||||
###########
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
session_hash: Optional[str]
|
||||
example_id: Optional[int]
|
||||
data: Any
|
||||
state: Optional[Any]
|
||||
|
||||
|
||||
###########
|
||||
# Auth
|
||||
###########
|
||||
@ -211,15 +223,13 @@ def api_docs(request: Request):
|
||||
|
||||
|
||||
@app.post("/api/predict/", dependencies=[Depends(login_check)])
|
||||
async def predict(request: Request, username: str = Depends(get_current_user)):
|
||||
body = await request.json()
|
||||
|
||||
async def predict(body: PredictRequest, username: str = Depends(get_current_user)):
|
||||
if app.launchable.stateful:
|
||||
session_hash = body.get("session_hash", None)
|
||||
session_hash = body.session_hash
|
||||
state = app.state_holder.get(
|
||||
(session_hash, "state"), app.launchable.state_default
|
||||
)
|
||||
body["state"] = state
|
||||
body.state = state
|
||||
try:
|
||||
output = await run_in_threadpool(app.launchable.process_api, body, username)
|
||||
if app.launchable.stateful:
|
||||
|
Loading…
x
Reference in New Issue
Block a user