mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-27 02:30:17 +08:00
Merge pull request #769 from gradio-app/new-state
Rewrite state to be backend-based
This commit is contained in:
commit
a457f6a446
@ -73,6 +73,7 @@ class Blocks(Launchable, BlockContext):
|
||||
self.theme = theme
|
||||
self.requires_permissions = False # TODO: needs to be implemented
|
||||
self.enable_queue = False
|
||||
self.stateful = False # TODO: implement state
|
||||
|
||||
super().__init__()
|
||||
Context.root_block = self
|
||||
|
@ -28,7 +28,7 @@ from gradio.outputs import OutputComponent
|
||||
from gradio.outputs import State as o_State # type: ignore
|
||||
from gradio.outputs import get_output_instance
|
||||
from gradio.process_examples import load_from_cache, process_example
|
||||
from gradio.routes import predict
|
||||
from gradio.routes import PredictBody
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
import flask
|
||||
@ -176,6 +176,7 @@ class Interface(Launchable):
|
||||
if repeat_outputs_per_model:
|
||||
self.output_components *= len(fn)
|
||||
|
||||
self.stateful = False
|
||||
if sum(isinstance(i, i_State) for i in self.input_components) > 1:
|
||||
raise ValueError("Only one input component can be State.")
|
||||
if sum(isinstance(o, o_State) for o in self.output_components) > 1:
|
||||
@ -187,10 +188,24 @@ class Interface(Launchable):
|
||||
state_param_index = [
|
||||
isinstance(i, i_State) for i in self.input_components
|
||||
].index(True)
|
||||
self.stateful = True
|
||||
self.state_param_index = state_param_index
|
||||
state: i_State = self.input_components[state_param_index]
|
||||
if state.default is None:
|
||||
default = utils.get_default_args(fn[0])[state_param_index]
|
||||
state.default = default
|
||||
self.state_default = state.default
|
||||
|
||||
if sum(isinstance(i, o_State) for i in self.output_components) == 1:
|
||||
state_return_index = [
|
||||
isinstance(i, o_State) for i in self.output_components
|
||||
].index(True)
|
||||
self.state_return_index = state_return_index
|
||||
else:
|
||||
raise ValueError(
|
||||
"For a stateful interface, there must be exactly one State"
|
||||
" input component and one State output component."
|
||||
)
|
||||
|
||||
if (
|
||||
interpretation is None
|
||||
@ -543,17 +558,19 @@ class Interface(Launchable):
|
||||
else:
|
||||
return predictions
|
||||
|
||||
def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]:
|
||||
def process_api(self, data: PredictBody, username: str = None) -> Dict[str, Any]:
|
||||
flag_index = None
|
||||
if data.get("example_id") is not None:
|
||||
example_id = data["example_id"]
|
||||
if data.example_id is not None:
|
||||
if self.cache_examples:
|
||||
prediction = load_from_cache(self, example_id)
|
||||
prediction = load_from_cache(self, data.example_id)
|
||||
durations = None
|
||||
else:
|
||||
prediction, durations = process_example(self, example_id)
|
||||
prediction, durations = process_example(self, data.example_id)
|
||||
else:
|
||||
raw_input = data["data"]
|
||||
raw_input = data.data
|
||||
if self.stateful:
|
||||
state = data.state
|
||||
raw_input[self.state_param_index] = state
|
||||
prediction, durations = self.process(raw_input)
|
||||
if self.allow_flagging == "auto":
|
||||
flag_index = self.flagging_callback.flag(
|
||||
@ -563,12 +580,18 @@ class Interface(Launchable):
|
||||
flag_option="" if self.flagging_options else None,
|
||||
username=username,
|
||||
)
|
||||
if self.stateful:
|
||||
updated_state = prediction[self.state_return_index]
|
||||
prediction[self.state_return_index] = None
|
||||
else:
|
||||
updated_state = None
|
||||
|
||||
return {
|
||||
"data": prediction,
|
||||
"durations": durations,
|
||||
"avg_durations": self.config.get("avg_durations"),
|
||||
"flag_index": flag_index,
|
||||
"updated_state": updated_state,
|
||||
}
|
||||
|
||||
def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:
|
||||
|
@ -7,6 +7,8 @@ from typing import Dict, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from gradio.routes import QueuePushBody
|
||||
|
||||
DB_FILE = "gradio_queue.db"
|
||||
|
||||
|
||||
@ -106,8 +108,9 @@ def pop() -> Tuple[int, str, Dict, str]:
|
||||
return result[0], result[1], json.loads(result[2]), result[3]
|
||||
|
||||
|
||||
def push(input_data: Dict, action: str) -> Tuple[str, int]:
|
||||
input_data = json.dumps(input_data)
|
||||
def push(body: QueuePushBody) -> Tuple[str, int]:
|
||||
action = body.action
|
||||
input_data = json.dumps({"data": body.data})
|
||||
hash = generate_hash()
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
|
@ -9,7 +9,7 @@ import posixpath
|
||||
import secrets
|
||||
import traceback
|
||||
import urllib
|
||||
from typing import Any, List, Optional, Type
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import orjson
|
||||
import pkg_resources
|
||||
@ -21,6 +21,7 @@ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from jinja2.exceptions import TemplateNotFound
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
from gradio import encryptor, queueing, utils
|
||||
@ -54,9 +55,48 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
state_holder: Dict[Tuple[str, str], Any] = {}
|
||||
app.state_holder = state_holder
|
||||
|
||||
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
||||
|
||||
|
||||
###########
|
||||
# Data Models
|
||||
###########
|
||||
|
||||
|
||||
class PredictBody(BaseModel):
|
||||
session_hash: Optional[str]
|
||||
example_id: Optional[int]
|
||||
data: List[Any]
|
||||
state: Optional[Any]
|
||||
|
||||
|
||||
class FlagData(BaseModel):
|
||||
input_data: List[Any]
|
||||
output_data: List[Any]
|
||||
flag_option: Optional[str]
|
||||
flag_index: Optional[int]
|
||||
|
||||
|
||||
class FlagBody(BaseModel):
|
||||
data: FlagData
|
||||
|
||||
|
||||
class InterpretBody(BaseModel):
|
||||
data: List[Any]
|
||||
|
||||
|
||||
class QueueStatusBody(BaseModel):
|
||||
hash: str
|
||||
|
||||
|
||||
class QueuePushBody(BaseModel):
|
||||
action: str
|
||||
data: Any
|
||||
|
||||
|
||||
###########
|
||||
# Auth
|
||||
###########
|
||||
@ -208,10 +248,19 @@ def api_docs(request: Request):
|
||||
|
||||
|
||||
@app.post("/api/predict/", dependencies=[Depends(login_check)])
|
||||
async def predict(request: Request, username: str = Depends(get_current_user)):
|
||||
body = await request.json()
|
||||
async def predict(body: PredictBody, username: str = Depends(get_current_user)):
|
||||
if app.launchable.stateful:
|
||||
session_hash = body.session_hash
|
||||
state = app.state_holder.get(
|
||||
(session_hash, "state"), app.launchable.state_default
|
||||
)
|
||||
body.state = state
|
||||
try:
|
||||
output = await run_in_threadpool(app.launchable.process_api, body, username)
|
||||
if app.launchable.stateful:
|
||||
updated_state = output.pop("updated_state")
|
||||
app.state_holder[(session_hash, "state")] = updated_state
|
||||
|
||||
except BaseException as error:
|
||||
if app.launchable.show_error:
|
||||
traceback.print_exc()
|
||||
@ -222,29 +271,26 @@ async def predict(request: Request, username: str = Depends(get_current_user)):
|
||||
|
||||
|
||||
@app.post("/api/flag/", dependencies=[Depends(login_check)])
|
||||
async def flag(request: Request, username: str = Depends(get_current_user)):
|
||||
async def flag(body: FlagBody, username: str = Depends(get_current_user)):
|
||||
if app.launchable.analytics_enabled:
|
||||
await utils.log_feature_analytics(app.launchable.ip_address, "flag")
|
||||
body = await request.json()
|
||||
data = body["data"]
|
||||
await run_in_threadpool(
|
||||
app.launchable.flagging_callback.flag,
|
||||
app.launchable,
|
||||
data["input_data"],
|
||||
data["output_data"],
|
||||
flag_option=data.get("flag_option"),
|
||||
flag_index=data.get("flag_index"),
|
||||
body.data.input_data,
|
||||
body.data.output_data,
|
||||
flag_option=body.data.flag_option,
|
||||
flag_index=body.data.flag_index,
|
||||
username=username,
|
||||
)
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@app.post("/api/interpret/", dependencies=[Depends(login_check)])
|
||||
async def interpret(request: Request):
|
||||
async def interpret(body: InterpretBody):
|
||||
if app.launchable.analytics_enabled:
|
||||
await utils.log_feature_analytics(app.launchable.ip_address, "interpret")
|
||||
body = await request.json()
|
||||
raw_input = body["data"]
|
||||
raw_input = body.data
|
||||
interpretation_scores, alternative_outputs = await run_in_threadpool(
|
||||
app.launchable.interpret, raw_input
|
||||
)
|
||||
@ -255,18 +301,14 @@ async def interpret(request: Request):
|
||||
|
||||
|
||||
@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
|
||||
async def queue_push(request: Request):
|
||||
body = await request.json()
|
||||
action = body["action"]
|
||||
job_hash, queue_position = queueing.push(body, action)
|
||||
async def queue_push(body: QueuePushBody):
|
||||
job_hash, queue_position = queueing.push(body)
|
||||
return {"hash": job_hash, "queue_position": queue_position}
|
||||
|
||||
|
||||
@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
|
||||
async def queue_status(request: Request):
|
||||
body = await request.json()
|
||||
hash = body["hash"]
|
||||
status, data = queueing.get_status(hash)
|
||||
async def queue_status(body: QueueStatusBody):
|
||||
status, data = queueing.get_status(body.hash)
|
||||
return {"status": status, "data": data}
|
||||
|
||||
|
||||
|
@ -214,7 +214,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
|
||||
def test_speech_recognition_model(self):
|
||||
interface_info = gr.external.load_interface(
|
||||
"models/jonatasgrosman/wav2vec2-large-xlsr-53-english"
|
||||
"models/facebook/wav2vec2-base-960h"
|
||||
)
|
||||
io = gr.Interface(**interface_info)
|
||||
io.api_mode = True
|
||||
|
@ -4,6 +4,7 @@ import os
|
||||
import unittest
|
||||
|
||||
from gradio import queueing
|
||||
from gradio.routes import QueuePushBody
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -30,9 +31,11 @@ class TestQueuingActions(unittest.TestCase):
|
||||
queueing.close()
|
||||
|
||||
def test_push_pop_status(self):
|
||||
hash1, position = queueing.push({"data": "test1"}, "predict")
|
||||
request = QueuePushBody(data="test1", action="predict")
|
||||
hash1, position = queueing.push(request)
|
||||
self.assertEquals(position, 0)
|
||||
hash2, position = queueing.push({"data": "test2"}, "predict")
|
||||
request = QueuePushBody(data="test2", action="predict")
|
||||
hash2, position = queueing.push(request)
|
||||
self.assertEquals(position, 1)
|
||||
status, position = queueing.get_status(hash2)
|
||||
self.assertEquals(status, "QUEUED")
|
||||
@ -43,8 +46,9 @@ class TestQueuingActions(unittest.TestCase):
|
||||
self.assertEquals(action, "predict")
|
||||
|
||||
def test_jobs(self):
|
||||
hash1, _ = queueing.push({"data": "test1"}, "predict")
|
||||
hash2, position = queueing.push({"data": "test1"}, "predict")
|
||||
request = QueuePushBody(data="test1", action="predict")
|
||||
hash1, _ = queueing.push(request)
|
||||
hash2, position = queueing.push(request)
|
||||
self.assertEquals(position, 1)
|
||||
|
||||
queueing.start_job(hash1)
|
||||
|
@ -44,6 +44,22 @@ class TestRoutes(unittest.TestCase):
|
||||
self.assertTrue("durations" in output)
|
||||
self.assertTrue("avg_durations" in output)
|
||||
|
||||
def test_state(self):
|
||||
def predict(input, history=""):
|
||||
history += input
|
||||
return history, history
|
||||
|
||||
io = Interface(predict, ["textbox", "state"], ["textbox", "state"])
|
||||
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||
client = TestClient(app)
|
||||
response = client.post("/api/predict/", json={"data": ["test", None]})
|
||||
output = dict(response.json())
|
||||
print("output", output)
|
||||
self.assertEqual(output["data"], ["test", None])
|
||||
response = client.post("/api/predict/", json={"data": ["test", None]})
|
||||
output = dict(response.json())
|
||||
self.assertEqual(output["data"], ["testtest", None])
|
||||
|
||||
def test_queue_push_route(self):
|
||||
queueing.push = mock.MagicMock(return_value=(None, None))
|
||||
response = self.client.post(
|
||||
|
@ -14,12 +14,14 @@ let postData = async (url: string, body: unknown) => {
|
||||
};
|
||||
|
||||
export const fn = async (
|
||||
session_hash: string,
|
||||
api_endpoint: string,
|
||||
action: string,
|
||||
data: Record<string, unknown>,
|
||||
queue: boolean,
|
||||
queue_callback: (pos: number | null, is_initial?: boolean) => void
|
||||
) => {
|
||||
data["session_hash"] = session_hash;
|
||||
if (queue && ["predict", "interpret"].includes(action)) {
|
||||
data["action"] = action;
|
||||
const output = await postData(api_endpoint + "queue/push/", data);
|
||||
|
@ -96,7 +96,8 @@ window.launchGradio = (config: Config, element_query: string) => {
|
||||
config.dark = true;
|
||||
target.classList.add("dark");
|
||||
}
|
||||
config.fn = fn.bind(null, config.root + "api/");
|
||||
let session_hash = Math.random().toString(36).substring(2);
|
||||
config.fn = fn.bind(null, session_hash, config.root + "api/");
|
||||
if (config.mode === "blocks") {
|
||||
new Blocks({
|
||||
target: target,
|
||||
|
Loading…
Reference in New Issue
Block a user