mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-23 11:39:17 +08:00
Flagging fixes (#1081)
* only show flagging button if manual * fixing flagging * fixed flagging examples issue * formatting * cleanup * fixed tests * predictbody * formatting * fixed tests
This commit is contained in:
parent
cee698bd43
commit
82e95e259c
@ -10,12 +10,12 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from gradio import encryptor, networking, queueing, strings, utils
|
||||
from gradio.context import Context
|
||||
from gradio.routes import PredictBody
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
from fastapi.applications import FastAPI
|
||||
|
||||
from gradio.components import Component, StatusTracker
|
||||
from gradio.routes import PredictBody
|
||||
|
||||
|
||||
class Block:
|
||||
@ -240,7 +240,7 @@ class Blocks(BlockContext):
|
||||
|
||||
def process_api(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
data: PredictBody,
|
||||
username: str = None,
|
||||
state: Optional[Dict[int, any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
@ -252,8 +252,8 @@ class Blocks(BlockContext):
|
||||
state: data stored from stateful components for session
|
||||
Returns: None
|
||||
"""
|
||||
raw_input = data["data"]
|
||||
fn_index = data["fn_index"]
|
||||
raw_input = data.data
|
||||
fn_index = data.fn_index
|
||||
block_fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
|
||||
|
@ -82,20 +82,10 @@ class Component(Block):
|
||||
"""
|
||||
return data
|
||||
|
||||
def save_flagged_file(
|
||||
self,
|
||||
dir: str,
|
||||
label: str,
|
||||
data: Any,
|
||||
encryption_key: bool,
|
||||
file_path: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
def save_file(self, file: tempfile._TemporaryFileWrapper, dir: str, label: str):
|
||||
"""
|
||||
Saved flagged data (e.g. image or audio) as a file and returns filepath
|
||||
Saved flagged file and returns filepath
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
file = processing_utils.decode_base64_to_file(data, encryption_key, file_path)
|
||||
label = "".join([char for char in label if char.isalnum() or char in "._- "])
|
||||
old_file_name = file.name
|
||||
output_dir = os.path.join(dir, label)
|
||||
@ -112,6 +102,22 @@ class Component(Block):
|
||||
shutil.move(old_file_name, os.path.join(dir, label, new_file_name))
|
||||
return label + "/" + new_file_name
|
||||
|
||||
def save_flagged_file(
|
||||
self,
|
||||
dir: str,
|
||||
label: str,
|
||||
data: Any,
|
||||
encryption_key: bool,
|
||||
file_path: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Saved flagged data (e.g. image or audio) as a file and returns filepath
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
file = processing_utils.decode_base64_to_file(data, encryption_key, file_path)
|
||||
return self.save_file(file, dir, label)
|
||||
|
||||
def restore_flagged_file(
|
||||
self,
|
||||
dir: str,
|
||||
@ -1809,9 +1815,18 @@ class Audio(Component):
|
||||
"""
|
||||
Returns: (str) path to audio file
|
||||
"""
|
||||
return self.save_flagged_file(
|
||||
dir, label, None if data is None else data["data"], encryption_key
|
||||
)
|
||||
if data is None:
|
||||
data_string = None
|
||||
elif isinstance(data, str):
|
||||
data_string = data
|
||||
else:
|
||||
data_string = data["data"]
|
||||
is_example = data.get("is_example", False)
|
||||
if is_example:
|
||||
file_obj = processing_utils.create_tmp_copy_of_file(data["name"])
|
||||
return self.save_file(file_obj, dir, label)
|
||||
|
||||
return self.save_flagged_file(dir, label, data_string, encryption_key)
|
||||
|
||||
def generate_sample(self):
|
||||
return deepcopy(media_data.BASE64_AUDIO)
|
||||
|
@ -550,7 +550,15 @@ class Interface(Blocks):
|
||||
for component in self.output_components:
|
||||
component.render()
|
||||
with Row():
|
||||
flag_btn = Button("Flag")
|
||||
if self.allow_flagging == "manual":
|
||||
flag_btn = Button("Flag")
|
||||
flag_btn._click_no_preprocess(
|
||||
lambda *flag_data: self.flagging_callback.flag(
|
||||
flag_data
|
||||
),
|
||||
inputs=self.input_components + self.output_components,
|
||||
outputs=[],
|
||||
)
|
||||
if self.interpretation:
|
||||
interpretation_btn = Button("Interpret")
|
||||
submit_fn = (
|
||||
@ -617,11 +625,6 @@ class Interface(Blocks):
|
||||
+ (self.output_components if self.cache_examples else []),
|
||||
)
|
||||
|
||||
flag_btn._click_no_preprocess(
|
||||
lambda *flag_data: self.flagging_callback.flag(flag_data),
|
||||
inputs=self.input_components + self.output_components,
|
||||
outputs=[],
|
||||
)
|
||||
if self.interpretation:
|
||||
interpretation_btn._click_no_preprocess(
|
||||
lambda *data: self.interpret(data) + [False, True],
|
||||
|
@ -56,30 +56,6 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
|
||||
###########
|
||||
|
||||
|
||||
class PredictBody(BaseModel):
|
||||
session_hash: Optional[str]
|
||||
example_id: Optional[int]
|
||||
data: List[Any]
|
||||
state: Optional[Any]
|
||||
fn_index: Optional[int]
|
||||
cleared: Optional[bool]
|
||||
|
||||
|
||||
class FlagData(BaseModel):
|
||||
input_data: List[Any]
|
||||
output_data: List[Any]
|
||||
flag_option: Optional[str]
|
||||
flag_index: Optional[int]
|
||||
|
||||
|
||||
class FlagBody(BaseModel):
|
||||
data: FlagData
|
||||
|
||||
|
||||
class InterpretBody(BaseModel):
|
||||
data: List[Any]
|
||||
|
||||
|
||||
class QueueStatusBody(BaseModel):
|
||||
hash: str
|
||||
|
||||
@ -89,6 +65,12 @@ class QueuePushBody(BaseModel):
|
||||
data: Any
|
||||
|
||||
|
||||
class PredictBody(BaseModel):
|
||||
session_hash: Optional[str]
|
||||
data: Any
|
||||
fn_index: int
|
||||
|
||||
|
||||
###########
|
||||
# Auth
|
||||
###########
|
||||
@ -250,16 +232,15 @@ def create_app() -> FastAPI:
|
||||
return templates.TemplateResponse("api_docs.html", {"request": request, **docs})
|
||||
|
||||
@app.post("/api/predict/", dependencies=[Depends(login_check)])
|
||||
async def predict(request: Request, username: str = Depends(get_current_user)):
|
||||
body = await request.json()
|
||||
if "session_hash" in body:
|
||||
if body["session_hash"] not in app.state_holder:
|
||||
app.state_holder[body["session_hash"]] = {
|
||||
async def predict(body: PredictBody, username: str = Depends(get_current_user)):
|
||||
if hasattr(body, "session_hash"):
|
||||
if body.session_hash not in app.state_holder:
|
||||
app.state_holder[body.session_hash] = {
|
||||
_id: getattr(block, "default_value", None)
|
||||
for _id, block in app.blocks.blocks.items()
|
||||
if getattr(block, "stateful", False)
|
||||
}
|
||||
session_state = app.state_holder[body["session_hash"]]
|
||||
session_state = app.state_holder[body.session_hash]
|
||||
else:
|
||||
session_state = {}
|
||||
try:
|
||||
|
@ -36,14 +36,22 @@ class TestPort(unittest.TestCase):
|
||||
warnings.warn("Unable to test, no ports available")
|
||||
|
||||
|
||||
class TestInterfaceCustomParameters(unittest.TestCase):
|
||||
def test_show_error(self):
|
||||
class TestInterfaceErrors(unittest.TestCase):
|
||||
def test_processing_error(self):
|
||||
io = Interface(lambda x: 1 / x, "number", "number")
|
||||
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
|
||||
client = TestClient(app)
|
||||
response = client.post("/api/predict/", json={"data": [0], "fn_index": 1})
|
||||
self.assertEqual(response.status_code, 500)
|
||||
self.assertTrue("error" in response.json())
|
||||
io.close()
|
||||
|
||||
def test_validation_error(self):
|
||||
io = Interface(lambda x: 1 / x, "number", "number")
|
||||
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
|
||||
client = TestClient(app)
|
||||
response = client.post("/api/predict/", json={"data": [0]})
|
||||
self.assertEqual(response.status_code, 500)
|
||||
self.assertTrue("error" in response.json())
|
||||
self.assertEqual(response.status_code, 422)
|
||||
io.close()
|
||||
|
||||
|
||||
|
@ -38,7 +38,7 @@ class TestRoutes(unittest.TestCase):
|
||||
|
||||
def test_predict_route(self):
|
||||
response = self.client.post(
|
||||
"/api/predict/", json={"data": ["test"], "fn_index": 0}
|
||||
"/api/predict/", json={"data": ["test"], "fn_index": 1}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
output = dict(response.json())
|
||||
@ -56,14 +56,14 @@ class TestRoutes(unittest.TestCase):
|
||||
client = TestClient(app)
|
||||
response = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
|
||||
json={"data": ["test", None], "fn_index": 1, "session_hash": "_"},
|
||||
)
|
||||
output = dict(response.json())
|
||||
print("output", output)
|
||||
self.assertEqual(output["data"], ["test", None])
|
||||
response = client.post(
|
||||
"/api/predict/",
|
||||
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
|
||||
json={"data": ["test", None], "fn_index": 1, "session_hash": "_"},
|
||||
)
|
||||
output = dict(response.json())
|
||||
self.assertEqual(output["data"], ["testtest", None])
|
||||
|
Loading…
Reference in New Issue
Block a user