Flagging fixes (#1081)

* only show flagging button if manual

* fixing flagging

* fixed flagging examples issue

* formatting

* cleanup

* fixed tests

* predictbody

* formatting

* fixed tests
This commit is contained in:
Abubakar Abid 2022-04-27 00:32:57 -07:00 committed by GitHub
parent cee698bd43
commit 82e95e259c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 69 additions and 62 deletions

View File

@ -10,12 +10,12 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from gradio import encryptor, networking, queueing, strings, utils
from gradio.context import Context
from gradio.routes import PredictBody
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from fastapi.applications import FastAPI
from gradio.components import Component, StatusTracker
from gradio.routes import PredictBody
class Block:
@ -240,7 +240,7 @@ class Blocks(BlockContext):
def process_api(
self,
data: Dict[str, Any],
data: PredictBody,
username: str = None,
state: Optional[Dict[int, any]] = None,
) -> Dict[str, Any]:
@ -252,8 +252,8 @@ class Blocks(BlockContext):
state: data stored from stateful components for session
Returns: None
"""
raw_input = data["data"]
fn_index = data["fn_index"]
raw_input = data.data
fn_index = data.fn_index
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]

View File

@ -82,20 +82,10 @@ class Component(Block):
"""
return data
def save_flagged_file(
self,
dir: str,
label: str,
data: Any,
encryption_key: bool,
file_path: Optional[str] = None,
) -> Optional[str]:
def save_file(self, file: tempfile._TemporaryFileWrapper, dir: str, label: str):
"""
Saved flagged data (e.g. image or audio) as a file and returns filepath
Saved flagged file and returns filepath
"""
if data is None:
return None
file = processing_utils.decode_base64_to_file(data, encryption_key, file_path)
label = "".join([char for char in label if char.isalnum() or char in "._- "])
old_file_name = file.name
output_dir = os.path.join(dir, label)
@ -112,6 +102,22 @@ class Component(Block):
shutil.move(old_file_name, os.path.join(dir, label, new_file_name))
return label + "/" + new_file_name
def save_flagged_file(
self,
dir: str,
label: str,
data: Any,
encryption_key: bool,
file_path: Optional[str] = None,
) -> Optional[str]:
"""
Saved flagged data (e.g. image or audio) as a file and returns filepath
"""
if data is None:
return None
file = processing_utils.decode_base64_to_file(data, encryption_key, file_path)
return self.save_file(file, dir, label)
def restore_flagged_file(
self,
dir: str,
@ -1809,9 +1815,18 @@ class Audio(Component):
"""
Returns: (str) path to audio file
"""
return self.save_flagged_file(
dir, label, None if data is None else data["data"], encryption_key
)
if data is None:
data_string = None
elif isinstance(data, str):
data_string = data
else:
data_string = data["data"]
is_example = data.get("is_example", False)
if is_example:
file_obj = processing_utils.create_tmp_copy_of_file(data["name"])
return self.save_file(file_obj, dir, label)
return self.save_flagged_file(dir, label, data_string, encryption_key)
def generate_sample(self):
return deepcopy(media_data.BASE64_AUDIO)

View File

@ -550,7 +550,15 @@ class Interface(Blocks):
for component in self.output_components:
component.render()
with Row():
flag_btn = Button("Flag")
if self.allow_flagging == "manual":
flag_btn = Button("Flag")
flag_btn._click_no_preprocess(
lambda *flag_data: self.flagging_callback.flag(
flag_data
),
inputs=self.input_components + self.output_components,
outputs=[],
)
if self.interpretation:
interpretation_btn = Button("Interpret")
submit_fn = (
@ -617,11 +625,6 @@ class Interface(Blocks):
+ (self.output_components if self.cache_examples else []),
)
flag_btn._click_no_preprocess(
lambda *flag_data: self.flagging_callback.flag(flag_data),
inputs=self.input_components + self.output_components,
outputs=[],
)
if self.interpretation:
interpretation_btn._click_no_preprocess(
lambda *data: self.interpret(data) + [False, True],

View File

@ -56,30 +56,6 @@ templates = Jinja2Templates(directory=STATIC_TEMPLATE_LIB)
###########
class PredictBody(BaseModel):
session_hash: Optional[str]
example_id: Optional[int]
data: List[Any]
state: Optional[Any]
fn_index: Optional[int]
cleared: Optional[bool]
class FlagData(BaseModel):
input_data: List[Any]
output_data: List[Any]
flag_option: Optional[str]
flag_index: Optional[int]
class FlagBody(BaseModel):
data: FlagData
class InterpretBody(BaseModel):
data: List[Any]
class QueueStatusBody(BaseModel):
hash: str
@ -89,6 +65,12 @@ class QueuePushBody(BaseModel):
data: Any
class PredictBody(BaseModel):
session_hash: Optional[str]
data: Any
fn_index: int
###########
# Auth
###########
@ -250,16 +232,15 @@ def create_app() -> FastAPI:
return templates.TemplateResponse("api_docs.html", {"request": request, **docs})
@app.post("/api/predict/", dependencies=[Depends(login_check)])
async def predict(request: Request, username: str = Depends(get_current_user)):
body = await request.json()
if "session_hash" in body:
if body["session_hash"] not in app.state_holder:
app.state_holder[body["session_hash"]] = {
async def predict(body: PredictBody, username: str = Depends(get_current_user)):
if hasattr(body, "session_hash"):
if body.session_hash not in app.state_holder:
app.state_holder[body.session_hash] = {
_id: getattr(block, "default_value", None)
for _id, block in app.blocks.blocks.items()
if getattr(block, "stateful", False)
}
session_state = app.state_holder[body["session_hash"]]
session_state = app.state_holder[body.session_hash]
else:
session_state = {}
try:

View File

@ -36,14 +36,22 @@ class TestPort(unittest.TestCase):
warnings.warn("Unable to test, no ports available")
class TestInterfaceCustomParameters(unittest.TestCase):
def test_show_error(self):
class TestInterfaceErrors(unittest.TestCase):
def test_processing_error(self):
io = Interface(lambda x: 1 / x, "number", "number")
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"data": [0], "fn_index": 1})
self.assertEqual(response.status_code, 500)
self.assertTrue("error" in response.json())
io.close()
def test_validation_error(self):
io = Interface(lambda x: 1 / x, "number", "number")
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"data": [0]})
self.assertEqual(response.status_code, 500)
self.assertTrue("error" in response.json())
self.assertEqual(response.status_code, 422)
io.close()

View File

@ -38,7 +38,7 @@ class TestRoutes(unittest.TestCase):
def test_predict_route(self):
response = self.client.post(
"/api/predict/", json={"data": ["test"], "fn_index": 0}
"/api/predict/", json={"data": ["test"], "fn_index": 1}
)
self.assertEqual(response.status_code, 200)
output = dict(response.json())
@ -56,14 +56,14 @@ class TestRoutes(unittest.TestCase):
client = TestClient(app)
response = client.post(
"/api/predict/",
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
json={"data": ["test", None], "fn_index": 1, "session_hash": "_"},
)
output = dict(response.json())
print("output", output)
self.assertEqual(output["data"], ["test", None])
response = client.post(
"/api/predict/",
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
json={"data": ["test", None], "fn_index": 1, "session_hash": "_"},
)
output = dict(response.json())
self.assertEqual(output["data"], ["testtest", None])