From 863b4287a9e51e53bafb44176c99b31e13644ded Mon Sep 17 00:00:00 2001
From: Abubakar Abid <abubakar@huggingface.co>
Date: Fri, 4 Mar 2022 11:03:10 -0500
Subject: [PATCH] updated to use pydantic models

---
 gradio/interface.py | 15 +++++++--------
 gradio/routes.py    | 20 +++++++++++++++-----
 2 files changed, 22 insertions(+), 13 deletions(-)

diff --git a/gradio/interface.py b/gradio/interface.py
index 42b1821aa5..1a9c55226c 100644
--- a/gradio/interface.py
+++ b/gradio/interface.py
@@ -28,7 +28,7 @@ from gradio.outputs import OutputComponent
 from gradio.outputs import State as o_State  # type: ignore
 from gradio.outputs import get_output_instance
 from gradio.process_examples import load_from_cache, process_example
-from gradio.routes import predict
+from gradio.routes import PredictRequest
 
 if TYPE_CHECKING:  # Only import for type checking (is False at runtime).
     import flask
@@ -559,19 +559,18 @@ class Interface(Launchable):
         else:
             return predictions
 
-    def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]:
+    def process_api(self, data: PredictRequest, username: str = None) -> Dict[str, Any]:
         flag_index = None
-        if data.get("example_id") is not None:
-            example_id = data["example_id"]
+        if data.example_id is not None:
             if self.cache_examples:
-                prediction = load_from_cache(self, example_id)
+                prediction = load_from_cache(self, data.example_id)
                 durations = None
             else:
-                prediction, durations = process_example(self, example_id)
+                prediction, durations = process_example(self, data.example_id)
         else:
-            raw_input = data["data"]
+            raw_input = data.data
             if self.stateful:
-                state = data["state"]
+                state = data.state
                 raw_input[self.state_param_index] = state
             prediction, durations = self.process(raw_input)
             if self.allow_flagging == "auto":
diff --git a/gradio/routes.py b/gradio/routes.py
index 880cd7a7ab..72e079675a 100644
--- a/gradio/routes.py
+++ b/gradio/routes.py
@@ -21,6 +21,7 @@ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
 from fastapi.security import OAuth2PasswordRequestForm
 from fastapi.templating import Jinja2Templates
 from jinja2.exceptions import TemplateNotFound
+from pydantic import BaseModel
 from starlette.responses import RedirectResponse
 
 from gradio import encryptor, queueing, utils
@@ -60,6 +61,17 @@ app.state_holder = state_holder
 templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
 
 
+###########
+# Data Models
+###########
+
+class PredictRequest(BaseModel):
+    session_hash: Optional[str]
+    example_id: Optional[int]
+    data: Any
+    state: Optional[Any]
+    
+
 ###########
 # Auth
 ###########
@@ -211,15 +223,13 @@ def api_docs(request: Request):
 
 
 @app.post("/api/predict/", dependencies=[Depends(login_check)])
-async def predict(request: Request, username: str = Depends(get_current_user)):
-    body = await request.json()
-
+async def predict(body: PredictRequest, username: str = Depends(get_current_user)):
     if app.launchable.stateful:
-        session_hash = body.get("session_hash", None)
+        session_hash = body.session_hash
         state = app.state_holder.get(
             (session_hash, "state"), app.launchable.state_default
         )
-        body["state"] = state
+        body.state = state
     try:
         output = await run_in_threadpool(app.launchable.process_api, body, username)
         if app.launchable.stateful: