mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Fix flagged files and ensure that flagging_mode="auto" saves output components as well (#7704)
* interface * docstring * changes * changes * add changeset * changes * add changeset * changes * add changeset * changes * changes * fix * changes * add changeset * flaggin * simplify * changes * helpers * filedata --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
598ad7baf7
commit
95c6bc897b
5
.changeset/social-lies-nail.md
Normal file
5
.changeset/social-lies-nail.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
fix:Fix flagged files and ensure that flagging_mode="auto" saves output components as well
|
@ -14,7 +14,7 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from gradio_client.utils import is_file_obj
|
||||
import gradio_client.utils as client_utils
|
||||
|
||||
from gradio import utils
|
||||
from gradio.blocks import Block, BlockContext
|
||||
@ -202,7 +202,7 @@ class Component(ComponentBase, Block):
|
||||
postprocess=True,
|
||||
keep_in_cache=True,
|
||||
)
|
||||
if is_file_obj(self.value):
|
||||
if client_utils.is_file_obj(self.value):
|
||||
self.keep_in_cache.add(self.value["path"])
|
||||
|
||||
if callable(load_fn):
|
||||
@ -300,7 +300,9 @@ class Component(ComponentBase, Block):
|
||||
if self.data_model:
|
||||
payload = self.data_model.from_json(payload)
|
||||
Path(flag_dir).mkdir(exist_ok=True)
|
||||
return payload.copy_to_dir(flag_dir).model_dump_json()
|
||||
payload = payload.copy_to_dir(flag_dir).model_dump()
|
||||
if not isinstance(payload, str):
|
||||
payload = json.dumps(payload)
|
||||
return payload
|
||||
|
||||
def read_from_flag(self, payload: Any):
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
from gradio_client.documentation import document
|
||||
@ -156,9 +155,6 @@ class Code(Component):
|
||||
else:
|
||||
return value.strip()
|
||||
|
||||
def flag(self, payload: Any, flag_dir: str | Path = "") -> str:
|
||||
return super().flag(payload, flag_dir)
|
||||
|
||||
def api_info(self) -> dict[str, Any]:
|
||||
return {"type": "string"}
|
||||
|
||||
|
@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
from gradio_client.documentation import document
|
||||
@ -94,13 +93,6 @@ class JSON(Component):
|
||||
def example_value(self) -> Any:
|
||||
return {"foo": "bar"}
|
||||
|
||||
def flag(
|
||||
self,
|
||||
payload: Any,
|
||||
flag_dir: str | Path = "", # noqa: ARG002
|
||||
) -> str:
|
||||
return json.dumps(payload)
|
||||
|
||||
def read_from_flag(self, payload: Any):
|
||||
return json.loads(payload)
|
||||
|
||||
|
@ -127,8 +127,8 @@ class CSVLogger(FlaggingCallback):
|
||||
Guides: using-flagging
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, simplify_file_data: bool = True):
|
||||
self.simplify_file_data = simplify_file_data
|
||||
|
||||
def setup(
|
||||
self,
|
||||
@ -167,11 +167,14 @@ class CSVLogger(FlaggingCallback):
|
||||
if utils.is_update(sample):
|
||||
csv_data.append(str(sample))
|
||||
else:
|
||||
csv_data.append(
|
||||
data = (
|
||||
component.flag(sample, flag_dir=save_dir)
|
||||
if sample is not None
|
||||
else ""
|
||||
)
|
||||
if self.simplify_file_data:
|
||||
data = utils.simplify_file_data_in_str(data)
|
||||
csv_data.append(data)
|
||||
csv_data.append(flag_option)
|
||||
csv_data.append(username if username is not None else "")
|
||||
csv_data.append(str(datetime.datetime.now()))
|
||||
@ -416,7 +419,9 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
label = component.label or ""
|
||||
save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
deserialized = component.flag(sample, save_dir)
|
||||
deserialized = utils.simplify_file_data_in_str(
|
||||
component.flag(sample, save_dir)
|
||||
)
|
||||
|
||||
# Add deserialized object to row
|
||||
features[label] = {"dtype": "string", "_type": "Value"}
|
||||
|
@ -309,7 +309,7 @@ class Examples:
|
||||
)
|
||||
else:
|
||||
print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'")
|
||||
cache_logger = CSVLogger()
|
||||
cache_logger = CSVLogger(simplify_file_data=False)
|
||||
|
||||
generated_values = []
|
||||
if inspect.isgeneratorfunction(self.fn):
|
||||
|
@ -25,7 +25,7 @@ from gradio.components import (
|
||||
get_component_instance,
|
||||
)
|
||||
from gradio.data_classes import InterfaceTypes
|
||||
from gradio.events import Events, on
|
||||
from gradio.events import Dependency, Events, on
|
||||
from gradio.exceptions import RenderError
|
||||
from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
|
||||
from gradio.layouts import Accordion, Column, Row, Tab, Tabs
|
||||
@ -105,7 +105,10 @@ class Interface(Blocks):
|
||||
thumbnail: str | None = None,
|
||||
theme: Theme | str | None = None,
|
||||
css: str | None = None,
|
||||
allow_flagging: str | None = None,
|
||||
allow_flagging: Literal["never"]
|
||||
| Literal["auto"]
|
||||
| Literal["manual"]
|
||||
| None = None,
|
||||
flagging_options: list[str] | list[tuple[str, str]] | None = None,
|
||||
flagging_dir: str = "flagged",
|
||||
flagging_callback: FlaggingCallback | None = None,
|
||||
@ -142,7 +145,7 @@ class Interface(Blocks):
|
||||
thumbnail: String path or url to image to use as display image when the web demo is shared on social media.
|
||||
theme: A Theme object or a string representing a theme. If a string, will look for a built-in theme with that name (e.g. "soft" or "default"), or will attempt to load a theme from the Hugging Face Hub (e.g. "gradio/monochrome"). If None, will use the Default theme.
|
||||
css: Custom css as a string or path to a css file. This css will be included in the demo webpage.
|
||||
allow_flagging: One of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every input the user submits will be automatically flagged (outputs are not flagged). If "manual", both the input and outputs are flagged when the user clicks flag button. This parameter can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual".
|
||||
allow_flagging: One of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every input the user submits will be automatically flagged, along with the generated output. If "manual", both the input and outputs are flagged when the user clicks flag button. This parameter can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual".
|
||||
flagging_options: If provided, allows user to select from the list of options when flagging. Only applies if allow_flagging is "manual". Can either be a list of tuples of the form (label, value), where label is the string that will be displayed on the button and value is the string that will be stored in the flagging CSV; or it can be a list of strings ["X", "Y"], in which case the values will be the list of strings and the labels will ["Flag as X", "Flag as Y"], etc.
|
||||
flagging_dir: What to name the directory where flagged data is stored.
|
||||
flagging_callback: None or an instance of a subclass of FlaggingCallback which will be called when a sample is flagged. If set to None, an instance of gradio.flagging.CSVLogger will be created and logs will be saved to a local CSV file in flagging_dir. Default to None.
|
||||
@ -360,7 +363,7 @@ class Interface(Blocks):
|
||||
# For allow_flagging: (1) first check for parameter,
|
||||
# (2) check for env variable, (3) default to True/"manual"
|
||||
if allow_flagging is None:
|
||||
allow_flagging = os.getenv("GRADIO_ALLOW_FLAGGING", "manual")
|
||||
allow_flagging = os.getenv("GRADIO_ALLOW_FLAGGING", "manual") # type: ignore
|
||||
if allow_flagging is True:
|
||||
warnings.warn(
|
||||
"The `allow_flagging` parameter in `Interface` now"
|
||||
@ -451,10 +454,7 @@ class Interface(Blocks):
|
||||
component.label = f"output {i}"
|
||||
|
||||
if self.allow_flagging != "never":
|
||||
if (
|
||||
self.interface_type == InterfaceTypes.UNIFIED
|
||||
or self.allow_flagging == "auto"
|
||||
):
|
||||
if self.interface_type == InterfaceTypes.UNIFIED:
|
||||
self.flagging_callback.setup(self.input_components, self.flagging_dir) # type: ignore
|
||||
elif self.interface_type == InterfaceTypes.INPUT_ONLY:
|
||||
pass
|
||||
@ -509,12 +509,12 @@ class Interface(Blocks):
|
||||
if _clear_btn is None:
|
||||
raise RenderError("Clear button not rendered")
|
||||
|
||||
self.attach_submit_events(_submit_btn, _stop_btn)
|
||||
_submit_event = self.attach_submit_events(_submit_btn, _stop_btn)
|
||||
self.attach_clear_events(_clear_btn, input_component_column)
|
||||
if duplicate_btn is not None:
|
||||
duplicate_btn.activate()
|
||||
|
||||
self.attach_flagging_events(flag_btns, _clear_btn)
|
||||
self.attach_flagging_events(flag_btns, _clear_btn, _submit_event)
|
||||
self.render_examples()
|
||||
self.render_article()
|
||||
|
||||
@ -648,7 +648,7 @@ class Interface(Blocks):
|
||||
|
||||
def attach_submit_events(
|
||||
self, _submit_btn: Button | None, _stop_btn: Button | None
|
||||
):
|
||||
) -> Dependency:
|
||||
if self.live:
|
||||
if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
|
||||
if _submit_btn is None:
|
||||
@ -656,7 +656,7 @@ class Interface(Blocks):
|
||||
super().load(self.fn, None, self.output_components)
|
||||
# For output-only interfaces, the user probably still want a "generate"
|
||||
# button even if the Interface is live
|
||||
_submit_btn.click(
|
||||
return _submit_btn.click(
|
||||
self.fn,
|
||||
None,
|
||||
self.output_components,
|
||||
@ -675,7 +675,7 @@ class Interface(Blocks):
|
||||
streaming_event = True
|
||||
elif component.has_event("change"):
|
||||
events.append(component.change) # type: ignore
|
||||
on(
|
||||
return on(
|
||||
events,
|
||||
self.fn,
|
||||
self.input_components,
|
||||
@ -726,7 +726,7 @@ class Interface(Blocks):
|
||||
concurrency_limit=self.concurrency_limit,
|
||||
)
|
||||
|
||||
predict_event.then(
|
||||
final_event = predict_event.then(
|
||||
cleanup,
|
||||
inputs=None,
|
||||
outputs=extra_output, # type: ignore
|
||||
@ -742,8 +742,9 @@ class Interface(Blocks):
|
||||
queue=False,
|
||||
show_api=False,
|
||||
)
|
||||
return final_event
|
||||
else:
|
||||
on(
|
||||
return on(
|
||||
triggers,
|
||||
fn,
|
||||
self.input_components,
|
||||
@ -783,7 +784,10 @@ class Interface(Blocks):
|
||||
)
|
||||
|
||||
def attach_flagging_events(
|
||||
self, flag_btns: list[Button] | None, _clear_btn: ClearButton
|
||||
self,
|
||||
flag_btns: list[Button] | None,
|
||||
_clear_btn: ClearButton,
|
||||
_submit_event: Dependency,
|
||||
):
|
||||
if not (
|
||||
flag_btns
|
||||
@ -800,9 +804,9 @@ class Interface(Blocks):
|
||||
flag_method = FlagMethod(
|
||||
self.flagging_callback, "", "", visual_feedback=False
|
||||
)
|
||||
flag_btns[0].click( # flag_btns[0] is just the "Submit" button
|
||||
_submit_event.success(
|
||||
flag_method,
|
||||
inputs=self.input_components,
|
||||
inputs=self.input_components + self.output_components,
|
||||
outputs=None,
|
||||
preprocess=False,
|
||||
queue=False,
|
||||
|
@ -39,6 +39,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import anyio
|
||||
import gradio_client.utils as client_utils
|
||||
import httpx
|
||||
from gradio_client.documentation import document
|
||||
from typing_extensions import ParamSpec
|
||||
@ -1152,3 +1153,19 @@ def get_upload_folder() -> str:
|
||||
return os.environ.get("GRADIO_TEMP_DIR") or str(
|
||||
(Path(tempfile.gettempdir()) / "gradio").resolve()
|
||||
)
|
||||
|
||||
|
||||
def simplify_file_data_in_str(s):
|
||||
"""
|
||||
If a FileData dictionary has been dumped as part of a string, this function will replace the dict with just the str filepath
|
||||
"""
|
||||
try:
|
||||
payload = json.loads(s)
|
||||
except json.JSONDecodeError:
|
||||
return s
|
||||
payload = client_utils.traverse(
|
||||
payload, lambda x: x["path"], client_utils.is_file_obj_with_meta
|
||||
)
|
||||
if isinstance(payload, str):
|
||||
return payload
|
||||
return json.dumps(payload)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -21,6 +22,24 @@ class TestDefaultFlagging:
|
||||
assert row_count == 2 # 3 rows written including header
|
||||
io.close()
|
||||
|
||||
def test_files_saved_as_file_paths(self):
|
||||
image = {"path": str(pathlib.Path(__file__).parent / "test_files" / "bus.png")}
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(
|
||||
lambda x: x,
|
||||
"image",
|
||||
"image",
|
||||
flagging_dir=tmpdirname,
|
||||
allow_flagging="auto",
|
||||
)
|
||||
io.launch(prevent_thread_lock=True)
|
||||
io.flagging_callback.flag([image, image])
|
||||
io.close()
|
||||
with open(os.path.join(tmpdirname, "log.csv")) as f:
|
||||
flagged_data = f.readlines()[1].split(",")[0]
|
||||
assert flagged_data.endswith("bus.png")
|
||||
io.close()
|
||||
|
||||
def test_flagging_does_not_create_unnecessary_directories(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
|
||||
|
Loading…
x
Reference in New Issue
Block a user