From 95c6bc897be14e28a568242ea31516bfe2df13e8 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 14 Mar 2024 14:03:34 -0700 Subject: [PATCH] Fix flagged files and ensure that flagging_mode="auto" saves output components as well (#7704) * interface * docstring * changes * changes * add changeset * changes * add changeset * changes * add changeset * changes * changes * fix * changes * add changeset * flaggin * simplify * changes * helpers * filedata --------- Co-authored-by: gradio-pr-bot --- .changeset/social-lies-nail.md | 5 ++++ gradio/components/base.py | 8 +++--- gradio/components/code.py | 4 --- gradio/components/json_component.py | 8 ------ gradio/flagging.py | 13 +++++++--- gradio/helpers.py | 2 +- gradio/interface.py | 40 ++++++++++++++++------------- gradio/utils.py | 17 ++++++++++++ test/test_flagging.py | 19 ++++++++++++++ 9 files changed, 78 insertions(+), 38 deletions(-) create mode 100644 .changeset/social-lies-nail.md diff --git a/.changeset/social-lies-nail.md b/.changeset/social-lies-nail.md new file mode 100644 index 0000000000..76e33ec43f --- /dev/null +++ b/.changeset/social-lies-nail.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +fix:Fix flagged files and ensure that flagging_mode="auto" saves output components as well diff --git a/gradio/components/base.py b/gradio/components/base.py index 738da50fc6..f57fea5ad3 100644 --- a/gradio/components/base.py +++ b/gradio/components/base.py @@ -14,7 +14,7 @@ from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, Any, Callable -from gradio_client.utils import is_file_obj +import gradio_client.utils as client_utils from gradio import utils from gradio.blocks import Block, BlockContext @@ -202,7 +202,7 @@ class Component(ComponentBase, Block): postprocess=True, keep_in_cache=True, ) - if is_file_obj(self.value): + if client_utils.is_file_obj(self.value): self.keep_in_cache.add(self.value["path"]) if callable(load_fn): @@ -300,7 +300,9 @@ class Component(ComponentBase, Block): if self.data_model: payload = self.data_model.from_json(payload) Path(flag_dir).mkdir(exist_ok=True) - return payload.copy_to_dir(flag_dir).model_dump_json() + payload = payload.copy_to_dir(flag_dir).model_dump() + if not isinstance(payload, str): + payload = json.dumps(payload) return payload def read_from_flag(self, payload: Any): diff --git a/gradio/components/code.py b/gradio/components/code.py index 7add8f28b9..ed5209decb 100644 --- a/gradio/components/code.py +++ b/gradio/components/code.py @@ -2,7 +2,6 @@ from __future__ import annotations -from pathlib import Path from typing import Any, Callable, Literal from gradio_client.documentation import document @@ -156,9 +155,6 @@ class Code(Component): else: return value.strip() - def flag(self, payload: Any, flag_dir: str | Path = "") -> str: - return super().flag(payload, flag_dir) - def api_info(self) -> dict[str, Any]: return {"type": "string"} diff --git a/gradio/components/json_component.py b/gradio/components/json_component.py index 639f801cc5..331306cdf6 100644 --- a/gradio/components/json_component.py +++ b/gradio/components/json_component.py @@ -3,7 +3,6 @@ from __future__ import annotations import json -from pathlib import Path from typing import Any, Callable from gradio_client.documentation import document @@ -94,13 +93,6 @@ class JSON(Component): def example_value(self) -> Any: return {"foo": "bar"} - def flag( - self, - payload: Any, - flag_dir: str | Path = "", # noqa: ARG002 - ) -> str: - return json.dumps(payload) - def read_from_flag(self, payload: Any): return json.loads(payload) diff --git a/gradio/flagging.py b/gradio/flagging.py index 3f9fdaf33b..f8f7171253 100644 --- a/gradio/flagging.py +++ b/gradio/flagging.py @@ -127,8 +127,8 @@ class CSVLogger(FlaggingCallback): Guides: using-flagging """ - def __init__(self): - pass + def __init__(self, simplify_file_data: bool = True): + self.simplify_file_data = simplify_file_data def setup( self, @@ -167,11 +167,14 @@ class CSVLogger(FlaggingCallback): if utils.is_update(sample): csv_data.append(str(sample)) else: - csv_data.append( + data = ( component.flag(sample, flag_dir=save_dir) if sample is not None else "" ) + if self.simplify_file_data: + data = utils.simplify_file_data_in_str(data) + csv_data.append(data) csv_data.append(flag_option) csv_data.append(username if username is not None else "") csv_data.append(str(datetime.datetime.now())) @@ -416,7 +419,9 @@ class HuggingFaceDatasetSaver(FlaggingCallback): label = component.label or "" save_dir = data_dir / client_utils.strip_invalid_filename_characters(label) save_dir.mkdir(exist_ok=True, parents=True) - deserialized = component.flag(sample, save_dir) + deserialized = utils.simplify_file_data_in_str( + component.flag(sample, save_dir) + ) # Add deserialized object to row features[label] = {"dtype": "string", "_type": "Value"} diff --git a/gradio/helpers.py b/gradio/helpers.py index 346f354220..0c96d986e5 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -309,7 +309,7 @@ class Examples: ) else: print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'") - cache_logger = CSVLogger() + cache_logger = CSVLogger(simplify_file_data=False) generated_values = [] if inspect.isgeneratorfunction(self.fn): diff --git a/gradio/interface.py b/gradio/interface.py index 297851ae63..955e3e81d2 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -25,7 +25,7 @@ from gradio.components import ( get_component_instance, ) from gradio.data_classes import InterfaceTypes -from gradio.events import Events, on +from gradio.events import Dependency, Events, on from gradio.exceptions import RenderError from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod from gradio.layouts import Accordion, Column, Row, Tab, Tabs @@ -105,7 +105,10 @@ class Interface(Blocks): thumbnail: str | None = None, theme: Theme | str | None = None, css: str | None = None, - allow_flagging: str | None = None, + allow_flagging: Literal["never"] + | Literal["auto"] + | Literal["manual"] + | None = None, flagging_options: list[str] | list[tuple[str, str]] | None = None, flagging_dir: str = "flagged", flagging_callback: FlaggingCallback | None = None, @@ -142,7 +145,7 @@ class Interface(Blocks): thumbnail: String path or url to image to use as display image when the web demo is shared on social media. theme: A Theme object or a string representing a theme. If a string, will look for a built-in theme with that name (e.g. "soft" or "default"), or will attempt to load a theme from the Hugging Face Hub (e.g. "gradio/monochrome"). If None, will use the Default theme. css: Custom css as a string or path to a css file. This css will be included in the demo webpage. - allow_flagging: One of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every input the user submits will be automatically flagged (outputs are not flagged). If "manual", both the input and outputs are flagged when the user clicks flag button. This parameter can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual". + allow_flagging: One of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every input the user submits will be automatically flagged, along with the generated output. If "manual", both the input and outputs are flagged when the user clicks flag button. This parameter can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual". flagging_options: If provided, allows user to select from the list of options when flagging. Only applies if allow_flagging is "manual". Can either be a list of tuples of the form (label, value), where label is the string that will be displayed on the button and value is the string that will be stored in the flagging CSV; or it can be a list of strings ["X", "Y"], in which case the values will be the list of strings and the labels will ["Flag as X", "Flag as Y"], etc. flagging_dir: What to name the directory where flagged data is stored. flagging_callback: None or an instance of a subclass of FlaggingCallback which will be called when a sample is flagged. If set to None, an instance of gradio.flagging.CSVLogger will be created and logs will be saved to a local CSV file in flagging_dir. Default to None. @@ -360,7 +363,7 @@ class Interface(Blocks): # For allow_flagging: (1) first check for parameter, # (2) check for env variable, (3) default to True/"manual" if allow_flagging is None: - allow_flagging = os.getenv("GRADIO_ALLOW_FLAGGING", "manual") + allow_flagging = os.getenv("GRADIO_ALLOW_FLAGGING", "manual") # type: ignore if allow_flagging is True: warnings.warn( "The `allow_flagging` parameter in `Interface` now" @@ -451,10 +454,7 @@ class Interface(Blocks): component.label = f"output {i}" if self.allow_flagging != "never": - if ( - self.interface_type == InterfaceTypes.UNIFIED - or self.allow_flagging == "auto" - ): + if self.interface_type == InterfaceTypes.UNIFIED: self.flagging_callback.setup(self.input_components, self.flagging_dir) # type: ignore elif self.interface_type == InterfaceTypes.INPUT_ONLY: pass @@ -509,12 +509,12 @@ class Interface(Blocks): if _clear_btn is None: raise RenderError("Clear button not rendered") - self.attach_submit_events(_submit_btn, _stop_btn) + _submit_event = self.attach_submit_events(_submit_btn, _stop_btn) self.attach_clear_events(_clear_btn, input_component_column) if duplicate_btn is not None: duplicate_btn.activate() - self.attach_flagging_events(flag_btns, _clear_btn) + self.attach_flagging_events(flag_btns, _clear_btn, _submit_event) self.render_examples() self.render_article() @@ -648,7 +648,7 @@ class Interface(Blocks): def attach_submit_events( self, _submit_btn: Button | None, _stop_btn: Button | None - ): + ) -> Dependency: if self.live: if self.interface_type == InterfaceTypes.OUTPUT_ONLY: if _submit_btn is None: @@ -656,7 +656,7 @@ class Interface(Blocks): super().load(self.fn, None, self.output_components) # For output-only interfaces, the user probably still want a "generate" # button even if the Interface is live - _submit_btn.click( + return _submit_btn.click( self.fn, None, self.output_components, @@ -675,7 +675,7 @@ class Interface(Blocks): streaming_event = True elif component.has_event("change"): events.append(component.change) # type: ignore - on( + return on( events, self.fn, self.input_components, @@ -726,7 +726,7 @@ class Interface(Blocks): concurrency_limit=self.concurrency_limit, ) - predict_event.then( + final_event = predict_event.then( cleanup, inputs=None, outputs=extra_output, # type: ignore @@ -742,8 +742,9 @@ class Interface(Blocks): queue=False, show_api=False, ) + return final_event else: - on( + return on( triggers, fn, self.input_components, @@ -783,7 +784,10 @@ class Interface(Blocks): ) def attach_flagging_events( - self, flag_btns: list[Button] | None, _clear_btn: ClearButton + self, + flag_btns: list[Button] | None, + _clear_btn: ClearButton, + _submit_event: Dependency, ): if not ( flag_btns @@ -800,9 +804,9 @@ class Interface(Blocks): flag_method = FlagMethod( self.flagging_callback, "", "", visual_feedback=False ) - flag_btns[0].click( # flag_btns[0] is just the "Submit" button + _submit_event.success( flag_method, - inputs=self.input_components, + inputs=self.input_components + self.output_components, outputs=None, preprocess=False, queue=False, diff --git a/gradio/utils.py b/gradio/utils.py index f3f96f2379..ca31979876 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -39,6 +39,7 @@ from typing import ( ) import anyio +import gradio_client.utils as client_utils import httpx from gradio_client.documentation import document from typing_extensions import ParamSpec @@ -1152,3 +1153,19 @@ def get_upload_folder() -> str: return os.environ.get("GRADIO_TEMP_DIR") or str( (Path(tempfile.gettempdir()) / "gradio").resolve() ) + + +def simplify_file_data_in_str(s): + """ + If a FileData dictionary has been dumped as part of a string, this function will replace the dict with just the str filepath + """ + try: + payload = json.loads(s) + except json.JSONDecodeError: + return s + payload = client_utils.traverse( + payload, lambda x: x["path"], client_utils.is_file_obj_with_meta + ) + if isinstance(payload, str): + return payload + return json.dumps(payload) diff --git a/test/test_flagging.py b/test/test_flagging.py index 96a99c4f8b..4bd41d25d5 100644 --- a/test/test_flagging.py +++ b/test/test_flagging.py @@ -1,4 +1,5 @@ import os +import pathlib import tempfile from unittest.mock import MagicMock, patch @@ -21,6 +22,24 @@ class TestDefaultFlagging: assert row_count == 2 # 3 rows written including header io.close() + def test_files_saved_as_file_paths(self): + image = {"path": str(pathlib.Path(__file__).parent / "test_files" / "bus.png")} + with tempfile.TemporaryDirectory() as tmpdirname: + io = gr.Interface( + lambda x: x, + "image", + "image", + flagging_dir=tmpdirname, + allow_flagging="auto", + ) + io.launch(prevent_thread_lock=True) + io.flagging_callback.flag([image, image]) + io.close() + with open(os.path.join(tmpdirname, "log.csv")) as f: + flagged_data = f.readlines()[1].split(",")[0] + assert flagged_data.endswith("bus.png") + io.close() + def test_flagging_does_not_create_unnecessary_directories(self): with tempfile.TemporaryDirectory() as tmpdirname: io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)