Merge pull request #769 from gradio-app/new-state

Rewrite state to be backend-based
This commit is contained in:
Abubakar Abid 2022-03-07 13:37:49 -06:00 committed by GitHub
commit a457f6a446
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 128 additions and 36 deletions

View File

@ -73,6 +73,7 @@ class Blocks(Launchable, BlockContext):
self.theme = theme
self.requires_permissions = False # TODO: needs to be implemented
self.enable_queue = False
self.stateful = False # TODO: implement state
super().__init__()
Context.root_block = self

View File

@ -28,7 +28,7 @@ from gradio.outputs import OutputComponent
from gradio.outputs import State as o_State # type: ignore
from gradio.outputs import get_output_instance
from gradio.process_examples import load_from_cache, process_example
from gradio.routes import predict
from gradio.routes import PredictBody
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import flask
@ -176,6 +176,7 @@ class Interface(Launchable):
if repeat_outputs_per_model:
self.output_components *= len(fn)
self.stateful = False
if sum(isinstance(i, i_State) for i in self.input_components) > 1:
raise ValueError("Only one input component can be State.")
if sum(isinstance(o, o_State) for o in self.output_components) > 1:
@ -187,10 +188,24 @@ class Interface(Launchable):
state_param_index = [
isinstance(i, i_State) for i in self.input_components
].index(True)
self.stateful = True
self.state_param_index = state_param_index
state: i_State = self.input_components[state_param_index]
if state.default is None:
default = utils.get_default_args(fn[0])[state_param_index]
state.default = default
self.state_default = state.default
if sum(isinstance(i, o_State) for i in self.output_components) == 1:
state_return_index = [
isinstance(i, o_State) for i in self.output_components
].index(True)
self.state_return_index = state_return_index
else:
raise ValueError(
"For a stateful interface, there must be exactly one State"
" input component and one State output component."
)
if (
interpretation is None
@ -543,17 +558,19 @@ class Interface(Launchable):
else:
return predictions
def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]:
def process_api(self, data: PredictBody, username: str = None) -> Dict[str, Any]:
flag_index = None
if data.get("example_id") is not None:
example_id = data["example_id"]
if data.example_id is not None:
if self.cache_examples:
prediction = load_from_cache(self, example_id)
prediction = load_from_cache(self, data.example_id)
durations = None
else:
prediction, durations = process_example(self, example_id)
prediction, durations = process_example(self, data.example_id)
else:
raw_input = data["data"]
raw_input = data.data
if self.stateful:
state = data.state
raw_input[self.state_param_index] = state
prediction, durations = self.process(raw_input)
if self.allow_flagging == "auto":
flag_index = self.flagging_callback.flag(
@ -563,12 +580,18 @@ class Interface(Launchable):
flag_option="" if self.flagging_options else None,
username=username,
)
if self.stateful:
updated_state = prediction[self.state_return_index]
prediction[self.state_return_index] = None
else:
updated_state = None
return {
"data": prediction,
"durations": durations,
"avg_durations": self.config.get("avg_durations"),
"flag_index": flag_index,
"updated_state": updated_state,
}
def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:

View File

@ -7,6 +7,8 @@ from typing import Dict, Tuple
import requests
from gradio.routes import QueuePushBody
DB_FILE = "gradio_queue.db"
@ -106,8 +108,9 @@ def pop() -> Tuple[int, str, Dict, str]:
return result[0], result[1], json.loads(result[2]), result[3]
def push(input_data: Dict, action: str) -> Tuple[str, int]:
input_data = json.dumps(input_data)
def push(body: QueuePushBody) -> Tuple[str, int]:
action = body.action
input_data = json.dumps({"data": body.data})
hash = generate_hash()
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()

View File

