diff --git a/gradio/blocks.py b/gradio/blocks.py index 53b1745f0e..ad87aacd09 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -73,6 +73,7 @@ class Blocks(Launchable, BlockContext): self.theme = theme self.requires_permissions = False # TODO: needs to be implemented self.enable_queue = False + self.stateful = False # TODO: implement state super().__init__() Context.root_block = self diff --git a/gradio/interface.py b/gradio/interface.py index aa5a8a5732..8891f52ee7 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -28,7 +28,7 @@ from gradio.outputs import OutputComponent from gradio.outputs import State as o_State # type: ignore from gradio.outputs import get_output_instance from gradio.process_examples import load_from_cache, process_example -from gradio.routes import predict +from gradio.routes import PredictBody if TYPE_CHECKING: # Only import for type checking (is False at runtime). import flask @@ -176,6 +176,7 @@ class Interface(Launchable): if repeat_outputs_per_model: self.output_components *= len(fn) + self.stateful = False if sum(isinstance(i, i_State) for i in self.input_components) > 1: raise ValueError("Only one input component can be State.") if sum(isinstance(o, o_State) for o in self.output_components) > 1: @@ -187,10 +188,24 @@ class Interface(Launchable): state_param_index = [ isinstance(i, i_State) for i in self.input_components ].index(True) + self.stateful = True + self.state_param_index = state_param_index state: i_State = self.input_components[state_param_index] if state.default is None: default = utils.get_default_args(fn[0])[state_param_index] state.default = default + self.state_default = state.default + + if sum(isinstance(i, o_State) for i in self.output_components) == 1: + state_return_index = [ + isinstance(i, o_State) for i in self.output_components + ].index(True) + self.state_return_index = state_return_index + else: + raise ValueError( + "For a stateful interface, there must be exactly one State" + " input component and one State output component." + ) if ( interpretation is None @@ -543,17 +558,19 @@ class Interface(Launchable): else: return predictions - def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]: + def process_api(self, data: PredictBody, username: str = None) -> Dict[str, Any]: flag_index = None - if data.get("example_id") is not None: - example_id = data["example_id"] + if data.example_id is not None: if self.cache_examples: - prediction = load_from_cache(self, example_id) + prediction = load_from_cache(self, data.example_id) durations = None else: - prediction, durations = process_example(self, example_id) + prediction, durations = process_example(self, data.example_id) else: - raw_input = data["data"] + raw_input = data.data + if self.stateful: + state = data.state + raw_input[self.state_param_index] = state prediction, durations = self.process(raw_input) if self.allow_flagging == "auto": flag_index = self.flagging_callback.flag( @@ -563,12 +580,18 @@ class Interface(Launchable): flag_option="" if self.flagging_options else None, username=username, ) + if self.stateful: + updated_state = prediction[self.state_return_index] + prediction[self.state_return_index] = None + else: + updated_state = None return { "data": prediction, "durations": durations, "avg_durations": self.config.get("avg_durations"), "flag_index": flag_index, + "updated_state": updated_state, } def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]: diff --git a/gradio/queueing.py b/gradio/queueing.py index 9b6ab16abf..7d3bbf89c0 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -7,6 +7,8 @@ from typing import Dict, Tuple import requests +from gradio.routes import QueuePushBody + DB_FILE = "gradio_queue.db" @@ -106,8 +108,9 @@ def pop() -> Tuple[int, str, Dict, str]: return result[0], result[1], json.loads(result[2]), result[3] -def push(input_data: Dict, action: str) -> Tuple[str, int]: - input_data = json.dumps(input_data) +def push(body: QueuePushBody) -> Tuple[str, int]: + action = body.action + input_data = json.dumps({"data": body.data}) hash = generate_hash() conn = sqlite3.connect(DB_FILE) c = conn.cursor() diff --git a/gradio/routes.py b/gradio/routes.py index 817f0baa86..159ecfa726 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -9,7 +9,7 @@ import posixpath import secrets import traceback import urllib -from typing import Any, List, Optional, Type +from typing import Any, Dict, List, Optional, Tuple, Type import orjson import pkg_resources @@ -21,6 +21,7 @@ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse from fastapi.security import OAuth2PasswordRequestForm from fastapi.templating import Jinja2Templates from jinja2.exceptions import TemplateNotFound +from pydantic import BaseModel from starlette.responses import RedirectResponse from gradio import encryptor, queueing, utils @@ -54,9 +55,48 @@ app.add_middleware( allow_headers=["*"], ) +state_holder: Dict[Tuple[str, str], Any] = {} +app.state_holder = state_holder + templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB) +########### +# Data Models +########### + + +class PredictBody(BaseModel): + session_hash: Optional[str] + example_id: Optional[int] + data: List[Any] + state: Optional[Any] + + +class FlagData(BaseModel): + input_data: List[Any] + output_data: List[Any] + flag_option: Optional[str] + flag_index: Optional[int] + + +class FlagBody(BaseModel): + data: FlagData + + +class InterpretBody(BaseModel): + data: List[Any] + + +class QueueStatusBody(BaseModel): + hash: str + + +class QueuePushBody(BaseModel): + action: str + data: Any + + ########### # Auth ########### @@ -208,10 +248,19 @@ def api_docs(request: Request): @app.post("/api/predict/", dependencies=[Depends(login_check)]) -async def predict(request: Request, username: str = Depends(get_current_user)): - body = await request.json() +async def predict(body: PredictBody, username: str = Depends(get_current_user)): + if app.launchable.stateful: + session_hash = body.session_hash + state = app.state_holder.get( + (session_hash, "state"), app.launchable.state_default + ) + body.state = state try: output = await run_in_threadpool(app.launchable.process_api, body, username) + if app.launchable.stateful: + updated_state = output.pop("updated_state") + app.state_holder[(session_hash, "state")] = updated_state + except BaseException as error: if app.launchable.show_error: traceback.print_exc() @@ -222,29 +271,26 @@ async def predict(request: Request, username: str = Depends(get_current_user)): @app.post("/api/flag/", dependencies=[Depends(login_check)]) -async def flag(request: Request, username: str = Depends(get_current_user)): +async def flag(body: FlagBody, username: str = Depends(get_current_user)): if app.launchable.analytics_enabled: await utils.log_feature_analytics(app.launchable.ip_address, "flag") - body = await request.json() - data = body["data"] await run_in_threadpool( app.launchable.flagging_callback.flag, app.launchable, - data["input_data"], - data["output_data"], - flag_option=data.get("flag_option"), - flag_index=data.get("flag_index"), + body.data.input_data, + body.data.output_data, + flag_option=body.data.flag_option, + flag_index=body.data.flag_index, username=username, ) return {"success": True} @app.post("/api/interpret/", dependencies=[Depends(login_check)]) -async def interpret(request: Request): +async def interpret(body: InterpretBody): if app.launchable.analytics_enabled: await utils.log_feature_analytics(app.launchable.ip_address, "interpret") - body = await request.json() - raw_input = body["data"] + raw_input = body.data interpretation_scores, alternative_outputs = await run_in_threadpool( app.launchable.interpret, raw_input ) @@ -255,18 +301,14 @@ async def interpret(request: Request): @app.post("/api/queue/push/", dependencies=[Depends(login_check)]) -async def queue_push(request: Request): - body = await request.json() - action = body["action"] - job_hash, queue_position = queueing.push(body, action) +async def queue_push(body: QueuePushBody): + job_hash, queue_position = queueing.push(body) return {"hash": job_hash, "queue_position": queue_position} @app.post("/api/queue/status/", dependencies=[Depends(login_check)]) -async def queue_status(request: Request): - body = await request.json() - hash = body["hash"] - status, data = queueing.get_status(hash) +async def queue_status(body: QueueStatusBody): + status, data = queueing.get_status(body.hash) return {"status": status, "data": data} diff --git a/test/test_external.py b/test/test_external.py index 6f081f68a0..98d80a22ab 100644 --- a/test/test_external.py +++ b/test/test_external.py @@ -214,7 +214,7 @@ class TestLoadInterface(unittest.TestCase): def test_speech_recognition_model(self): interface_info = gr.external.load_interface( - "models/jonatasgrosman/wav2vec2-large-xlsr-53-english" + "models/facebook/wav2vec2-base-960h" ) io = gr.Interface(**interface_info) io.api_mode = True diff --git a/test/test_queuing.py b/test/test_queuing.py index 7bbfe5d159..93aabc28bf 100644 --- a/test/test_queuing.py +++ b/test/test_queuing.py @@ -4,6 +4,7 @@ import os import unittest from gradio import queueing +from gradio.routes import QueuePushBody os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -30,9 +31,11 @@ class TestQueuingActions(unittest.TestCase): queueing.close() def test_push_pop_status(self): - hash1, position = queueing.push({"data": "test1"}, "predict") + request = QueuePushBody(data="test1", action="predict") + hash1, position = queueing.push(request) self.assertEquals(position, 0) - hash2, position = queueing.push({"data": "test2"}, "predict") + request = QueuePushBody(data="test2", action="predict") + hash2, position = queueing.push(request) self.assertEquals(position, 1) status, position = queueing.get_status(hash2) self.assertEquals(status, "QUEUED") @@ -43,8 +46,9 @@ class TestQueuingActions(unittest.TestCase): self.assertEquals(action, "predict") def test_jobs(self): - hash1, _ = queueing.push({"data": "test1"}, "predict") - hash2, position = queueing.push({"data": "test1"}, "predict") + request = QueuePushBody(data="test1", action="predict") + hash1, _ = queueing.push(request) + hash2, position = queueing.push(request) self.assertEquals(position, 1) queueing.start_job(hash1) diff --git a/test/test_routes.py b/test/test_routes.py index 8ce0b94bb1..b50117c2d3 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -44,6 +44,22 @@ class TestRoutes(unittest.TestCase): self.assertTrue("durations" in output) self.assertTrue("avg_durations" in output) + def test_state(self): + def predict(input, history=""): + history += input + return history, history + + io = Interface(predict, ["textbox", "state"], ["textbox", "state"]) + app, _, _ = io.launch(prevent_thread_lock=True) + client = TestClient(app) + response = client.post("/api/predict/", json={"data": ["test", None]}) + output = dict(response.json()) + print("output", output) + self.assertEqual(output["data"], ["test", None]) + response = client.post("/api/predict/", json={"data": ["test", None]}) + output = dict(response.json()) + self.assertEqual(output["data"], ["testtest", None]) + def test_queue_push_route(self): queueing.push = mock.MagicMock(return_value=(None, None)) response = self.client.post( diff --git a/ui/packages/app/src/api.ts b/ui/packages/app/src/api.ts index 38447af3f9..e162ebf5ba 100644 --- a/ui/packages/app/src/api.ts +++ b/ui/packages/app/src/api.ts @@ -14,12 +14,14 @@ let postData = async (url: string, body: unknown) => { }; export const fn = async ( + session_hash: string, api_endpoint: string, action: string, data: Record, queue: boolean, queue_callback: (pos: number | null, is_initial?: boolean) => void ) => { + data["session_hash"] = session_hash; if (queue && ["predict", "interpret"].includes(action)) { data["action"] = action; const output = await postData(api_endpoint + "queue/push/", data); diff --git a/ui/packages/app/src/main.ts b/ui/packages/app/src/main.ts index 146dcd991e..c51de8d57a 100644 --- a/ui/packages/app/src/main.ts +++ b/ui/packages/app/src/main.ts @@ -96,7 +96,8 @@ window.launchGradio = (config: Config, element_query: string) => { config.dark = true; target.classList.add("dark"); } - config.fn = fn.bind(null, config.root + "api/"); + let session_hash = Math.random().toString(36).substring(2); + config.fn = fn.bind(null, session_hash, config.root + "api/"); if (config.mode === "blocks") { new Blocks({ target: target,