Some improvements to Flag (#3289)

* Fixes to button disable

* button

* formatting

* flagging fix

* fixes

* formatter

* changelog

* ormatting

* tests

* saving

* adding optionality for flagging

* updatest

* error catching

* updates

* changelog

* tests

* typing

* flag button

* formatting

* tests

* tests

* tests

* increased latency

* queue fix

* clear

* formatting

* fix

* fix tests
This commit is contained in:
Abubakar Abid 2023-02-28 10:29:34 -08:00 committed by GitHub
parent ce0bbdab89
commit 04ddce05b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 111 additions and 45 deletions

View File

@ -23,11 +23,14 @@ By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3297](https://git
- Updated image upload component to accept all image formats, including lossless formats like .webp by [@fienestar](https://github.com/fienestar) in [PR 3225](https://github.com/gradio-app/gradio/pull/3225)
- Adds a disabled mode to the `gr.Button` component by setting `interactive=False` by [@abidlabs](https://github.com/abidlabs) in [PR 3266](https://github.com/gradio-app/gradio/pull/3266) and [PR 3288](https://github.com/gradio-app/gradio/pull/3288)
- Adds visual feedback to the when the Flag button is clicked, by [@abidlabs](https://github.com/abidlabs) in [PR 3289](https://github.com/gradio-app/gradio/pull/3289)
- Adds ability to set `flagging_options` display text and saved flag separately by [@abidlabs](https://github.com/abidlabs) in [PR 3289](https://github.com/gradio-app/gradio/pull/3289)
- Allow the setting of `brush_radius` for the `Image` component both as a default and via `Image.update()` by [@pngwn](https://github.com/pngwn) in [PR 3277](https://github.com/gradio-app/gradio/pull/3277)
- Added `info=` argument to form components to enable extra context provided to users, by [@aliabid94](https://github.com/aliabid94) in [PR 3291](https://github.com/gradio-app/gradio/pull/3291)
- Allow developers to access the username of a logged-in user from the `gr.Request()` object using the `.username` attribute by [@abidlabs](https://github.com/abidlabs) in [PR 3296](https://github.com/gradio-app/gradio/pull/3296)
- Add `preview` option to `Gallery.style` that launches the gallery in preview mode when first loaded by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3345](https://github.com/gradio-app/gradio/pull/3345)
## Bug Fixes:
- Ensure `mirror_webcam` is always respected by [@pngwn](https://github.com/pngwn) in [PR 3245](https://github.com/gradio-app/gradio/pull/3245)
- Fix issue where updated markdown links were not being opened in a new tab by [@gante](https://github.com/gante) in [PR 3236](https://github.com/gradio-app/gradio/pull/3236)

View File

@ -5,6 +5,7 @@ import datetime
import io
import json
import os
import time
import uuid
from abc import ABC, abstractmethod
from distutils.version import StrictVersion
@ -89,7 +90,7 @@ class FlaggingCallback(ABC):
def flag(
self,
flag_data: List[Any],
flag_option: str | None = None,
flag_option: str = "",
flag_index: int | None = None,
username: str | None = None,
) -> int:
@ -133,7 +134,7 @@ class SimpleCSVLogger(FlaggingCallback):
def flag(
self,
flag_data: List[Any],
flag_option: str | None = None,
flag_option: str = "",
flag_index: int | None = None,
username: str | None = None,
) -> int:
@ -193,7 +194,7 @@ class CSVLogger(FlaggingCallback):
def flag(
self,
flag_data: List[Any],
flag_option: str | None = None,
flag_option: str = "",
flag_index: int | None = None,
username: str | None = None,
) -> int:
@ -226,7 +227,7 @@ class CSVLogger(FlaggingCallback):
if sample is not None
else ""
)
csv_data.append(flag_option if flag_option is not None else "")
csv_data.append(flag_option)
csv_data.append(username if username is not None else "")
csv_data.append(str(datetime.datetime.now()))
@ -235,7 +236,7 @@ class CSVLogger(FlaggingCallback):
content = list(csv.reader(file_content_))
header = content[0]
flag_col_index = header.index("flag")
content[flag_index][flag_col_index] = flag_option # type: ignore
content[flag_index][flag_col_index] = flag_option
output = io.StringIO()
writer = csv.writer(output)
writer.writerows(utils.sanitize_list_for_csv(content))
@ -366,7 +367,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
def flag(
self,
flag_data: List[Any],
flag_option: str | None = None,
flag_option: str = "",
flag_index: int | None = None,
username: str | None = None,
) -> int:
@ -398,7 +399,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
csv_data.append(
"{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
)
csv_data.append(flag_option if flag_option is not None else "")
csv_data.append(flag_option)
writer.writerow(utils.sanitize_list_for_csv(csv_data))
if is_new:
@ -498,7 +499,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
def flag(
self,
flag_data: List[Any],
flag_option: str | None = None,
flag_option: str = "",
flag_index: int | None = None,
username: str | None = None,
) -> str:
@ -550,7 +551,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
csv_data.append(filepath)
headers.append("flag")
csv_data.append(flag_option if flag_option is not None else "")
csv_data.append(flag_option)
# Creates metadata dict from row data and dumps it
metadata_dict = {
@ -575,13 +576,34 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
class FlagMethod:
"""
Helper class that contains the flagging button option and callback
Helper class that contains the flagging options and calls the flagging method. Also
provides visual feedback to the user when flag is clicked.
"""
def __init__(self, flagging_callback: FlaggingCallback, flag_option=None):
def __init__(
self,
flagging_callback: FlaggingCallback,
label: str,
value: str,
visual_feedback: bool = True,
):
self.flagging_callback = flagging_callback
self.flag_option = flag_option
self.label = label
self.value = value
self.__name__ = "Flag"
self.visual_feedback = visual_feedback
def __call__(self, *flag_data):
self.flagging_callback.flag(list(flag_data), flag_option=self.flag_option)
try:
self.flagging_callback.flag(list(flag_data), flag_option=self.value)
except Exception as e:
print("Error while flagging: {}".format(e))
if self.visual_feedback:
return "Error!"
if not self.visual_feedback:
return
time.sleep(0.8) # to provide enough time for the user to observe button change
return self.reset()
def reset(self):
return gr.Button.update(value=self.label, interactive=True)

View File

@ -135,7 +135,7 @@ class Interface(Blocks):
theme: str = "default",
css: str | None = None,
allow_flagging: str | None = None,
flagging_options: List[str] | None = None,
flagging_options: List[str] | List[Tuple[str, str]] | None = None,
flagging_dir: str = "flagged",
flagging_callback: FlaggingCallback = CSVLogger(),
analytics_enabled: bool | None = None,
@ -162,7 +162,7 @@ class Interface(Blocks):
theme: Theme to use - right now, only "default" is supported. Can be set with the GRADIO_THEME environment variable.
css: custom css or path to custom css file to use with interface.
allow_flagging: one of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every input the user submits will be automatically flagged (outputs are not flagged). If "manual", both the input and outputs are flagged when the user clicks flag button. This parameter can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual".
flagging_options: if provided, allows user to select from the list of options when flagging. Only applies if allow_flagging is "manual".
flagging_options: if provided, allows user to select from the list of options when flagging. Only applies if allow_flagging is "manual". Can either be a list of tuples of the form (label, value), where label is the string that will be displayed on the button and value is the string that will be stored in the flagging CSV; or it can be a list of strings ["X", "Y"], in which case the values will be the list of strings and the labels will ["Flag as X", "Flag as Y"], etc.
flagging_dir: what to name the directory where flagged data is stored.
flagging_callback: An instance of a subclass of FlaggingCallback which will be called when a sample is flagged. By default logs to a local CSV file.
analytics_enabled: Whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable if defined, or default to True.
@ -352,7 +352,21 @@ class Interface(Blocks):
"Must be: 'auto', 'manual', or 'never'."
)
self.flagging_options = flagging_options
if flagging_options is None:
self.flagging_options = [("Flag", "")]
elif not (isinstance(flagging_options, list)):
raise ValueError(
"flagging_options must be a list of strings or list of (string, string) tuples."
)
elif all([isinstance(x, str) for x in flagging_options]):
self.flagging_options = [(f"Flag as {x}", x) for x in flagging_options]
elif all([isinstance(x, tuple) for x in flagging_options]):
self.flagging_options = flagging_options
else:
raise ValueError(
"flagging_options must be a list of strings or list of (string, string) tuples."
)
self.flagging_callback = flagging_callback
self.flagging_dir = flagging_dir
self.batch = batch
@ -463,7 +477,7 @@ class Interface(Blocks):
interpret_component_column,
)
self.render_flagging_buttons(flag_btns)
self.attach_flagging_events(flag_btns, clear_btn)
self.render_examples()
self.render_article()
@ -479,17 +493,8 @@ class Interface(Blocks):
if self.description:
Markdown(self.description)
def render_flag_btns(self) -> List[Tuple[Button, str | None]]:
if self.flagging_options is None:
return [(Button("Flag"), None)]
else:
return [
(
Button("Flag as " + flag_option),
flag_option,
)
for flag_option in self.flagging_options
]
def render_flag_btns(self) -> List[Button]:
return [Button(label) for label, _ in self.flagging_options]
def render_input_column(
self,
@ -497,7 +502,7 @@ class Interface(Blocks):
Button | None,
Button | None,
Button | None,
List | None,
List[Button] | None,
Column,
Column | None,
List[Interpretation] | None,
@ -539,7 +544,7 @@ class Interface(Blocks):
if self.allow_flagging == "manual":
flag_btns = self.render_flag_btns()
elif self.allow_flagging == "auto":
flag_btns = [(submit_btn, None)]
flag_btns = [submit_btn]
return (
submit_btn,
clear_btn,
@ -576,7 +581,7 @@ class Interface(Blocks):
flag_btns = self.render_flag_btns()
elif self.allow_flagging == "auto":
assert submit_btn is not None, "Submit button not rendered"
flag_btns = [(submit_btn, None)]
flag_btns = [submit_btn]
if self.interpretation:
interpretation_btn = Button("Interpret")
@ -733,29 +738,53 @@ class Interface(Blocks):
preprocess=False,
)
def render_flagging_buttons(self, flag_btns: List | None):
def attach_flagging_events(self, flag_btns: List[Button] | None, clear_btn: Button):
if flag_btns:
if self.interface_type in [
InterfaceTypes.STANDARD,
InterfaceTypes.OUTPUT_ONLY,
InterfaceTypes.UNIFIED,
]:
if (
self.interface_type == InterfaceTypes.UNIFIED
or self.allow_flagging == "auto"
):
if self.allow_flagging == "auto":
flag_method = FlagMethod(
self.flagging_callback, "", "", visual_feedback=False
)
flag_btns[0].click( # flag_btns[0] is just the "Submit" button
flag_method,
inputs=self.input_components,
outputs=None,
preprocess=False,
queue=False,
)
return
if self.interface_type == InterfaceTypes.UNIFIED:
flag_components = self.input_components
else:
flag_components = self.input_components + self.output_components
for flag_btn, flag_option in flag_btns:
flag_method = FlagMethod(self.flagging_callback, flag_option)
for flag_btn, (label, value) in zip(flag_btns, self.flagging_options):
assert isinstance(value, str)
flag_method = FlagMethod(self.flagging_callback, label, value)
flag_btn.click(
lambda: Button.update(value="Saving...", interactive=False),
None,
flag_btn,
queue=False,
)
flag_btn.click(
flag_method,
inputs=flag_components,
outputs=[],
outputs=flag_btn,
preprocess=False,
queue=False,
)
clear_btn.click(
flag_method.reset,
None,
flag_btn,
queue=False,
)
def render_examples(self):
if self.examples:

View File

@ -330,10 +330,10 @@ class TestProcessExamples:
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"fn_index": 3, "data": [0]})
response = client.post("/api/predict/", json={"fn_index": 5, "data": [0]})
assert response.json()["data"] == ["Hello,"]
response = client.post("/api/predict/", json={"fn_index": 3, "data": [1]})
response = client.post("/api/predict/", json={"fn_index": 5, "data": [1]})
assert response.json()["data"] == ["Michael"]
def test_end_to_end_cache_examples(self):
@ -351,8 +351,8 @@ class TestProcessExamples:
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"fn_index": 3, "data": [0]})
response = client.post("/api/predict/", json={"fn_index": 5, "data": [0]})
assert response.json()["data"] == ["Hello,", "World", "Hello, World"]
response = client.post("/api/predict/", json={"fn_index": 3, "data": [1]})
response = client.post("/api/predict/", json={"fn_index": 5, "data": [1]})
assert response.json()["data"] == ["Michael", "Jordan", "Michael Jordan"]

View File

@ -65,7 +65,7 @@ class TestHuggingFaceDatasetSaver:
)
os.mkdir(os.path.join(tmpdirname, "test"))
io.launch(prevent_thread_lock=True)
row_count = io.flagging_callback.flag(["test", "test"])
row_count = io.flagging_callback.flag(["test", "test"], "")
assert row_count == 1 # 2 rows written including header
row_count = io.flagging_callback.flag(["test", "test"])
assert row_count == 2 # 3 rows written including header
@ -124,7 +124,7 @@ class TestDisableFlagging:
io.close()
class TestInterfaceConstructsFlagMethod:
class TestInterfaceSetsUpFlagging:
@pytest.mark.parametrize(
"allow_flagging, called",
[
@ -138,3 +138,15 @@ class TestInterfaceConstructsFlagMethod:
flagging.FlagMethod.__init__.return_value = None
gr.Interface(lambda x: x, "text", "text", allow_flagging=allow_flagging)
assert flagging.FlagMethod.__init__.called == called
@pytest.mark.parametrize(
"options, processed_options",
[
(None, [("Flag", "")]),
(["yes", "no"], [("Flag as yes", "yes"), ("Flag as no", "no")]),
([("abc", "de"), ("123", "45")], [("abc", "de"), ("123", "45")]),
],
)
def test_flagging_options_processed_correctly(self, options, processed_options):
io = gr.Interface(lambda x: x, "text", "text", flagging_options=options)
assert io.flagging_options == processed_options