Merge pull request #769 from gradio-app/new-state

Rewrite state to be backend-based
This commit is contained in:
Abubakar Abid 2022-03-07 13:37:49 -06:00 committed by GitHub
commit a457f6a446
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 128 additions and 36 deletions

View File

@ -73,6 +73,7 @@ class Blocks(Launchable, BlockContext):
self.theme = theme self.theme = theme
self.requires_permissions = False # TODO: needs to be implemented self.requires_permissions = False # TODO: needs to be implemented
self.enable_queue = False self.enable_queue = False
self.stateful = False # TODO: implement state
super().__init__() super().__init__()
Context.root_block = self Context.root_block = self

View File

@ -28,7 +28,7 @@ from gradio.outputs import OutputComponent
from gradio.outputs import State as o_State # type: ignore from gradio.outputs import State as o_State # type: ignore
from gradio.outputs import get_output_instance from gradio.outputs import get_output_instance
from gradio.process_examples import load_from_cache, process_example from gradio.process_examples import load_from_cache, process_example
from gradio.routes import predict from gradio.routes import PredictBody
if TYPE_CHECKING: # Only import for type checking (is False at runtime). if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import flask import flask
@ -176,6 +176,7 @@ class Interface(Launchable):
if repeat_outputs_per_model: if repeat_outputs_per_model:
self.output_components *= len(fn) self.output_components *= len(fn)
self.stateful = False
if sum(isinstance(i, i_State) for i in self.input_components) > 1: if sum(isinstance(i, i_State) for i in self.input_components) > 1:
raise ValueError("Only one input component can be State.") raise ValueError("Only one input component can be State.")
if sum(isinstance(o, o_State) for o in self.output_components) > 1: if sum(isinstance(o, o_State) for o in self.output_components) > 1:
@ -187,10 +188,24 @@ class Interface(Launchable):
state_param_index = [ state_param_index = [
isinstance(i, i_State) for i in self.input_components isinstance(i, i_State) for i in self.input_components
].index(True) ].index(True)
self.stateful = True
self.state_param_index = state_param_index
state: i_State = self.input_components[state_param_index] state: i_State = self.input_components[state_param_index]
if state.default is None: if state.default is None:
default = utils.get_default_args(fn[0])[state_param_index] default = utils.get_default_args(fn[0])[state_param_index]
state.default = default state.default = default
self.state_default = state.default
if sum(isinstance(i, o_State) for i in self.output_components) == 1:
state_return_index = [
isinstance(i, o_State) for i in self.output_components
].index(True)
self.state_return_index = state_return_index
else:
raise ValueError(
"For a stateful interface, there must be exactly one State"
" input component and one State output component."
)
if ( if (
interpretation is None interpretation is None
@ -543,17 +558,19 @@ class Interface(Launchable):
else: else:
return predictions return predictions
def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]: def process_api(self, data: PredictBody, username: str = None) -> Dict[str, Any]:
flag_index = None flag_index = None
if data.get("example_id") is not None: if data.example_id is not None:
example_id = data["example_id"]
if self.cache_examples: if self.cache_examples:
prediction = load_from_cache(self, example_id) prediction = load_from_cache(self, data.example_id)
durations = None durations = None
else: else:
prediction, durations = process_example(self, example_id) prediction, durations = process_example(self, data.example_id)
else: else:
raw_input = data["data"] raw_input = data.data
if self.stateful:
state = data.state
raw_input[self.state_param_index] = state
prediction, durations = self.process(raw_input) prediction, durations = self.process(raw_input)
if self.allow_flagging == "auto": if self.allow_flagging == "auto":
flag_index = self.flagging_callback.flag( flag_index = self.flagging_callback.flag(
@ -563,12 +580,18 @@ class Interface(Launchable):
flag_option="" if self.flagging_options else None, flag_option="" if self.flagging_options else None,
username=username, username=username,
) )
if self.stateful:
updated_state = prediction[self.state_return_index]
prediction[self.state_return_index] = None
else:
updated_state = None
return { return {
"data": prediction, "data": prediction,
"durations": durations, "durations": durations,
"avg_durations": self.config.get("avg_durations"), "avg_durations": self.config.get("avg_durations"),
"flag_index": flag_index, "flag_index": flag_index,
"updated_state": updated_state,
} }
def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]: def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:

View File

@ -7,6 +7,8 @@ from typing import Dict, Tuple
import requests import requests
from gradio.routes import QueuePushBody
DB_FILE = "gradio_queue.db" DB_FILE = "gradio_queue.db"
@ -106,8 +108,9 @@ def pop() -> Tuple[int, str, Dict, str]:
return result[0], result[1], json.loads(result[2]), result[3] return result[0], result[1], json.loads(result[2]), result[3]
def push(input_data: Dict, action: str) -> Tuple[str, int]: def push(body: QueuePushBody) -> Tuple[str, int]:
input_data = json.dumps(input_data) action = body.action
input_data = json.dumps({"data": body.data})
hash = generate_hash() hash = generate_hash()
conn = sqlite3.connect(DB_FILE) conn = sqlite3.connect(DB_FILE)
c = conn.cursor() c = conn.cursor()

