Fix flagged files and ensure that flagging_mode="auto" saves output components as well (#7704)

* interface

* docstring

* changes

* changes

* add changeset

* changes

* add changeset

* changes

* add changeset

* changes

* changes

* fix

* changes

* add changeset

* flaggin

* simplify

* changes

* helpers

* filedata

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
Abubakar Abid 2024-03-14 14:03:34 -07:00 committed by GitHub
parent 598ad7baf7
commit 95c6bc897b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 78 additions and 38 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": minor
---
fix:Fix flagged files and ensure that flagging_mode="auto" saves output components as well

View File

@ -14,7 +14,7 @@ from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
from gradio_client.utils import is_file_obj
import gradio_client.utils as client_utils
from gradio import utils
from gradio.blocks import Block, BlockContext
@ -202,7 +202,7 @@ class Component(ComponentBase, Block):
postprocess=True,
keep_in_cache=True,
)
if is_file_obj(self.value):
if client_utils.is_file_obj(self.value):
self.keep_in_cache.add(self.value["path"])
if callable(load_fn):
@ -300,7 +300,9 @@ class Component(ComponentBase, Block):
if self.data_model:
payload = self.data_model.from_json(payload)
Path(flag_dir).mkdir(exist_ok=True)
return payload.copy_to_dir(flag_dir).model_dump_json()
payload = payload.copy_to_dir(flag_dir).model_dump()
if not isinstance(payload, str):
payload = json.dumps(payload)
return payload
def read_from_flag(self, payload: Any):

View File

@ -2,7 +2,6 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, Callable, Literal
from gradio_client.documentation import document
@ -156,9 +155,6 @@ class Code(Component):
else:
return value.strip()
def flag(self, payload: Any, flag_dir: str | Path = "") -> str:
return super().flag(payload, flag_dir)
def api_info(self) -> dict[str, Any]:
return {"type": "string"}

View File

@ -3,7 +3,6 @@
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Callable
from gradio_client.documentation import document
@ -94,13 +93,6 @@ class JSON(Component):
def example_value(self) -> Any:
return {"foo": "bar"}
def flag(
self,
payload: Any,
flag_dir: str | Path = "", # noqa: ARG002
) -> str:
return json.dumps(payload)
def read_from_flag(self, payload: Any):
return json.loads(payload)

View File