@ -9,7 +9,7 @@ import posixpath
import secrets
import traceback
import urllib
from typing import Any, List, Optional, Type
from typing import Any, Dict, List, Optional, Tuple, Type
import orjson
import pkg_resources
@ -21,6 +21,7 @@ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.templating import Jinja2Templates
from jinja2.exceptions import TemplateNotFound
from pydantic import BaseModel
from starlette.responses import RedirectResponse
from gradio import encryptor, queueing, utils
@ -54,9 +55,48 @@ app.add_middleware(
allow_headers=["*"],
)
state_holder: Dict[Tuple[str, str], Any] = {}
app.state_holder = state_holder
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
###########
# Data Models
###########
class PredictBody(BaseModel):
session_hash: Optional[str]
example_id: Optional[int]
data: List[Any]
state: Optional[Any]
class FlagData(BaseModel):
input_data: List[Any]
output_data: List[Any]
flag_option: Optional[str]
flag_index: Optional[int]
class FlagBody(BaseModel):
data: FlagData
class InterpretBody(BaseModel):
data: List[Any]
class QueueStatusBody(BaseModel):
hash: str
class QueuePushBody(BaseModel):
action: str
data: Any
###########
# Auth
###########
@ -208,10 +248,19 @@ def api_docs(request: Request):
@app.post("/api/predict/", dependencies=[Depends(login_check)])
async def predict(request: Request, username: str = Depends(get_current_user)):
body = await request.json()
async def predict(body: PredictBody, username: str = Depends(get_current_user)):
if app.launchable.stateful:
session_hash = body.session_hash
state = app.state_holder.get(
(session_hash, "state"), app.launchable.state_default
)
body.state = state
try:
output = await run_in_threadpool(app.launchable.process_api, body, username)
if app.launchable.stateful:
updated_state = output.pop("updated_state")
app.state_holder[(session_hash, "state")] = updated_state
except BaseException as error:
if app.launchable.show_error:
traceback.print_exc()
@ -222,29 +271,26 @@ async def predict(request: Request, username: str = Depends(get_current_user)):
@app.post("/api/flag/", dependencies=[Depends(login_check)])
async def flag(request: Request, username: str = Depends(get_current_user)):
async def flag(body: FlagBody, username: str = Depends(get_current_user)):
if app.launchable.analytics_enabled:
await utils.log_feature_analytics(app.launchable.ip_address, "flag")
body = await request.json()
data = body["data"]
await run_in_threadpool(
app.launchable.flagging_callback.flag,
app.launchable,
data["input_data"],
data["output_data"],
flag_option=data.get("flag_option"),
flag_index=data.get("flag_index"),
body.data.input_data,
body.data.output_data,
flag_option=body.data.flag_option,
flag_index=body.data.flag_index,
username=username,
)
return {"success": True}
@app.post("/api/interpret/", dependencies=[Depends(login_check)])
async def interpret(request: Request):
async def interpret(body: InterpretBody):
if app.launchable.analytics_enabled:
await utils.log_feature_analytics(app.launchable.ip_address, "interpret")
body = await request.json()
raw_input = body["data"]
raw_input = body.data
interpretation_scores, alternative_outputs = await run_in_threadpool(
app.launchable.interpret, raw_input
)
@ -255,18 +301,14 @@ async def interpret(request: Request):
@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
async def queue_push(request: Request):
body = await request.json()
action = body["action"]
job_hash, queue_position = queueing.push(body, action)
async def queue_push(body: QueuePushBody):
job_hash, queue_position = queueing.push(body)
return {"hash": job_hash, "queue_position": queue_position}
@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
async def queue_status(request: Request):
body = await request.json()
hash = body["hash"]
status, data = queueing.get_status(hash)
async def queue_status(body: QueueStatusBody):
status, data = queueing.get_status(body.hash)
return {"status": status, "data": data}

View File

@ -214,7 +214,7 @@ class TestLoadInterface(unittest.TestCase):
def test_speech_recognition_model(self):
interface_info = gr.external.load_interface(
"models/jonatasgrosman/wav2vec2-large-xlsr-53-english"
"models/facebook/wav2vec2-base-960h"
)
io = gr.Interface(**interface_info)
io.api_mode = True

View File

@ -4,6 +4,7 @@ import os
import unittest
from gradio import queueing
from gradio.routes import QueuePushBody
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -30,9 +31,11 @@ class TestQueuingActions(unittest.TestCase):
queueing.close()
def test_push_pop_status(self):
hash1, position = queueing.push({"data": "test1"}, "predict")
request = QueuePushBody(data="test1", action="predict")
hash1, position = queueing.push(request)
self.assertEquals(position, 0)
hash2, position = queueing.push({"data": "test2"}, "predict")
request = QueuePushBody(data="test2", action="predict")
hash2, position = queueing.push(request)
self.assertEquals(position, 1)
status, position = queueing.get_status(hash2)
self.assertEquals(status, "QUEUED")
@ -43,8 +46,9 @@ class TestQueuingActions(unittest.TestCase):
self.assertEquals(action, "predict")
def test_jobs(self):
hash1, _ = queueing.push({"data": "test1"}, "predict")
hash2, position = queueing.push({"data": "test1"}, "predict")
request = QueuePushBody(data="test1", action="predict")
hash1, _ = queueing.push(request)
hash2, position = queueing.push(request)
self.assertEquals(position, 1)
queueing.start_job(hash1)

View File

@ -44,6 +44,22 @@ class TestRoutes(unittest.TestCase):
self.assertTrue("durations" in output)
self.assertTrue("avg_durations" in output)
def test_state(self):
def predict(input, history=""):
history += input
return history, history
io = Interface(predict, ["textbox", "state"], ["textbox", "state"])
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"data": ["test", None]})
output = dict(response.json())
print("output", output)
self.assertEqual(output["data"], ["test", None])
response = client.post("/api/predict/", json={"data": ["test", None]})
output = dict(response.json())
self.assertEqual(output["data"], ["testtest", None])
def test_queue_push_route(self):
queueing.push = mock.MagicMock(return_value=(None, None))
response = self.client.post(

View File

@ -14,12 +14,14 @@ let postData = async (url: string, body: unknown) => {
};
export const fn = async (
session_hash: string,
api_endpoint: string,
action: string,
data: Record<string, unknown>,
queue: boolean,
queue_callback: (pos: number | null, is_initial?: boolean) => void
) => {
data["session_hash"] = session_hash;
if (queue && ["predict", "interpret"].includes(action)) {
data["action"] = action;
const output = await postData(api_endpoint + "queue/push/", data);

View File

@ -96,7 +96,8 @@ window.launchGradio = (config: Config, element_query: string) => {
config.dark = true;
target.classList.add("dark");
}
config.fn = fn.bind(null, config.root + "api/");
let session_hash = Math.random().toString(36).substring(2);
config.fn = fn.bind(null, session_hash, config.root + "api/");
if (config.mode === "blocks") {
new Blocks({
target: target,