mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-27 02:30:17 +08:00
Merge pull request #769 from gradio-app/new-state
Rewrite state to be backend-based
This commit is contained in:
commit
a457f6a446
@ -73,6 +73,7 @@ class Blocks(Launchable, BlockContext):
|
|||||||
self.theme = theme
|
self.theme = theme
|
||||||
self.requires_permissions = False # TODO: needs to be implemented
|
self.requires_permissions = False # TODO: needs to be implemented
|
||||||
self.enable_queue = False
|
self.enable_queue = False
|
||||||
|
self.stateful = False # TODO: implement state
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
Context.root_block = self
|
Context.root_block = self
|
||||||
|
@ -28,7 +28,7 @@ from gradio.outputs import OutputComponent
|
|||||||
from gradio.outputs import State as o_State # type: ignore
|
from gradio.outputs import State as o_State # type: ignore
|
||||||
from gradio.outputs import get_output_instance
|
from gradio.outputs import get_output_instance
|
||||||
from gradio.process_examples import load_from_cache, process_example
|
from gradio.process_examples import load_from_cache, process_example
|
||||||
from gradio.routes import predict
|
from gradio.routes import PredictBody
|
||||||
|
|
||||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||||
import flask
|
import flask
|
||||||
@ -176,6 +176,7 @@ class Interface(Launchable):
|
|||||||
if repeat_outputs_per_model:
|
if repeat_outputs_per_model:
|
||||||
self.output_components *= len(fn)
|
self.output_components *= len(fn)
|
||||||
|
|
||||||
|
self.stateful = False
|
||||||
if sum(isinstance(i, i_State) for i in self.input_components) > 1:
|
if sum(isinstance(i, i_State) for i in self.input_components) > 1:
|
||||||
raise ValueError("Only one input component can be State.")
|
raise ValueError("Only one input component can be State.")
|
||||||
if sum(isinstance(o, o_State) for o in self.output_components) > 1:
|
if sum(isinstance(o, o_State) for o in self.output_components) > 1:
|
||||||
@ -187,10 +188,24 @@ class Interface(Launchable):
|
|||||||
state_param_index = [
|
state_param_index = [
|
||||||
isinstance(i, i_State) for i in self.input_components
|
isinstance(i, i_State) for i in self.input_components
|
||||||
].index(True)
|
].index(True)
|
||||||
|
self.stateful = True
|
||||||
|
self.state_param_index = state_param_index
|
||||||
state: i_State = self.input_components[state_param_index]
|
state: i_State = self.input_components[state_param_index]
|
||||||
if state.default is None:
|
if state.default is None:
|
||||||
default = utils.get_default_args(fn[0])[state_param_index]
|
default = utils.get_default_args(fn[0])[state_param_index]
|
||||||
state.default = default
|
state.default = default
|
||||||
|
self.state_default = state.default
|
||||||
|
|
||||||
|
if sum(isinstance(i, o_State) for i in self.output_components) == 1:
|
||||||
|
state_return_index = [
|
||||||
|
isinstance(i, o_State) for i in self.output_components
|
||||||
|
].index(True)
|
||||||
|
self.state_return_index = state_return_index
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"For a stateful interface, there must be exactly one State"
|
||||||
|
" input component and one State output component."
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
interpretation is None
|
interpretation is None
|
||||||
@ -543,17 +558,19 @@ class Interface(Launchable):
|
|||||||
else:
|
else:
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]:
|
def process_api(self, data: PredictBody, username: str = None) -> Dict[str, Any]:
|
||||||
flag_index = None
|
flag_index = None
|
||||||
if data.get("example_id") is not None:
|
if data.example_id is not None:
|
||||||
example_id = data["example_id"]
|
|
||||||
if self.cache_examples:
|
if self.cache_examples:
|
||||||
prediction = load_from_cache(self, example_id)
|
prediction = load_from_cache(self, data.example_id)
|
||||||
durations = None
|
durations = None
|
||||||
else:
|
else:
|
||||||
prediction, durations = process_example(self, example_id)
|
prediction, durations = process_example(self, data.example_id)
|
||||||
else:
|
else:
|
||||||
raw_input = data["data"]
|
raw_input = data.data
|
||||||
|
if self.stateful:
|
||||||
|
state = data.state
|
||||||
|
raw_input[self.state_param_index] = state
|
||||||
prediction, durations = self.process(raw_input)
|
prediction, durations = self.process(raw_input)
|
||||||
if self.allow_flagging == "auto":
|
if self.allow_flagging == "auto":
|
||||||
flag_index = self.flagging_callback.flag(
|
flag_index = self.flagging_callback.flag(
|
||||||
@ -563,12 +580,18 @@ class Interface(Launchable):
|
|||||||
flag_option="" if self.flagging_options else None,
|
flag_option="" if self.flagging_options else None,
|
||||||
username=username,
|
username=username,
|
||||||
)
|
)
|
||||||
|
if self.stateful:
|
||||||
|
updated_state = prediction[self.state_return_index]
|
||||||
|
prediction[self.state_return_index] = None
|
||||||
|
else:
|
||||||
|
updated_state = None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"data": prediction,
|
"data": prediction,
|
||||||
"durations": durations,
|
"durations": durations,
|
||||||
"avg_durations": self.config.get("avg_durations"),
|
"avg_durations": self.config.get("avg_durations"),
|
||||||
"flag_index": flag_index,
|
"flag_index": flag_index,
|
||||||
|
"updated_state": updated_state,
|
||||||
}
|
}
|
||||||
|
|
||||||
def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:
|
def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:
|
||||||
|
@ -7,6 +7,8 @@ from typing import Dict, Tuple
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from gradio.routes import QueuePushBody
|
||||||
|
|
||||||
DB_FILE = "gradio_queue.db"
|
DB_FILE = "gradio_queue.db"
|
||||||
|
|
||||||
|
|
||||||
@ -106,8 +108,9 @@ def pop() -> Tuple[int, str, Dict, str]:
|
|||||||
return result[0], result[1], json.loads(result[2]), result[3]
|
return result[0], result[1], json.loads(result[2]), result[3]
|
||||||
|
|
||||||
|
|
||||||
def push(input_data: Dict, action: str) -> Tuple[str, int]:
|
def push(body: QueuePushBody) -> Tuple[str, int]:
|
||||||
input_data = json.dumps(input_data)
|
action = body.action
|
||||||
|
input_data = json.dumps({"data": body.data})
|
||||||
hash = generate_hash()
|
hash = generate_hash()
|
||||||
conn = sqlite3.connect(DB_FILE)
|
conn = sqlite3.connect(DB_FILE)
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
|
@ -9,7 +9,7 @@ import posixpath
|
|||||||
import secrets
|
import secrets
|
||||||
import traceback
|
import traceback
|
||||||
import urllib
|
import urllib
|
||||||
from typing import Any, List, Optional, Type
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
@ -21,6 +21,7 @@ from fastapi.responses import FileResponse, HTMLResponse, JSONResponse
|
|||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
from jinja2.exceptions import TemplateNotFound
|
from jinja2.exceptions import TemplateNotFound
|
||||||
|
from pydantic import BaseModel
|
||||||
from starlette.responses import RedirectResponse
|
from starlette.responses import RedirectResponse
|
||||||
|
|
||||||
from gradio import encryptor, queueing, utils
|
from gradio import encryptor, queueing, utils
|
||||||
@ -54,9 +55,48 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
state_holder: Dict[Tuple[str, str], Any] = {}
|
||||||
|
app.state_holder = state_holder
|
||||||
|
|
||||||
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
||||||
|
|
||||||
|
|
||||||
|
###########
|
||||||
|
# Data Models
|
||||||
|
###########
|
||||||
|
|
||||||
|
|
||||||
|
class PredictBody(BaseModel):
|
||||||
|
session_hash: Optional[str]
|
||||||
|
example_id: Optional[int]
|
||||||
|
data: List[Any]
|
||||||
|
state: Optional[Any]
|
||||||
|
|
||||||
|
|
||||||
|
class FlagData(BaseModel):
|
||||||
|
input_data: List[Any]
|
||||||
|
output_data: List[Any]
|
||||||
|
flag_option: Optional[str]
|
||||||
|
flag_index: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
class FlagBody(BaseModel):
|
||||||
|
data: FlagData
|
||||||
|
|
||||||
|
|
||||||
|
class InterpretBody(BaseModel):
|
||||||
|
data: List[Any]
|
||||||
|
|
||||||
|
|
||||||
|
class QueueStatusBody(BaseModel):
|
||||||
|
hash: str
|
||||||
|
|
||||||
|
|
||||||
|
class QueuePushBody(BaseModel):
|
||||||
|
action: str
|
||||||
|
data: Any
|
||||||
|
|
||||||
|
|
||||||
###########
|
###########
|
||||||
# Auth
|
# Auth
|
||||||
###########
|
###########
|
||||||
@ -208,10 +248,19 @@ def api_docs(request: Request):
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/api/predict/", dependencies=[Depends(login_check)])
|
@app.post("/api/predict/", dependencies=[Depends(login_check)])
|
||||||
async def predict(request: Request, username: str = Depends(get_current_user)):
|
async def predict(body: PredictBody, username: str = Depends(get_current_user)):
|
||||||
body = await request.json()
|
if app.launchable.stateful:
|
||||||
|
session_hash = body.session_hash
|
||||||
|
state = app.state_holder.get(
|
||||||
|
(session_hash, "state"), app.launchable.state_default
|
||||||
|
)
|
||||||
|
body.state = state
|
||||||
try:
|
try:
|
||||||
output = await run_in_threadpool(app.launchable.process_api, body, username)
|
output = await run_in_threadpool(app.launchable.process_api, body, username)
|
||||||
|
if app.launchable.stateful:
|
||||||
|
updated_state = output.pop("updated_state")
|
||||||
|
app.state_holder[(session_hash, "state")] = updated_state
|
||||||
|
|
||||||
except BaseException as error:
|
except BaseException as error:
|
||||||
if app.launchable.show_error:
|
if app.launchable.show_error:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
@ -222,29 +271,26 @@ async def predict(request: Request, username: str = Depends(get_current_user)):
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/api/flag/", dependencies=[Depends(login_check)])
|
@app.post("/api/flag/", dependencies=[Depends(login_check)])
|
||||||
async def flag(request: Request, username: str = Depends(get_current_user)):
|
async def flag(body: FlagBody, username: str = Depends(get_current_user)):
|
||||||
if app.launchable.analytics_enabled:
|
if app.launchable.analytics_enabled:
|
||||||
await utils.log_feature_analytics(app.launchable.ip_address, "flag")
|
await utils.log_feature_analytics(app.launchable.ip_address, "flag")
|
||||||
body = await request.json()
|
|
||||||
data = body["data"]
|
|
||||||
await run_in_threadpool(
|
await run_in_threadpool(
|
||||||
app.launchable.flagging_callback.flag,
|
app.launchable.flagging_callback.flag,
|
||||||
app.launchable,
|
app.launchable,
|
||||||
data["input_data"],
|
body.data.input_data,
|
||||||
data["output_data"],
|
body.data.output_data,
|
||||||
flag_option=data.get("flag_option"),
|
flag_option=body.data.flag_option,
|
||||||
flag_index=data.get("flag_index"),
|
flag_index=body.data.flag_index,
|
||||||
username=username,
|
username=username,
|
||||||
)
|
)
|
||||||
return {"success": True}
|
return {"success": True}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/interpret/", dependencies=[Depends(login_check)])
|
@app.post("/api/interpret/", dependencies=[Depends(login_check)])
|
||||||
async def interpret(request: Request):
|
async def interpret(body: InterpretBody):
|
||||||
if app.launchable.analytics_enabled:
|
if app.launchable.analytics_enabled:
|
||||||
await utils.log_feature_analytics(app.launchable.ip_address, "interpret")
|
await utils.log_feature_analytics(app.launchable.ip_address, "interpret")
|
||||||
body = await request.json()
|
raw_input = body.data
|
||||||
raw_input = body["data"]
|
|
||||||
interpretation_scores, alternative_outputs = await run_in_threadpool(
|
interpretation_scores, alternative_outputs = await run_in_threadpool(
|
||||||
app.launchable.interpret, raw_input
|
app.launchable.interpret, raw_input
|
||||||
)
|
)
|
||||||
@ -255,18 +301,14 @@ async def interpret(request: Request):
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
|
@app.post("/api/queue/push/", dependencies=[Depends(login_check)])
|
||||||
async def queue_push(request: Request):
|
async def queue_push(body: QueuePushBody):
|
||||||
body = await request.json()
|
job_hash, queue_position = queueing.push(body)
|
||||||
action = body["action"]
|
|
||||||
job_hash, queue_position = queueing.push(body, action)
|
|
||||||
return {"hash": job_hash, "queue_position": queue_position}
|
return {"hash": job_hash, "queue_position": queue_position}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
|
@app.post("/api/queue/status/", dependencies=[Depends(login_check)])
|
||||||
async def queue_status(request: Request):
|
async def queue_status(body: QueueStatusBody):
|
||||||
body = await request.json()
|
status, data = queueing.get_status(body.hash)
|
||||||
hash = body["hash"]
|
|
||||||
status, data = queueing.get_status(hash)
|
|
||||||
return {"status": status, "data": data}
|
return {"status": status, "data": data}
|
||||||
|
|
||||||
|
|
||||||
|
@ -214,7 +214,7 @@ class TestLoadInterface(unittest.TestCase):
|
|||||||
|
|
||||||
def test_speech_recognition_model(self):
|
def test_speech_recognition_model(self):
|
||||||
interface_info = gr.external.load_interface(
|
interface_info = gr.external.load_interface(
|
||||||
"models/jonatasgrosman/wav2vec2-large-xlsr-53-english"
|
"models/facebook/wav2vec2-base-960h"
|
||||||
)
|
)
|
||||||
io = gr.Interface(**interface_info)
|
io = gr.Interface(**interface_info)
|
||||||
io.api_mode = True
|
io.api_mode = True
|
||||||
|
@ -4,6 +4,7 @@ import os
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from gradio import queueing
|
from gradio import queueing
|
||||||
|
from gradio.routes import QueuePushBody
|
||||||
|
|
||||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||||
|
|
||||||
@ -30,9 +31,11 @@ class TestQueuingActions(unittest.TestCase):
|
|||||||
queueing.close()
|
queueing.close()
|
||||||
|
|
||||||
def test_push_pop_status(self):
|
def test_push_pop_status(self):
|
||||||
hash1, position = queueing.push({"data": "test1"}, "predict")
|
request = QueuePushBody(data="test1", action="predict")
|
||||||
|
hash1, position = queueing.push(request)
|
||||||
self.assertEquals(position, 0)
|
self.assertEquals(position, 0)
|
||||||
hash2, position = queueing.push({"data": "test2"}, "predict")
|
request = QueuePushBody(data="test2", action="predict")
|
||||||
|
hash2, position = queueing.push(request)
|
||||||
self.assertEquals(position, 1)
|
self.assertEquals(position, 1)
|
||||||
status, position = queueing.get_status(hash2)
|
status, position = queueing.get_status(hash2)
|
||||||
self.assertEquals(status, "QUEUED")
|
self.assertEquals(status, "QUEUED")
|
||||||
@ -43,8 +46,9 @@ class TestQueuingActions(unittest.TestCase):
|
|||||||
self.assertEquals(action, "predict")
|
self.assertEquals(action, "predict")
|
||||||
|
|
||||||
def test_jobs(self):
|
def test_jobs(self):
|
||||||
hash1, _ = queueing.push({"data": "test1"}, "predict")
|
request = QueuePushBody(data="test1", action="predict")
|
||||||
hash2, position = queueing.push({"data": "test1"}, "predict")
|
hash1, _ = queueing.push(request)
|
||||||
|
hash2, position = queueing.push(request)
|
||||||
self.assertEquals(position, 1)
|
self.assertEquals(position, 1)
|
||||||
|
|
||||||
queueing.start_job(hash1)
|
queueing.start_job(hash1)
|
||||||
|
@ -44,6 +44,22 @@ class TestRoutes(unittest.TestCase):
|
|||||||
self.assertTrue("durations" in output)
|
self.assertTrue("durations" in output)
|
||||||
self.assertTrue("avg_durations" in output)
|
self.assertTrue("avg_durations" in output)
|
||||||
|
|
||||||
|
def test_state(self):
|
||||||
|
def predict(input, history=""):
|
||||||
|
history += input
|
||||||
|
return history, history
|
||||||
|
|
||||||
|
io = Interface(predict, ["textbox", "state"], ["textbox", "state"])
|
||||||
|
app, _, _ = io.launch(prevent_thread_lock=True)
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.post("/api/predict/", json={"data": ["test", None]})
|
||||||
|
output = dict(response.json())
|
||||||
|
print("output", output)
|
||||||
|
self.assertEqual(output["data"], ["test", None])
|
||||||
|
response = client.post("/api/predict/", json={"data": ["test", None]})
|
||||||
|
output = dict(response.json())
|
||||||
|
self.assertEqual(output["data"], ["testtest", None])
|
||||||
|
|
||||||
def test_queue_push_route(self):
|
def test_queue_push_route(self):
|
||||||
queueing.push = mock.MagicMock(return_value=(None, None))
|
queueing.push = mock.MagicMock(return_value=(None, None))
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
|
@ -14,12 +14,14 @@ let postData = async (url: string, body: unknown) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const fn = async (
|
export const fn = async (
|
||||||
|
session_hash: string,
|
||||||
api_endpoint: string,
|
api_endpoint: string,
|
||||||
action: string,
|
action: string,
|
||||||
data: Record<string, unknown>,
|
data: Record<string, unknown>,
|
||||||
queue: boolean,
|
queue: boolean,
|
||||||
queue_callback: (pos: number | null, is_initial?: boolean) => void
|
queue_callback: (pos: number | null, is_initial?: boolean) => void
|
||||||
) => {
|
) => {
|
||||||
|
data["session_hash"] = session_hash;
|
||||||
if (queue && ["predict", "interpret"].includes(action)) {
|
if (queue && ["predict", "interpret"].includes(action)) {
|
||||||
data["action"] = action;
|
data["action"] = action;
|
||||||
const output = await postData(api_endpoint + "queue/push/", data);
|
const output = await postData(api_endpoint + "queue/push/", data);
|
||||||
|
@ -96,7 +96,8 @@ window.launchGradio = (config: Config, element_query: string) => {
|
|||||||
config.dark = true;
|
config.dark = true;
|
||||||
target.classList.add("dark");
|
target.classList.add("dark");
|
||||||
}
|
}
|
||||||
config.fn = fn.bind(null, config.root + "api/");
|
let session_hash = Math.random().toString(36).substring(2);
|
||||||
|
config.fn = fn.bind(null, session_hash, config.root + "api/");
|
||||||
if (config.mode === "blocks") {
|
if (config.mode === "blocks") {
|
||||||
new Blocks({
|
new Blocks({
|
||||||
target: target,
|
target: target,
|
||||||
|
Loading…
Reference in New Issue
Block a user