@ -127,8 +127,8 @@ class CSVLogger(FlaggingCallback):
Guides: using-flagging
"""
def __init__(self):
pass
def __init__(self, simplify_file_data: bool = True):
self.simplify_file_data = simplify_file_data
def setup(
self,
@ -167,11 +167,14 @@ class CSVLogger(FlaggingCallback):
if utils.is_update(sample):
csv_data.append(str(sample))
else:
csv_data.append(
data = (
component.flag(sample, flag_dir=save_dir)
if sample is not None
else ""
)
if self.simplify_file_data:
data = utils.simplify_file_data_in_str(data)
csv_data.append(data)
csv_data.append(flag_option)
csv_data.append(username if username is not None else "")
csv_data.append(str(datetime.datetime.now()))
@ -416,7 +419,9 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
label = component.label or ""
save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
save_dir.mkdir(exist_ok=True, parents=True)
deserialized = component.flag(sample, save_dir)
deserialized = utils.simplify_file_data_in_str(
component.flag(sample, save_dir)
)
# Add deserialized object to row
features[label] = {"dtype": "string", "_type": "Value"}

View File

@ -309,7 +309,7 @@ class Examples:
)
else:
print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'")
cache_logger = CSVLogger()
cache_logger = CSVLogger(simplify_file_data=False)
generated_values = []
if inspect.isgeneratorfunction(self.fn):

View File

@ -25,7 +25,7 @@ from gradio.components import (
get_component_instance,
)
from gradio.data_classes import InterfaceTypes
from gradio.events import Events, on
from gradio.events import Dependency, Events, on
from gradio.exceptions import RenderError
from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
from gradio.layouts import Accordion, Column, Row, Tab, Tabs
@ -105,7 +105,10 @@ class Interface(Blocks):
thumbnail: str | None = None,
theme: Theme | str | None = None,
css: str | None = None,
allow_flagging: str | None = None,
allow_flagging: Literal["never"]
| Literal["auto"]
| Literal["manual"]
| None = None,
flagging_options: list[str] | list[tuple[str, str]] | None = None,
flagging_dir: str = "flagged",
flagging_callback: FlaggingCallback | None = None,
@ -142,7 +145,7 @@ class Interface(Blocks):
thumbnail: String path or url to image to use as display image when the web demo is shared on social media.
theme: A Theme object or a string representing a theme. If a string, will look for a built-in theme with that name (e.g. "soft" or "default"), or will attempt to load a theme from the Hugging Face Hub (e.g. "gradio/monochrome"). If None, will use the Default theme.
css: Custom css as a string or path to a css file. This css will be included in the demo webpage.
allow_flagging: One of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every input the user submits will be automatically flagged (outputs are not flagged). If "manual", both the input and outputs are flagged when the user clicks flag button. This parameter can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual".
allow_flagging: One of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every input the user submits will be automatically flagged, along with the generated output. If "manual", both the input and outputs are flagged when the user clicks flag button. This parameter can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual".
flagging_options: If provided, allows user to select from the list of options when flagging. Only applies if allow_flagging is "manual". Can either be a list of tuples of the form (label, value), where label is the string that will be displayed on the button and value is the string that will be stored in the flagging CSV; or it can be a list of strings ["X", "Y"], in which case the values will be the list of strings and the labels will ["Flag as X", "Flag as Y"], etc.
flagging_dir: What to name the directory where flagged data is stored.
flagging_callback: None or an instance of a subclass of FlaggingCallback which will be called when a sample is flagged. If set to None, an instance of gradio.flagging.CSVLogger will be created and logs will be saved to a local CSV file in flagging_dir. Default to None.
@ -360,7 +363,7 @@ class Interface(Blocks):
# For allow_flagging: (1) first check for parameter,
# (2) check for env variable, (3) default to True/"manual"
if allow_flagging is None:
allow_flagging = os.getenv("GRADIO_ALLOW_FLAGGING", "manual")
allow_flagging = os.getenv("GRADIO_ALLOW_FLAGGING", "manual") # type: ignore
if allow_flagging is True:
warnings.warn(
"The `allow_flagging` parameter in `Interface` now"
@ -451,10 +454,7 @@ class Interface(Blocks):
component.label = f"output {i}"
if self.allow_flagging != "never":
if (
self.interface_type == InterfaceTypes.UNIFIED
or self.allow_flagging == "auto"
):
if self.interface_type == InterfaceTypes.UNIFIED:
self.flagging_callback.setup(self.input_components, self.flagging_dir) # type: ignore
elif self.interface_type == InterfaceTypes.INPUT_ONLY:
pass
@ -509,12 +509,12 @@ class Interface(Blocks):
if _clear_btn is None:
raise RenderError("Clear button not rendered")
self.attach_submit_events(_submit_btn, _stop_btn)
_submit_event = self.attach_submit_events(_submit_btn, _stop_btn)
self.attach_clear_events(_clear_btn, input_component_column)
if duplicate_btn is not None:
duplicate_btn.activate()
self.attach_flagging_events(flag_btns, _clear_btn)
self.attach_flagging_events(flag_btns, _clear_btn, _submit_event)
self.render_examples()
self.render_article()
@ -648,7 +648,7 @@ class Interface(Blocks):
def attach_submit_events(
self, _submit_btn: Button | None, _stop_btn: Button | None
):
) -> Dependency:
if self.live:
if self.interface_type == InterfaceTypes.OUTPUT_ONLY:
if _submit_btn is None:
@ -656,7 +656,7 @@ class Interface(Blocks):
super().load(self.fn, None, self.output_components)
# For output-only interfaces, the user probably still want a "generate"
# button even if the Interface is live
_submit_btn.click(
return _submit_btn.click(
self.fn,
None,
self.output_components,
@ -675,7 +675,7 @@ class Interface(Blocks):
streaming_event = True
elif component.has_event("change"):
events.append(component.change) # type: ignore
on(
return on(
events,
self.fn,
self.input_components,
@ -726,7 +726,7 @@ class Interface(Blocks):
concurrency_limit=self.concurrency_limit,
)
predict_event.then(
final_event = predict_event.then(
cleanup,
inputs=None,
outputs=extra_output, # type: ignore
@ -742,8 +742,9 @@ class Interface(Blocks):
queue=False,
show_api=False,
)
return final_event
else:
on(
return on(
triggers,
fn,
self.input_components,
@ -783,7 +784,10 @@ class Interface(Blocks):
)
def attach_flagging_events(
self, flag_btns: list[Button] | None, _clear_btn: ClearButton
self,
flag_btns: list[Button] | None,
_clear_btn: ClearButton,
_submit_event: Dependency,
):
if not (
flag_btns
@ -800,9 +804,9 @@ class Interface(Blocks):
flag_method = FlagMethod(
self.flagging_callback, "", "", visual_feedback=False
)
flag_btns[0].click( # flag_btns[0] is just the "Submit" button
_submit_event.success(
flag_method,
inputs=self.input_components,
inputs=self.input_components + self.output_components,
outputs=None,
preprocess=False,
queue=False,

View File

@ -39,6 +39,7 @@ from typing import (
)
import anyio
import gradio_client.utils as client_utils
import httpx
from gradio_client.documentation import document
from typing_extensions import ParamSpec
@ -1152,3 +1153,19 @@ def get_upload_folder() -> str:
return os.environ.get("GRADIO_TEMP_DIR") or str(
(Path(tempfile.gettempdir()) / "gradio").resolve()
)
def simplify_file_data_in_str(s):
"""
If a FileData dictionary has been dumped as part of a string, this function will replace the dict with just the str filepath
"""
try:
payload = json.loads(s)
except json.JSONDecodeError:
return s
payload = client_utils.traverse(
payload, lambda x: x["path"], client_utils.is_file_obj_with_meta
)
if isinstance(payload, str):
return payload
return json.dumps(payload)

View File

@ -1,4 +1,5 @@
import os
import pathlib
import tempfile
from unittest.mock import MagicMock, patch
@ -21,6 +22,24 @@ class TestDefaultFlagging:
assert row_count == 2 # 3 rows written including header
io.close()
def test_files_saved_as_file_paths(self):
image = {"path": str(pathlib.Path(__file__).parent / "test_files" / "bus.png")}
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
lambda x: x,
"image",
"image",
flagging_dir=tmpdirname,
allow_flagging="auto",
)
io.launch(prevent_thread_lock=True)
io.flagging_callback.flag([image, image])
io.close()
with open(os.path.join(tmpdirname, "log.csv")) as f:
flagged_data = f.readlines()[1].split(",")[0]
assert flagged_data.endswith("bus.png")
io.close()
def test_flagging_does_not_create_unnecessary_directories(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)