From 04ddce05b3cdd63af432dc599810cfca77cfe2e6 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Tue, 28 Feb 2023 10:29:34 -0800 Subject: [PATCH] Some improvements to Flag (#3289) * Fixes to button disable * button * formatting * flagging fix * fixes * formatter * changelog * ormatting * tests * saving * adding optionality for flagging * updatest * error catching * updates * changelog * tests * typing * flag button * formatting * tests * tests * tests * increased latency * queue fix * clear * formatting * fix * fix tests --- CHANGELOG.md | 3 ++ gradio/flagging.py | 48 ++++++++++++++++++------- gradio/interface.py | 81 +++++++++++++++++++++++++++++-------------- test/test_examples.py | 8 ++--- test/test_flagging.py | 16 +++++++-- 5 files changed, 111 insertions(+), 45 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c7d3e83538..e9bb6aed60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,11 +23,14 @@ By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3297](https://git - Updated image upload component to accept all image formats, including lossless formats like .webp by [@fienestar](https://github.com/fienestar) in [PR 3225](https://github.com/gradio-app/gradio/pull/3225) - Adds a disabled mode to the `gr.Button` component by setting `interactive=False` by [@abidlabs](https://github.com/abidlabs) in [PR 3266](https://github.com/gradio-app/gradio/pull/3266) and [PR 3288](https://github.com/gradio-app/gradio/pull/3288) +- Adds visual feedback to the when the Flag button is clicked, by [@abidlabs](https://github.com/abidlabs) in [PR 3289](https://github.com/gradio-app/gradio/pull/3289) +- Adds ability to set `flagging_options` display text and saved flag separately by [@abidlabs](https://github.com/abidlabs) in [PR 3289](https://github.com/gradio-app/gradio/pull/3289) - Allow the setting of `brush_radius` for the `Image` component both as a default and via `Image.update()` by [@pngwn](https://github.com/pngwn) in [PR 3277](https://github.com/gradio-app/gradio/pull/3277) - Added `info=` argument to form components to enable extra context provided to users, by [@aliabid94](https://github.com/aliabid94) in [PR 3291](https://github.com/gradio-app/gradio/pull/3291) - Allow developers to access the username of a logged-in user from the `gr.Request()` object using the `.username` attribute by [@abidlabs](https://github.com/abidlabs) in [PR 3296](https://github.com/gradio-app/gradio/pull/3296) - Add `preview` option to `Gallery.style` that launches the gallery in preview mode when first loaded by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3345](https://github.com/gradio-app/gradio/pull/3345) + ## Bug Fixes: - Ensure `mirror_webcam` is always respected by [@pngwn](https://github.com/pngwn) in [PR 3245](https://github.com/gradio-app/gradio/pull/3245) - Fix issue where updated markdown links were not being opened in a new tab by [@gante](https://github.com/gante) in [PR 3236](https://github.com/gradio-app/gradio/pull/3236) diff --git a/gradio/flagging.py b/gradio/flagging.py index f5064cc4a7..af6e57a73e 100644 --- a/gradio/flagging.py +++ b/gradio/flagging.py @@ -5,6 +5,7 @@ import datetime import io import json import os +import time import uuid from abc import ABC, abstractmethod from distutils.version import StrictVersion @@ -89,7 +90,7 @@ class FlaggingCallback(ABC): def flag( self, flag_data: List[Any], - flag_option: str | None = None, + flag_option: str = "", flag_index: int | None = None, username: str | None = None, ) -> int: @@ -133,7 +134,7 @@ class SimpleCSVLogger(FlaggingCallback): def flag( self, flag_data: List[Any], - flag_option: str | None = None, + flag_option: str = "", flag_index: int | None = None, username: str | None = None, ) -> int: @@ -193,7 +194,7 @@ class CSVLogger(FlaggingCallback): def flag( self, flag_data: List[Any], - flag_option: str | None = None, + flag_option: str = "", flag_index: int | None = None, username: str | None = None, ) -> int: @@ -226,7 +227,7 @@ class CSVLogger(FlaggingCallback): if sample is not None else "" ) - csv_data.append(flag_option if flag_option is not None else "") + csv_data.append(flag_option) csv_data.append(username if username is not None else "") csv_data.append(str(datetime.datetime.now())) @@ -235,7 +236,7 @@ class CSVLogger(FlaggingCallback): content = list(csv.reader(file_content_)) header = content[0] flag_col_index = header.index("flag") - content[flag_index][flag_col_index] = flag_option # type: ignore + content[flag_index][flag_col_index] = flag_option output = io.StringIO() writer = csv.writer(output) writer.writerows(utils.sanitize_list_for_csv(content)) @@ -366,7 +367,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): def flag( self, flag_data: List[Any], - flag_option: str | None = None, + flag_option: str = "", flag_index: int | None = None, username: str | None = None, ) -> int: @@ -398,7 +399,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): csv_data.append( "{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath) ) - csv_data.append(flag_option if flag_option is not None else "") + csv_data.append(flag_option) writer.writerow(utils.sanitize_list_for_csv(csv_data)) if is_new: @@ -498,7 +499,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback): def flag( self, flag_data: List[Any], - flag_option: str | None = None, + flag_option: str = "", flag_index: int | None = None, username: str | None = None, ) -> str: @@ -550,7 +551,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback): csv_data.append(filepath) headers.append("flag") - csv_data.append(flag_option if flag_option is not None else "") + csv_data.append(flag_option) # Creates metadata dict from row data and dumps it metadata_dict = { @@ -575,13 +576,34 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback): class FlagMethod: """ - Helper class that contains the flagging button option and callback + Helper class that contains the flagging options and calls the flagging method. Also + provides visual feedback to the user when flag is clicked. """ - def __init__(self, flagging_callback: FlaggingCallback, flag_option=None): + def __init__( + self, + flagging_callback: FlaggingCallback, + label: str, + value: str, + visual_feedback: bool = True, + ): self.flagging_callback = flagging_callback - self.flag_option = flag_option + self.label = label + self.value = value self.__name__ = "Flag" + self.visual_feedback = visual_feedback def __call__(self, *flag_data): - self.flagging_callback.flag(list(flag_data), flag_option=self.flag_option) + try: + self.flagging_callback.flag(list(flag_data), flag_option=self.value) + except Exception as e: + print("Error while flagging: {}".format(e)) + if self.visual_feedback: + return "Error!" + if not self.visual_feedback: + return + time.sleep(0.8) # to provide enough time for the user to observe button change + return self.reset() + + def reset(self): + return gr.Button.update(value=self.label, interactive=True) diff --git a/gradio/interface.py b/gradio/interface.py index e9340e8df3..8fd6444549 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -135,7 +135,7 @@ class Interface(Blocks): theme: str = "default", css: str | None = None, allow_flagging: str | None = None, - flagging_options: List[str] | None = None, + flagging_options: List[str] | List[Tuple[str, str]] | None = None, flagging_dir: str = "flagged", flagging_callback: FlaggingCallback = CSVLogger(), analytics_enabled: bool | None = None, @@ -162,7 +162,7 @@ class Interface(Blocks): theme: Theme to use - right now, only "default" is supported. Can be set with the GRADIO_THEME environment variable. css: custom css or path to custom css file to use with interface. allow_flagging: one of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every input the user submits will be automatically flagged (outputs are not flagged). If "manual", both the input and outputs are flagged when the user clicks flag button. This parameter can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual". - flagging_options: if provided, allows user to select from the list of options when flagging. Only applies if allow_flagging is "manual". + flagging_options: if provided, allows user to select from the list of options when flagging. Only applies if allow_flagging is "manual". Can either be a list of tuples of the form (label, value), where label is the string that will be displayed on the button and value is the string that will be stored in the flagging CSV; or it can be a list of strings ["X", "Y"], in which case the values will be the list of strings and the labels will ["Flag as X", "Flag as Y"], etc. flagging_dir: what to name the directory where flagged data is stored. flagging_callback: An instance of a subclass of FlaggingCallback which will be called when a sample is flagged. By default logs to a local CSV file. analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True. @@ -352,7 +352,21 @@ class Interface(Blocks): "Must be: 'auto', 'manual', or 'never'." ) - self.flagging_options = flagging_options + if flagging_options is None: + self.flagging_options = [("Flag", "")] + elif not (isinstance(flagging_options, list)): + raise ValueError( + "flagging_options must be a list of strings or list of (string, string) tuples." + ) + elif all([isinstance(x, str) for x in flagging_options]): + self.flagging_options = [(f"Flag as {x}", x) for x in flagging_options] + elif all([isinstance(x, tuple) for x in flagging_options]): + self.flagging_options = flagging_options + else: + raise ValueError( + "flagging_options must be a list of strings or list of (string, string) tuples." + ) + self.flagging_callback = flagging_callback self.flagging_dir = flagging_dir self.batch = batch @@ -463,7 +477,7 @@ class Interface(Blocks): interpret_component_column, ) - self.render_flagging_buttons(flag_btns) + self.attach_flagging_events(flag_btns, clear_btn) self.render_examples() self.render_article() @@ -479,17 +493,8 @@ class Interface(Blocks): if self.description: Markdown(self.description) - def render_flag_btns(self) -> List[Tuple[Button, str | None]]: - if self.flagging_options is None: - return [(Button("Flag"), None)] - else: - return [ - ( - Button("Flag as " + flag_option), - flag_option, - ) - for flag_option in self.flagging_options - ] + def render_flag_btns(self) -> List[Button]: + return [Button(label) for label, _ in self.flagging_options] def render_input_column( self, @@ -497,7 +502,7 @@ class Interface(Blocks): Button | None, Button | None, Button | None, - List | None, + List[Button] | None, Column, Column | None, List[Interpretation] | None, @@ -539,7 +544,7 @@ class Interface(Blocks): if self.allow_flagging == "manual": flag_btns = self.render_flag_btns() elif self.allow_flagging == "auto": - flag_btns = [(submit_btn, None)] + flag_btns = [submit_btn] return ( submit_btn, clear_btn, @@ -576,7 +581,7 @@ class Interface(Blocks): flag_btns = self.render_flag_btns() elif self.allow_flagging == "auto": assert submit_btn is not None, "Submit button not rendered" - flag_btns = [(submit_btn, None)] + flag_btns = [submit_btn] if self.interpretation: interpretation_btn = Button("Interpret") @@ -733,29 +738,53 @@ class Interface(Blocks): preprocess=False, ) - def render_flagging_buttons(self, flag_btns: List | None): + def attach_flagging_events(self, flag_btns: List[Button] | None, clear_btn: Button): if flag_btns: if self.interface_type in [ InterfaceTypes.STANDARD, InterfaceTypes.OUTPUT_ONLY, InterfaceTypes.UNIFIED, ]: - if ( - self.interface_type == InterfaceTypes.UNIFIED - or self.allow_flagging == "auto" - ): + if self.allow_flagging == "auto": + flag_method = FlagMethod( + self.flagging_callback, "", "", visual_feedback=False + ) + flag_btns[0].click( # flag_btns[0] is just the "Submit" button + flag_method, + inputs=self.input_components, + outputs=None, + preprocess=False, + queue=False, + ) + return + + if self.interface_type == InterfaceTypes.UNIFIED: flag_components = self.input_components else: flag_components = self.input_components + self.output_components - for flag_btn, flag_option in flag_btns: - flag_method = FlagMethod(self.flagging_callback, flag_option) + + for flag_btn, (label, value) in zip(flag_btns, self.flagging_options): + assert isinstance(value, str) + flag_method = FlagMethod(self.flagging_callback, label, value) + flag_btn.click( + lambda: Button.update(value="Saving...", interactive=False), + None, + flag_btn, + queue=False, + ) flag_btn.click( flag_method, inputs=flag_components, - outputs=[], + outputs=flag_btn, preprocess=False, queue=False, ) + clear_btn.click( + flag_method.reset, + None, + flag_btn, + queue=False, + ) def render_examples(self): if self.examples: diff --git a/test/test_examples.py b/test/test_examples.py index 766e88b263..05dc4bae85 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -330,10 +330,10 @@ class TestProcessExamples: app, _, _ = io.launch(prevent_thread_lock=True) client = TestClient(app) - response = client.post("/api/predict/", json={"fn_index": 3, "data": [0]}) + response = client.post("/api/predict/", json={"fn_index": 5, "data": [0]}) assert response.json()["data"] == ["Hello,"] - response = client.post("/api/predict/", json={"fn_index": 3, "data": [1]}) + response = client.post("/api/predict/", json={"fn_index": 5, "data": [1]}) assert response.json()["data"] == ["Michael"] def test_end_to_end_cache_examples(self): @@ -351,8 +351,8 @@ class TestProcessExamples: app, _, _ = io.launch(prevent_thread_lock=True) client = TestClient(app) - response = client.post("/api/predict/", json={"fn_index": 3, "data": [0]}) + response = client.post("/api/predict/", json={"fn_index": 5, "data": [0]}) assert response.json()["data"] == ["Hello,", "World", "Hello, World"] - response = client.post("/api/predict/", json={"fn_index": 3, "data": [1]}) + response = client.post("/api/predict/", json={"fn_index": 5, "data": [1]}) assert response.json()["data"] == ["Michael", "Jordan", "Michael Jordan"] diff --git a/test/test_flagging.py b/test/test_flagging.py index 45afe8d8f7..9a3a1b5bc9 100644 --- a/test/test_flagging.py +++ b/test/test_flagging.py @@ -65,7 +65,7 @@ class TestHuggingFaceDatasetSaver: ) os.mkdir(os.path.join(tmpdirname, "test")) io.launch(prevent_thread_lock=True) - row_count = io.flagging_callback.flag(["test", "test"]) + row_count = io.flagging_callback.flag(["test", "test"], "") assert row_count == 1 # 2 rows written including header row_count = io.flagging_callback.flag(["test", "test"]) assert row_count == 2 # 3 rows written including header @@ -124,7 +124,7 @@ class TestDisableFlagging: io.close() -class TestInterfaceConstructsFlagMethod: +class TestInterfaceSetsUpFlagging: @pytest.mark.parametrize( "allow_flagging, called", [ @@ -138,3 +138,15 @@ class TestInterfaceConstructsFlagMethod: flagging.FlagMethod.__init__.return_value = None gr.Interface(lambda x: x, "text", "text", allow_flagging=allow_flagging) assert flagging.FlagMethod.__init__.called == called + + @pytest.mark.parametrize( + "options, processed_options", + [ + (None, [("Flag", "")]), + (["yes", "no"], [("Flag as yes", "yes"), ("Flag as no", "no")]), + ([("abc", "de"), ("123", "45")], [("abc", "de"), ("123", "45")]), + ], + ) + def test_flagging_options_processed_correctly(self, options, processed_options): + io = gr.Interface(lambda x: x, "text", "text", flagging_options=options) + assert io.flagging_options == processed_options