From 04a7199502a4f1c22bc9053a722d31afc594e79d Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 3 Mar 2022 13:55:53 -0500 Subject: [PATCH 01/19] working on a new backend-based state --- gradio/state.py | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 gradio/state.py diff --git a/gradio/state.py b/gradio/state.py new file mode 100644 index 0000000000..d425d7fe1a --- /dev/null +++ b/gradio/state.py @@ -0,0 +1,2 @@ +"""Implements a State class to store state in the backend.""" + From 651d67a0c69b0a9003dc03c136caaeab08f99e3a Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 3 Mar 2022 15:06:04 -0500 Subject: [PATCH 02/19] state work --- gradio/inputs.py | 37 ++++++++++++++++++++++++++++++++++++- gradio/state.py | 25 ++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/gradio/inputs.py b/gradio/inputs.py index 33da90ee3e..8fab4f81de 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -1609,7 +1609,8 @@ class State(InputComponent): default (Any): the initial value of the state. optional (bool): this parameter is ignored. """ - + warnings.warn("The State input component will be deprecated. Please use the " + "new Stateful component.") self.default = default super().__init__(label) @@ -1621,6 +1622,40 @@ class State(InputComponent): return { "state": {}, } + + +class Stateful(InputComponent): + """ + Special hidden component that stores state across runs of the interface. + Input type: Any + Demos: chatbot + """ + + def __init__( + self, + label: str = None, + default: Any = None, + optional: bool = False, + ): + """ + Parameters: + label (str): component name in interface (not used). + default (Any): the initial value of the state. + optional (bool): this parameter is ignored. + """ + + self.default = default + super().__init__(label) + + def get_template_context(self): + return {"default": self.default, **super().get_template_context()} + + @classmethod + def get_shortcut_implementations(cls): + return { + "stateful": {}, + } + def get_input_instance(iface: Interface): diff --git a/gradio/state.py b/gradio/state.py index d425d7fe1a..0d8980e163 100644 --- a/gradio/state.py +++ b/gradio/state.py @@ -1,2 +1,25 @@ -"""Implements a State class to store state in the backend.""" +"""Implements a StateHolder class to store state in the backend.""" +from __future__ import annotations +from typing import Any, Dict + + +class StateHolder: + state_dict: Dict[str, Any] = {} + + def __init__(self, id): + self.id = id + + def __setattr__(self, name, value): + if name == "state": + StateHolder.state_dict[self.id] = value + else: + self.__dict__[name] = value + + + def __getattr__(self, name): + if name == "state": + return StateHolder.state_dict.get(self.id, None) + else: + return self.__dict__[name] + From 8e1577e6debd76caffac1b1102a00f94348d7a3f Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 3 Mar 2022 15:25:14 -0500 Subject: [PATCH 03/19] state fixes; deprecation --- gradio/outputs.py | 2 ++ gradio/state.py | 15 +++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/gradio/outputs.py b/gradio/outputs.py index aecbdfbb49..9ab153919a 100644 --- a/gradio/outputs.py +++ b/gradio/outputs.py @@ -844,6 +844,8 @@ class State(OutputComponent): Parameters: label (str): component name in interface (not used). """ + warnings.warn("The State output component will be deprecated. Please use the " + "new Stateful component.") super().__init__(label) @classmethod diff --git a/gradio/state.py b/gradio/state.py index 0d8980e163..f2b88dafef 100644 --- a/gradio/state.py +++ b/gradio/state.py @@ -8,18 +8,17 @@ class StateHolder: state_dict: Dict[str, Any] = {} def __init__(self, id): - self.id = id + self.__id = id def __setattr__(self, name, value): - if name == "state": - StateHolder.state_dict[self.id] = value - else: + if name.startswith("_"): self.__dict__[name] = value - + else: + StateHolder.state_dict[(self.__id, name)] = value def __getattr__(self, name): - if name == "state": - return StateHolder.state_dict.get(self.id, None) - else: + if name.startswith("_"): return self.__dict__[name] + else: + return StateHolder.state_dict.get((self.__id, name), None) From eadee8fbf5098d7eea2dba4e45918e649325cc23 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 3 Mar 2022 16:38:40 -0500 Subject: [PATCH 04/19] redid state --- gradio/blocks.py | 4 +++- gradio/inputs.py | 34 ---------------------------------- gradio/interface.py | 30 ++++++++++++++++++++++++++++-- gradio/routes.py | 14 +++++++++++--- gradio/state.py | 24 ------------------------ 5 files changed, 42 insertions(+), 64 deletions(-) delete mode 100644 gradio/state.py diff --git a/gradio/blocks.py b/gradio/blocks.py index 53b1745f0e..7256fed695 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -73,6 +73,7 @@ class Blocks(Launchable, BlockContext): self.theme = theme self.requires_permissions = False # TODO: needs to be implemented self.enable_queue = False + self.stateful = False #TODO: implement state super().__init__() Context.root_block = self @@ -80,7 +81,8 @@ class Blocks(Launchable, BlockContext): self.fns = [] self.dependencies = [] - def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]: + def process_api(self, data: Dict[str, Any], username: str = None, state=None) -> Dict[str, Any]: + #TODO: implement state raw_input = data["data"] fn_index = data["fn_index"] fn = self.fns[fn_index] diff --git a/gradio/inputs.py b/gradio/inputs.py index 8fab4f81de..e674963d84 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -1624,40 +1624,6 @@ class State(InputComponent): } -class Stateful(InputComponent): - """ - Special hidden component that stores state across runs of the interface. - Input type: Any - Demos: chatbot - """ - - def __init__( - self, - label: str = None, - default: Any = None, - optional: bool = False, - ): - """ - Parameters: - label (str): component name in interface (not used). - default (Any): the initial value of the state. - optional (bool): this parameter is ignored. - """ - - self.default = default - super().__init__(label) - - def get_template_context(self): - return {"default": self.default, **super().get_template_context()} - - @classmethod - def get_shortcut_implementations(cls): - return { - "stateful": {}, - } - - - def get_input_instance(iface: Interface): if isinstance(iface, str): shortcut = InputComponent.get_all_shortcut_implementations()[iface] diff --git a/gradio/interface.py b/gradio/interface.py index 09756afc84..17d1df7d33 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -176,6 +176,8 @@ class Interface(Launchable): if repeat_outputs_per_model: self.output_components *= len(fn) + + self.stateful = False if sum(isinstance(i, i_State) for i in self.input_components) > 1: raise ValueError("Only one input component can be State.") if sum(isinstance(o, o_State) for o in self.output_components) > 1: @@ -187,10 +189,23 @@ class Interface(Launchable): state_param_index = [ isinstance(i, i_State) for i in self.input_components ].index(True) + self.stateful = True + self.state_param_index = state_param_index state: i_State = self.input_components[state_param_index] if state.default is None: default = utils.get_default_args(fn[0])[state_param_index] state.default = default + self.state_default = state.default + + if sum(isinstance(i, o_State) for i in self.output_components) == 1: + state_return_index = [ + isinstance(i, o_State) for i in self.output_components + ].index(True) + self.state_return_index = state_return_index + else: + raise ValueError("Exactly one input and one output component must be State") + + if ( interpretation is None @@ -544,7 +559,12 @@ class Interface(Launchable): else: return predictions - def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]: + def process_api( + self, + data: Dict[str, Any], + username: str = None, + state: Any = None, + ) -> Dict[str, Any]: flag_index = None if data.get("example_id") is not None: example_id = data["example_id"] @@ -555,6 +575,8 @@ class Interface(Launchable): prediction, durations = process_example(self, example_id) else: raw_input = data["data"] + if self.stateful: + raw_input[self.state_param_index] = state prediction, durations = self.process(raw_input) if self.allow_flagging == "auto": flag_index = self.flagging_callback.flag( @@ -564,13 +586,17 @@ class Interface(Launchable): flag_option="" if self.flagging_options else None, username=username, ) + if self.stateful: + updated_state = prediction[self.state_return_index] + else: + updated_state = None return { "data": prediction, "durations": durations, "avg_durations": self.config.get("avg_durations"), "flag_index": flag_index, - } + }, updated_state def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]: """ diff --git a/gradio/routes.py b/gradio/routes.py index 817f0baa86..71da2986b1 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -9,7 +9,7 @@ import posixpath import secrets import traceback import urllib -from typing import Any, List, Optional, Type +from typing import Any, Dict, List, Optional, Tuple, Type import orjson import pkg_resources @@ -44,7 +44,7 @@ class ORJSONResponse(JSONResponse): def render(self, content: Any) -> bytes: return orjson.dumps(content) - + app = FastAPI(default_response_class=ORJSONResponse) app.add_middleware( @@ -53,6 +53,7 @@ app.add_middleware( allow_methods=["*"], allow_headers=["*"], ) +app.state_holder = {} templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB) @@ -210,8 +211,15 @@ def api_docs(request: Request): @app.post("/api/predict/", dependencies=[Depends(login_check)]) async def predict(request: Request, username: str = Depends(get_current_user)): body = await request.json() + session_hash = body.get("session_hash", None) # TODO(aliabid94): send from frontend + if app.launchable.stateful: + state = app.state_holder.get(session_hash, app.launchable.state_default) + try: - output = await run_in_threadpool(app.launchable.process_api, body, username) + output, updated_state = await run_in_threadpool( + app.launchable.process_api, body, username, state) + app.state_holder[session_hash] = updated_state + except BaseException as error: if app.launchable.show_error: traceback.print_exc() diff --git a/gradio/state.py b/gradio/state.py deleted file mode 100644 index f2b88dafef..0000000000 --- a/gradio/state.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Implements a StateHolder class to store state in the backend.""" -from __future__ import annotations - -from typing import Any, Dict - - -class StateHolder: - state_dict: Dict[str, Any] = {} - - def __init__(self, id): - self.__id = id - - def __setattr__(self, name, value): - if name.startswith("_"): - self.__dict__[name] = value - else: - StateHolder.state_dict[(self.__id, name)] = value - - def __getattr__(self, name): - if name.startswith("_"): - return self.__dict__[name] - else: - return StateHolder.state_dict.get((self.__id, name), None) - From 6f2b57f99d579423bd8de6bf5dcbd91d5e15bd6f Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 3 Mar 2022 16:40:46 -0500 Subject: [PATCH 05/19] redid state --- gradio/inputs.py | 4 +--- gradio/outputs.py | 2 -- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/gradio/inputs.py b/gradio/inputs.py index e674963d84..a06116719b 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -1609,8 +1609,6 @@ class State(InputComponent): default (Any): the initial value of the state. optional (bool): this parameter is ignored. """ - warnings.warn("The State input component will be deprecated. Please use the " - "new Stateful component.") self.default = default super().__init__(label) @@ -1622,7 +1620,7 @@ class State(InputComponent): return { "state": {}, } - + def get_input_instance(iface: Interface): if isinstance(iface, str): diff --git a/gradio/outputs.py b/gradio/outputs.py index 9ab153919a..aecbdfbb49 100644 --- a/gradio/outputs.py +++ b/gradio/outputs.py @@ -844,8 +844,6 @@ class State(OutputComponent): Parameters: label (str): component name in interface (not used). """ - warnings.warn("The State output component will be deprecated. Please use the " - "new Stateful component.") super().__init__(label) @classmethod From 95578e9926086f737f7b87ea5038aba14d7a6300 Mon Sep 17 00:00:00 2001 From: Ali Abid Date: Thu, 3 Mar 2022 16:23:18 -0600 Subject: [PATCH 06/19] add session hash to frontend --- ui/packages/app/src/api.ts | 2 ++ ui/packages/app/src/main.ts | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/ui/packages/app/src/api.ts b/ui/packages/app/src/api.ts index 38447af3f9..e162ebf5ba 100644 --- a/ui/packages/app/src/api.ts +++ b/ui/packages/app/src/api.ts @@ -14,12 +14,14 @@ let postData = async (url: string, body: unknown) => { }; export const fn = async ( + session_hash: string, api_endpoint: string, action: string, data: Record, queue: boolean, queue_callback: (pos: number | null, is_initial?: boolean) => void ) => { + data["session_hash"] = session_hash; if (queue && ["predict", "interpret"].includes(action)) { data["action"] = action; const output = await postData(api_endpoint + "queue/push/", data); diff --git a/ui/packages/app/src/main.ts b/ui/packages/app/src/main.ts index 146dcd991e..c51de8d57a 100644 --- a/ui/packages/app/src/main.ts +++ b/ui/packages/app/src/main.ts @@ -96,7 +96,8 @@ window.launchGradio = (config: Config, element_query: string) => { config.dark = true; target.classList.add("dark"); } - config.fn = fn.bind(null, config.root + "api/"); + let session_hash = Math.random().toString(36).substring(2); + config.fn = fn.bind(null, session_hash, config.root + "api/"); if (config.mode === "blocks") { new Blocks({ target: target, From c01f574e3144b30c1a31ee7187bdef985b41229a Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 3 Mar 2022 23:59:21 -0500 Subject: [PATCH 07/19] cleaned up state --- gradio/blocks.py | 3 +-- gradio/interface.py | 5 +++-- gradio/routes.py | 19 ++++++++++++------- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/gradio/blocks.py b/gradio/blocks.py index 7256fed695..53a003ae8c 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -81,8 +81,7 @@ class Blocks(Launchable, BlockContext): self.fns = [] self.dependencies = [] - def process_api(self, data: Dict[str, Any], username: str = None, state=None) -> Dict[str, Any]: - #TODO: implement state + def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]: raw_input = data["data"] fn_index = data["fn_index"] fn = self.fns[fn_index] diff --git a/gradio/interface.py b/gradio/interface.py index 17d1df7d33..b58155b0a3 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -563,7 +563,6 @@ class Interface(Launchable): self, data: Dict[str, Any], username: str = None, - state: Any = None, ) -> Dict[str, Any]: flag_index = None if data.get("example_id") is not None: @@ -576,6 +575,7 @@ class Interface(Launchable): else: raw_input = data["data"] if self.stateful: + state = data["state"] raw_input[self.state_param_index] = state prediction, durations = self.process(raw_input) if self.allow_flagging == "auto": @@ -596,7 +596,8 @@ class Interface(Launchable): "durations": durations, "avg_durations": self.config.get("avg_durations"), "flag_index": flag_index, - }, updated_state + "updated_state": updated_state + } def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]: """ diff --git a/gradio/routes.py b/gradio/routes.py index 71da2986b1..693609c6a1 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -53,7 +53,9 @@ app.add_middleware( allow_methods=["*"], allow_headers=["*"], ) -app.state_holder = {} + +state_holder: Dict[Tuple[str, str], Any] = {} +app.state_holder = state_holder templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB) @@ -211,14 +213,17 @@ def api_docs(request: Request): @app.post("/api/predict/", dependencies=[Depends(login_check)]) async def predict(request: Request, username: str = Depends(get_current_user)): body = await request.json() - session_hash = body.get("session_hash", None) # TODO(aliabid94): send from frontend - if app.launchable.stateful: - state = app.state_holder.get(session_hash, app.launchable.state_default) + if app.launchable.stateful: + session_hash = body.get("session_hash", None) + state = app.state_holder.get((session_hash, "state"), app.launchable.state_default) + body['state'] = state try: - output, updated_state = await run_in_threadpool( - app.launchable.process_api, body, username, state) - app.state_holder[session_hash] = updated_state + output = await run_in_threadpool( + app.launchable.process_api, body, username) + if app.launchable.stateful: + updated_state = output.pop("updated_state") + app.state_holder[(session_hash, "state")] = updated_state except BaseException as error: if app.launchable.show_error: From 495e61f59898056648881504942de214eaf13e4e Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 4 Mar 2022 00:00:58 -0500 Subject: [PATCH 08/19] formatting --- gradio/blocks.py | 2 +- gradio/interface.py | 13 +++++-------- gradio/routes.py | 15 ++++++++------- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/gradio/blocks.py b/gradio/blocks.py index 53a003ae8c..ad87aacd09 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -73,7 +73,7 @@ class Blocks(Launchable, BlockContext): self.theme = theme self.requires_permissions = False # TODO: needs to be implemented self.enable_queue = False - self.stateful = False #TODO: implement state + self.stateful = False # TODO: implement state super().__init__() Context.root_block = self diff --git a/gradio/interface.py b/gradio/interface.py index b58155b0a3..1b05676f24 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -176,7 +176,6 @@ class Interface(Launchable): if repeat_outputs_per_model: self.output_components *= len(fn) - self.stateful = False if sum(isinstance(i, i_State) for i in self.input_components) > 1: raise ValueError("Only one input component can be State.") @@ -196,7 +195,7 @@ class Interface(Launchable): default = utils.get_default_args(fn[0])[state_param_index] state.default = default self.state_default = state.default - + if sum(isinstance(i, o_State) for i in self.output_components) == 1: state_return_index = [ isinstance(i, o_State) for i in self.output_components @@ -204,8 +203,6 @@ class Interface(Launchable): self.state_return_index = state_return_index else: raise ValueError("Exactly one input and one output component must be State") - - if ( interpretation is None @@ -560,8 +557,8 @@ class Interface(Launchable): return predictions def process_api( - self, - data: Dict[str, Any], + self, + data: Dict[str, Any], username: str = None, ) -> Dict[str, Any]: flag_index = None @@ -588,7 +585,7 @@ class Interface(Launchable): ) if self.stateful: updated_state = prediction[self.state_return_index] - else: + else: updated_state = None return { @@ -596,7 +593,7 @@ class Interface(Launchable): "durations": durations, "avg_durations": self.config.get("avg_durations"), "flag_index": flag_index, - "updated_state": updated_state + "updated_state": updated_state, } def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]: diff --git a/gradio/routes.py b/gradio/routes.py index 693609c6a1..880cd7a7ab 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -44,7 +44,7 @@ class ORJSONResponse(JSONResponse): def render(self, content: Any) -> bytes: return orjson.dumps(content) - + app = FastAPI(default_response_class=ORJSONResponse) app.add_middleware( @@ -213,18 +213,19 @@ def api_docs(request: Request): @app.post("/api/predict/", dependencies=[Depends(login_check)]) async def predict(request: Request, username: str = Depends(get_current_user)): body = await request.json() - + if app.launchable.stateful: session_hash = body.get("session_hash", None) - state = app.state_holder.get((session_hash, "state"), app.launchable.state_default) - body['state'] = state + state = app.state_holder.get( + (session_hash, "state"), app.launchable.state_default + ) + body["state"] = state try: - output = await run_in_threadpool( - app.launchable.process_api, body, username) + output = await run_in_threadpool(app.launchable.process_api, body, username) if app.launchable.stateful: updated_state = output.pop("updated_state") app.state_holder[(session_hash, "state")] = updated_state - + except BaseException as error: if app.launchable.show_error: traceback.print_exc() From 1b3f7333a9bcc4002d46b29f3fb51437a34629ae Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 4 Mar 2022 00:01:28 -0500 Subject: [PATCH 09/19] formatting --- gradio/inputs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gradio/inputs.py b/gradio/inputs.py index a06116719b..85a7805f9f 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -1609,6 +1609,7 @@ class State(InputComponent): default (Any): the initial value of the state. optional (bool): this parameter is ignored. """ + self.default = default super().__init__(label) From 9765c0a5367a060be5b41a7d423df36870f821f6 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 4 Mar 2022 00:01:48 -0500 Subject: [PATCH 10/19] formatting --- gradio/inputs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gradio/inputs.py b/gradio/inputs.py index 85a7805f9f..33da90ee3e 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -1609,7 +1609,7 @@ class State(InputComponent): default (Any): the initial value of the state. optional (bool): this parameter is ignored. """ - + self.default = default super().__init__(label) From 5cf8bcbc01d1814629b2f6232268850a33eb527e Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 4 Mar 2022 00:06:30 -0500 Subject: [PATCH 11/19] cleanups --- gradio/interface.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/gradio/interface.py b/gradio/interface.py index 1b05676f24..8e74538bf3 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -202,7 +202,8 @@ class Interface(Launchable): ].index(True) self.state_return_index = state_return_index else: - raise ValueError("Exactly one input and one output component must be State") + raise ValueError("For a stateful interface, there must be exactly one State" + " input component and one State output component.") if ( interpretation is None @@ -556,11 +557,7 @@ class Interface(Launchable): else: return predictions - def process_api( - self, - data: Dict[str, Any], - username: str = None, - ) -> Dict[str, Any]: + def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]: flag_index = None if data.get("example_id") is not None: example_id = data["example_id"] @@ -585,6 +582,7 @@ class Interface(Launchable): ) if self.stateful: updated_state = prediction[self.state_return_index] + prediction[self.state_return_index] = None else: updated_state = None From a38a428072991820a6591fd84fecfdba9fd73215 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 4 Mar 2022 00:09:20 -0500 Subject: [PATCH 12/19] formatting --- gradio/interface.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/gradio/interface.py b/gradio/interface.py index 8e74538bf3..3f1c8fcdaf 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -202,8 +202,10 @@ class Interface(Launchable): ].index(True) self.state_return_index = state_return_index else: - raise ValueError("For a stateful interface, there must be exactly one State" - " input component and one State output component.") + raise ValueError( + "For a stateful interface, there must be exactly one State" + " input component and one State output component." + ) if ( interpretation is None From fd034ee29f5cbfb290a483ec3bb1e3025744ffa8 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 4 Mar 2022 00:24:17 -0500 Subject: [PATCH 13/19] test state --- gradio/interface.py | 20 ++++++++++---------- test/test_routes.py | 16 ++++++++++++++++ 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/gradio/interface.py b/gradio/interface.py index 3f1c8fcdaf..42b1821aa5 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -196,16 +196,16 @@ class Interface(Launchable): state.default = default self.state_default = state.default - if sum(isinstance(i, o_State) for i in self.output_components) == 1: - state_return_index = [ - isinstance(i, o_State) for i in self.output_components - ].index(True) - self.state_return_index = state_return_index - else: - raise ValueError( - "For a stateful interface, there must be exactly one State" - " input component and one State output component." - ) + if sum(isinstance(i, o_State) for i in self.output_components) == 1: + state_return_index = [ + isinstance(i, o_State) for i in self.output_components + ].index(True) + self.state_return_index = state_return_index + else: + raise ValueError( + "For a stateful interface, there must be exactly one State" + " input component and one State output component." + ) if ( interpretation is None diff --git a/test/test_routes.py b/test/test_routes.py index 8ce0b94bb1..b50117c2d3 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -44,6 +44,22 @@ class TestRoutes(unittest.TestCase): self.assertTrue("durations" in output) self.assertTrue("avg_durations" in output) + def test_state(self): + def predict(input, history=""): + history += input + return history, history + + io = Interface(predict, ["textbox", "state"], ["textbox", "state"]) + app, _, _ = io.launch(prevent_thread_lock=True) + client = TestClient(app) + response = client.post("/api/predict/", json={"data": ["test", None]}) + output = dict(response.json()) + print("output", output) + self.assertEqual(output["data"], ["test", None]) + response = client.post("/api/predict/", json={"data": ["test", None]}) + output = dict(response.json()) + self.assertEqual(output["data"], ["testtest", None]) + def test_queue_push_route(self): queueing.push = mock.MagicMock(return_value=(None, None)) response = self.client.post( From 863b4287a9e51e53bafb44176c99b31e13644ded Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 4 Mar 2022 11:03:10 -0500 Subject: [PATCH 14/19] updated to use pydantic models --- gradio/interface.py | 15 +++++++-------- gradio/routes.py | 20 +++++++++++++++----- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/gradio/interface.py b/gradio/interface.py index 42b1821aa5..1a9c55226c 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -28,7 +28,7 @@ from gradio.outputs import OutputComponent from gradio.outputs import State as o_State # type: ignore from gradio.outputs import get_output_instance from gradio.process_examples import load_from_cache, process_example -from gradio.routes import predict +from gradio.routes import PredictRequest if TYPE_CHECKING: # Only import for type checking (is False at runtime). import flask @@ -559,19 +559,18 @@ class Interface(Launchable): else: return predictions - def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]: + def process_api(self, data: PredictRequest, username: str = None) -> Dict[str, Any]: flag_index = None - if data.get("example_id") is not None: - example_id = data["example_id"] + if data.example_id is not None: if self.cache_examples: - prediction = load_from_cache(self, example_id) + prediction = load_from_cache(self, data.example_id) durations = None else: - prediction, durations = process_example(self, example_id) + prediction, durations = process_example(self, data.example_id) else: - raw_input = data["data"] + raw_input = data.data if self.stateful: - state = data["state"] + state = data.state raw_input[self.state_param_index] = state prediction, durations = self.process(raw_input) if self.allow_flagging == "auto": diff --git a/gradio/routes.py b/gradio/routes.py index 880cd7a7ab..72e079675a 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -21,6 +21,7 @@ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse from fastapi.security import OAuth2PasswordRequestForm from fastapi.templating import Jinja2Templates from jinja2.exceptions import TemplateNotFound +from pydantic import BaseModel from starlette.responses import RedirectResponse from gradio import encryptor, queueing, utils @@ -60,6 +61,17 @@ app.state_holder = state_holder templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB) +########### +# Data Models +########### + +class PredictRequest(BaseModel): + session_hash: Optional[str] + example_id: Optional[int] + data: Any + state: Optional[Any] + + ########### # Auth ########### @@ -211,15 +223,13 @@ def api_docs(request: Request): @app.post("/api/predict/", dependencies=[Depends(login_check)]) -async def predict(request: Request, username: str = Depends(get_current_user)): - body = await request.json() - +async def predict(body: PredictRequest, username: str = Depends(get_current_user)): if app.launchable.stateful: - session_hash = body.get("session_hash", None) + session_hash = body.session_hash state = app.state_holder.get( (session_hash, "state"), app.launchable.state_default ) - body["state"] = state + body.state = state try: output = await run_in_threadpool(app.launchable.process_api, body, username) if app.launchable.stateful: From 9e2cac6e4c1ed9b62968060bd5bf63e61f4e9c43 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 4 Mar 2022 11:06:16 -0500 Subject: [PATCH 15/19] formatting --- gradio/routes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gradio/routes.py b/gradio/routes.py index 72e079675a..5a101d0042 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -65,12 +65,13 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB) # Data Models ########### + class PredictRequest(BaseModel): session_hash: Optional[str] example_id: Optional[int] data: Any state: Optional[Any] - + ########### # Auth From c65f9a599f55b5a82329f18879d9a132815d1fb8 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Sat, 5 Mar 2022 21:56:50 -0500 Subject: [PATCH 16/19] added pydantic data models for all requests --- gradio/queueing.py | 8 +++++-- gradio/routes.py | 55 +++++++++++++++++++++++++++++--------------- test/test_queuing.py | 12 ++++++---- 3 files changed, 50 insertions(+), 25 deletions(-) diff --git a/gradio/queueing.py b/gradio/queueing.py index 9b6ab16abf..d402db9290 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -7,6 +7,9 @@ from typing import Dict, Tuple import requests +from gradio.routes import QueuePushRequest + + DB_FILE = "gradio_queue.db" @@ -106,8 +109,9 @@ def pop() -> Tuple[int, str, Dict, str]: return result[0], result[1], json.loads(result[2]), result[3] -def push(input_data: Dict, action: str) -> Tuple[str, int]: - input_data = json.dumps(input_data) +def push(request: QueuePushRequest) -> Tuple[str, int]: + action = request.action + input_data = json.dumps({'data': request.data}) hash = generate_hash() conn = sqlite3.connect(DB_FILE) c = conn.cursor() diff --git a/gradio/routes.py b/gradio/routes.py index 5a101d0042..482141a09d 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -69,10 +69,34 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB) class PredictRequest(BaseModel): session_hash: Optional[str] example_id: Optional[int] - data: Any + data: List[Any] state: Optional[Any] +class FlagData(BaseModel): + input_data: List[Any] + output_data: List[Any] + flag_option: Optional[str] + flag_index: Optional[int] + + +class FlagRequest(BaseModel): + data: FlagData + + +class InterpretRequest(BaseModel): + data: List[Any] + + +class QueueStatusRequest(BaseModel): + hash: str + + +class QueuePushRequest(BaseModel): + action: str + data: Any + + ########### # Auth ########### @@ -247,29 +271,26 @@ async def predict(body: PredictRequest, username: str = Depends(get_current_user @app.post("/api/flag/", dependencies=[Depends(login_check)]) -async def flag(request: Request, username: str = Depends(get_current_user)): +async def flag(body: FlagRequest, username: str = Depends(get_current_user)): if app.launchable.analytics_enabled: await utils.log_feature_analytics(app.launchable.ip_address, "flag") - body = await request.json() - data = body["data"] await run_in_threadpool( app.launchable.flagging_callback.flag, app.launchable, - data["input_data"], - data["output_data"], - flag_option=data.get("flag_option"), - flag_index=data.get("flag_index"), + body.data.input_data, + body.data.output_data, + flag_option=body.data.flag_option, + flag_index=body.data.flag_index, username=username, ) return {"success": True} @app.post("/api/interpret/", dependencies=[Depends(login_check)]) -async def interpret(request: Request): +async def interpret(body: InterpretRequest): if app.launchable.analytics_enabled: await utils.log_feature_analytics(app.launchable.ip_address, "interpret") - body = await request.json() - raw_input = body["data"] + raw_input = body.data interpretation_scores, alternative_outputs = await run_in_threadpool( app.launchable.interpret, raw_input ) @@ -280,18 +301,14 @@ async def interpret(request: Request): @app.post("/api/queue/push/", dependencies=[Depends(login_check)]) -async def queue_push(request: Request): - body = await request.json() - action = body["action"] - job_hash, queue_position = queueing.push(body, action) +async def queue_push(body: QueuePushRequest): + job_hash, queue_position = queueing.push(body) return {"hash": job_hash, "queue_position": queue_position} @app.post("/api/queue/status/", dependencies=[Depends(login_check)]) -async def queue_status(request: Request): - body = await request.json() - hash = body["hash"] - status, data = queueing.get_status(hash) +async def queue_status(body: QueueStatusRequest): + status, data = queueing.get_status(body.hash) return {"status": status, "data": data} diff --git a/test/test_queuing.py b/test/test_queuing.py index 7bbfe5d159..fcae876635 100644 --- a/test/test_queuing.py +++ b/test/test_queuing.py @@ -4,6 +4,7 @@ import os import unittest from gradio import queueing +from gradio.routes import QueuePushRequest os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -30,9 +31,11 @@ class TestQueuingActions(unittest.TestCase): queueing.close() def test_push_pop_status(self): - hash1, position = queueing.push({"data": "test1"}, "predict") + request = QueuePushRequest(data="test1", action="predict") + hash1, position = queueing.push(request) self.assertEquals(position, 0) - hash2, position = queueing.push({"data": "test2"}, "predict") + request = QueuePushRequest(data="test2", action="predict") + hash2, position = queueing.push(request) self.assertEquals(position, 1) status, position = queueing.get_status(hash2) self.assertEquals(status, "QUEUED") @@ -43,8 +46,9 @@ class TestQueuingActions(unittest.TestCase): self.assertEquals(action, "predict") def test_jobs(self): - hash1, _ = queueing.push({"data": "test1"}, "predict") - hash2, position = queueing.push({"data": "test1"}, "predict") + request = QueuePushRequest(data="test1", action="predict") + hash1, _ = queueing.push(request) + hash2, position = queueing.push(request) self.assertEquals(position, 1) queueing.start_job(hash1) From f9034db75b8d9fa85c68e106a38850e61c09b0e0 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Mon, 7 Mar 2022 13:01:43 -0600 Subject: [PATCH 17/19] fixed naming --- gradio/interface.py | 4 ++-- gradio/queueing.py | 8 ++++---- gradio/routes.py | 20 ++++++++++---------- test/test_queuing.py | 8 ++++---- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/gradio/interface.py b/gradio/interface.py index 1a9c55226c..7dd36aeefb 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -28,7 +28,7 @@ from gradio.outputs import OutputComponent from gradio.outputs import State as o_State # type: ignore from gradio.outputs import get_output_instance from gradio.process_examples import load_from_cache, process_example -from gradio.routes import PredictRequest +from gradio.routes import PredictBody if TYPE_CHECKING: # Only import for type checking (is False at runtime). import flask @@ -559,7 +559,7 @@ class Interface(Launchable): else: return predictions - def process_api(self, data: PredictRequest, username: str = None) -> Dict[str, Any]: + def process_api(self, data: PredictBody, username: str = None) -> Dict[str, Any]: flag_index = None if data.example_id is not None: if self.cache_examples: diff --git a/gradio/queueing.py b/gradio/queueing.py index d402db9290..ef5162b0f4 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -7,7 +7,7 @@ from typing import Dict, Tuple import requests -from gradio.routes import QueuePushRequest +from gradio.routes import QueuePushBody DB_FILE = "gradio_queue.db" @@ -109,9 +109,9 @@ def pop() -> Tuple[int, str, Dict, str]: return result[0], result[1], json.loads(result[2]), result[3] -def push(request: QueuePushRequest) -> Tuple[str, int]: - action = request.action - input_data = json.dumps({'data': request.data}) +def push(body: QueuePushBody) -> Tuple[str, int]: + action = body.action + input_data = json.dumps({'data': body.data}) hash = generate_hash() conn = sqlite3.connect(DB_FILE) c = conn.cursor() diff --git a/gradio/routes.py b/gradio/routes.py index 482141a09d..159ecfa726 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -66,7 +66,7 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB) ########### -class PredictRequest(BaseModel): +class PredictBody(BaseModel): session_hash: Optional[str] example_id: Optional[int] data: List[Any] @@ -80,19 +80,19 @@ class FlagData(BaseModel): flag_index: Optional[int] -class FlagRequest(BaseModel): +class FlagBody(BaseModel): data: FlagData -class InterpretRequest(BaseModel): +class InterpretBody(BaseModel): data: List[Any] -class QueueStatusRequest(BaseModel): +class QueueStatusBody(BaseModel): hash: str -class QueuePushRequest(BaseModel): +class QueuePushBody(BaseModel): action: str data: Any @@ -248,7 +248,7 @@ def api_docs(request: Request): @app.post("/api/predict/", dependencies=[Depends(login_check)]) -async def predict(body: PredictRequest, username: str = Depends(get_current_user)): +async def predict(body: PredictBody, username: str = Depends(get_current_user)): if app.launchable.stateful: session_hash = body.session_hash state = app.state_holder.get( @@ -271,7 +271,7 @@ async def predict(body: PredictRequest, username: str = Depends(get_current_user @app.post("/api/flag/", dependencies=[Depends(login_check)]) -async def flag(body: FlagRequest, username: str = Depends(get_current_user)): +async def flag(body: FlagBody, username: str = Depends(get_current_user)): if app.launchable.analytics_enabled: await utils.log_feature_analytics(app.launchable.ip_address, "flag") await run_in_threadpool( @@ -287,7 +287,7 @@ async def flag(body: FlagRequest, username: str = Depends(get_current_user)): @app.post("/api/interpret/", dependencies=[Depends(login_check)]) -async def interpret(body: InterpretRequest): +async def interpret(body: InterpretBody): if app.launchable.analytics_enabled: await utils.log_feature_analytics(app.launchable.ip_address, "interpret") raw_input = body.data @@ -301,13 +301,13 @@ async def interpret(body: InterpretRequest): @app.post("/api/queue/push/", dependencies=[Depends(login_check)]) -async def queue_push(body: QueuePushRequest): +async def queue_push(body: QueuePushBody): job_hash, queue_position = queueing.push(body) return {"hash": job_hash, "queue_position": queue_position} @app.post("/api/queue/status/", dependencies=[Depends(login_check)]) -async def queue_status(body: QueueStatusRequest): +async def queue_status(body: QueueStatusBody): status, data = queueing.get_status(body.hash) return {"status": status, "data": data} diff --git a/test/test_queuing.py b/test/test_queuing.py index fcae876635..93aabc28bf 100644 --- a/test/test_queuing.py +++ b/test/test_queuing.py @@ -4,7 +4,7 @@ import os import unittest from gradio import queueing -from gradio.routes import QueuePushRequest +from gradio.routes import QueuePushBody os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -31,10 +31,10 @@ class TestQueuingActions(unittest.TestCase): queueing.close() def test_push_pop_status(self): - request = QueuePushRequest(data="test1", action="predict") + request = QueuePushBody(data="test1", action="predict") hash1, position = queueing.push(request) self.assertEquals(position, 0) - request = QueuePushRequest(data="test2", action="predict") + request = QueuePushBody(data="test2", action="predict") hash2, position = queueing.push(request) self.assertEquals(position, 1) status, position = queueing.get_status(hash2) @@ -46,7 +46,7 @@ class TestQueuingActions(unittest.TestCase): self.assertEquals(action, "predict") def test_jobs(self): - request = QueuePushRequest(data="test1", action="predict") + request = QueuePushBody(data="test1", action="predict") hash1, _ = queueing.push(request) hash2, position = queueing.push(request) self.assertEquals(position, 1) From befa7bc8a2ae37530338a526d40e58a060e5b248 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Mon, 7 Mar 2022 13:13:17 -0600 Subject: [PATCH 18/19] fixed speech text model --- test/test_external.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_external.py b/test/test_external.py index 6f081f68a0..98d80a22ab 100644 --- a/test/test_external.py +++ b/test/test_external.py @@ -214,7 +214,7 @@ class TestLoadInterface(unittest.TestCase): def test_speech_recognition_model(self): interface_info = gr.external.load_interface( - "models/jonatasgrosman/wav2vec2-large-xlsr-53-english" + "models/facebook/wav2vec2-base-960h" ) io = gr.Interface(**interface_info) io.api_mode = True From 86b6a5bffad42e86856edfba5c907714cbb11d9e Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Mon, 7 Mar 2022 13:25:28 -0600 Subject: [PATCH 19/19] formatting --- gradio/queueing.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/gradio/queueing.py b/gradio/queueing.py index ef5162b0f4..7d3bbf89c0 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -9,7 +9,6 @@ import requests from gradio.routes import QueuePushBody - DB_FILE = "gradio_queue.db" @@ -110,8 +109,8 @@ def pop() -> Tuple[int, str, Dict, str]: def push(body: QueuePushBody) -> Tuple[str, int]: - action = body.action - input_data = json.dumps({'data': body.data}) + action = body.action + input_data = json.dumps({"data": body.data}) hash = generate_hash() conn = sqlite3.connect(DB_FILE) c = conn.cursor()