View File

@ -9,7 +9,7 @@ import posixpath
import secrets import secrets
import traceback import traceback
import urllib import urllib
from typing import Any, List, Optional, Type from typing import Any, Dict, List, Optional, Tuple, Type
import orjson import orjson
import pkg_resources import pkg_resources
@ -21,6 +21,7 @@ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from jinja2.exceptions import TemplateNotFound from jinja2.exceptions import TemplateNotFound
from pydantic import BaseModel
from starlette.responses import RedirectResponse from starlette.responses import RedirectResponse
from gradio import encryptor, queueing, utils from gradio import encryptor, queueing, utils
@ -54,9 +55,48 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
state_holder: Dict[Tuple[str, str], Any] = {}
app.state_holder = state_holder
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB) templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
###########
# Data Models
###########
class PredictBody(BaseModel):
session_hash: Optional[str]
example_id: Optional[int]
data: List[Any]
state: Optional[Any]
class FlagData(BaseModel):
input_data: List[Any]
output_data: List[Any]
flag_option: Optional[str]
flag_index: Optional[int]
class FlagBody(BaseModel):
data: FlagData
class InterpretBody(BaseModel):
data: List[Any]
class QueueStatusBody(BaseModel):
hash: str
class QueuePushBody(BaseModel):
action: str
data: Any
########### ###########
# Auth # Auth
########### ###########
@ -208,10 +248,19 @@ def api_docs(request: Request):
@app.post("/api/predict/", dependencies=[Depends(login_check)]) @app.post("/api/predict/", dependencies=[Depends(login_check)])
async def predict(request: Request, username: str = Depends(get_current_user)): async def predict(body: PredictBody, username: str = Depends(get_current_user)):
body = await request.json() if app.launchable.stateful:
session_hash = body.session_hash
state = app.state_holder.get(
(session_hash, "state"), app.launchable.state_default
)
body.state = state
try: try:
output = await run_in_threadpool(app.launchable.process_api, body, username) output = await run_in_threadpool(app.launchable.process_api, body, username)
if app.launchable.stateful:
updated_state = output.pop("updated_state")
app.state_holder[(session_hash, "state")] = updated_state
except BaseException as error: except BaseException as error:
if app.launchable.show_error: if app.launchable.show_error:
traceback.print_exc() traceback.print_exc()
@ -222,29 +271,26 @@ async def predict(request: Request, username: str = Depends(get_current_user)):
@app.post("/api/flag/", dependencies=[Depends(login_check)]) @app.post("/api/flag/", dependencies=[Depends(login_check)])
async def flag(request: Request, username: str = Depends(get_current_user)): async def flag(body: FlagBody, username: str = Depends(get_current_user)):
if app.launchable.analytics_enabled: if app.launchable.analytics_enabled:
await utils.log_feature_analytics(app.launchable.ip_address, "flag") await utils.log_feature_analytics(app.launchable.ip_address, "flag")
body = await request.json()
data = body["data"]
await run_in_threadpool( await run_in_threadpool(
app.launchable.flagging_callback.flag, app.launchable.flagging_callback.flag,
app.launchable, app.launchable,
data["input_data"], body.data.input_data,
data["output_data"], body.data.output_data,
flag_option=data.get("flag_option"), flag_option=body.data.flag_option,
flag_index=data.get("flag_index"), flag_index=body.data.flag_index,
username=username, username=username,
) )
return {"success": True} return {"success": True}
@app.post("/api/interpret/", dependencies=[Depends(login_check)]) @app.post("/api/interpret/", dependencies=[Depends(login_check)])
async def interpret(request: Request): async def interpret(body: InterpretBody):
if app.launchable.analytics_enabled: if app.launchable.analytics_enabled:
await utils.log_feature_analytics(app.launchable.ip_address, "interpret") await utils.log_feature_analytics(app.launchable.ip_address, "interpret")
body = await request.json() raw_input = body.data
raw_input = body["data"]
interpretation_scores, alternative_outputs = await run_in_threadpool( interpretation_scores, alternative_outputs = await run_in_threadpool(
app.launchable.interpret, raw_input app.launchable.interpret, raw_input
) )
@ -255,18 +301,14 @@ async def interpret(request: Request):
@app.post("/api/queue/push/", dependencies=[Depends(login_check)]) @app.post("/api/queue/push/", dependencies=[Depends(login_check)])
async def queue_push(request: Request): async def queue_push(body: QueuePushBody):
body = await request.json() job_hash, queue_position = queueing.push(body)
action = body["action"]
job_hash, queue_position = queueing.push(body, action)
return {"hash": job_hash, "queue_position": queue_position} return {"hash": job_hash, "queue_position": queue_position}
@app.post("/api/queue/status/", dependencies=[Depends(login_check)]) @app.post("/api/queue/status/", dependencies=[Depends(login_check)])
async def queue_status(request: Request): async def queue_status(body: QueueStatusBody):
body = await request.json() status, data = queueing.get_status(body.hash)
hash = body["hash"]
status, data = queueing.get_status(hash)
return {"status": status, "data": data} return {"status": status, "data": data}

View File

@ -214,7 +214,7 @@ class TestLoadInterface(unittest.TestCase):
def test_speech_recognition_model(self): def test_speech_recognition_model(self):
interface_info = gr.external.load_interface( interface_info = gr.external.load_interface(
"models/jonatasgrosman/wav2vec2-large-xlsr-53-english" "models/facebook/wav2vec2-base-960h"
) )
io = gr.Interface(**interface_info) io = gr.Interface(**interface_info)
io.api_mode = True io.api_mode = True

View File

@ -4,6 +4,7 @@ import os
import unittest import unittest
from gradio import queueing from gradio import queueing
from gradio.routes import QueuePushBody
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -30,9 +31,11 @@ class TestQueuingActions(unittest.TestCase):
queueing.close() queueing.close()
def test_push_pop_status(self): def test_push_pop_status(self):
hash1, position = queueing.push({"data": "test1"}, "predict") request = QueuePushBody(data="test1", action="predict")
hash1, position = queueing.push(request)
self.assertEquals(position, 0) self.assertEquals(position, 0)
hash2, position = queueing.push({"data": "test2"}, "predict") request = QueuePushBody(data="test2", action="predict")
hash2, position = queueing.push(request)
self.assertEquals(position, 1) self.assertEquals(position, 1)
status, position = queueing.get_status(hash2) status, position = queueing.get_status(hash2)
self.assertEquals(status, "QUEUED") self.assertEquals(status, "QUEUED")
@ -43,8 +46,9 @@ class TestQueuingActions(unittest.TestCase):
self.assertEquals(action, "predict") self.assertEquals(action, "predict")
def test_jobs(self): def test_jobs(self):
hash1, _ = queueing.push({"data": "test1"}, "predict") request = QueuePushBody(data="test1", action="predict")
hash2, position = queueing.push({"data": "test1"}, "predict") hash1, _ = queueing.push(request)
hash2, position = queueing.push(request)
self.assertEquals(position, 1) self.assertEquals(position, 1)
queueing.start_job(hash1) queueing.start_job(hash1)

View File

@ -44,6 +44,22 @@ class TestRoutes(unittest.TestCase):
self.assertTrue("durations" in output) self.assertTrue("durations" in output)
self.assertTrue("avg_durations" in output) self.assertTrue("avg_durations" in output)
def test_state(self):
def predict(input, history=""):
history += input
return history, history
io = Interface(predict, ["textbox", "state"], ["textbox", "state"])
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"data": ["test", None]})
output = dict(response.json())
print("output", output)
self.assertEqual(output["data"], ["test", None])
response = client.post("/api/predict/", json={"data": ["test", None]})
output = dict(response.json())
self.assertEqual(output["data"], ["testtest", None])
def test_queue_push_route(self): def test_queue_push_route(self):
queueing.push = mock.MagicMock(return_value=(None, None)) queueing.push = mock.MagicMock(return_value=(None, None))
response = self.client.post( response = self.client.post(

View File

@ -14,12 +14,14 @@ let postData = async (url: string, body: unknown) => {
}; };
export const fn = async ( export const fn = async (
session_hash: string,
api_endpoint: string, api_endpoint: string,
action: string, action: string,
data: Record<string, unknown>, data: Record<string, unknown>,
queue: boolean, queue: boolean,
queue_callback: (pos: number | null, is_initial?: boolean) => void queue_callback: (pos: number | null, is_initial?: boolean) => void
) => { ) => {
data["session_hash"] = session_hash;
if (queue && ["predict", "interpret"].includes(action)) { if (queue && ["predict", "interpret"].includes(action)) {
data["action"] = action; data["action"] = action;
const output = await postData(api_endpoint + "queue/push/", data); const output = await postData(api_endpoint + "queue/push/", data);

View File

@ -96,7 +96,8 @@ window.launchGradio = (config: Config, element_query: string) => {
config.dark = true; config.dark = true;
target.classList.add("dark"); target.classList.add("dark");
} }
config.fn = fn.bind(null, config.root + "api/"); let session_hash = Math.random().toString(36).substring(2);
config.fn = fn.bind(null, session_hash, config.root + "api/");
if (config.mode === "blocks") { if (config.mode === "blocks") {
new Blocks({ new Blocks({
target: target, target: target,