Fixes flagging when allow_flagging is set to "auto" (#2695)

* flagging fix

* formatting

* changelog

* auto

* auto
This commit is contained in:
Abubakar Abid 2022-12-20 15:27:14 -06:00 committed by GitHub
parent 96297c0bad
commit 0879f0a296
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 46 additions and 16 deletions

View File

@ -23,7 +23,7 @@ gr.LinePlot(stocks,
By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2807](https://github.com/gradio-app/gradio/pull/2807)
## Bug Fixes:
No changes to highlight.
* Fixes flagging when `allow_flagging="auto"` in `gr.Interface()` by [@abidlabs](https://github.com/abidlabs) in [PR 2695](https://github.com/gradio-app/gradio/pull/2695)
## Documentation Changes:
* Added a Guide on using BigQuery with Gradio's `DataFrame` and `ScatterPlot` component,

View File

@ -548,3 +548,17 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
def dump_json(self, thing: dict, file_path: str) -> None:
with open(file_path, "w+", encoding="utf8") as f:
json.dump(thing, f)
class FlagMethod:
"""
Helper class that contains the flagging button option and callback
"""
def __init__(self, flagging_callback, flag_option=None):
self.flagging_callback = flagging_callback
self.flag_option = flag_option
self.__name__ = "Flag"
def __call__(self, *flag_data):
self.flagging_callback.flag(flag_data, flag_option=self.flag_option)

View File

@ -32,9 +32,9 @@ from gradio.components import (
)
from gradio.documentation import document, set_documentation_group
from gradio.events import Changeable, Streamable
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
from gradio.layouts import Column, Row, TabItem, Tabs
from gradio.pipelines import load_from_pipeline # type: ignore
from gradio.pipelines import load_from_pipeline
set_documentation_group("interface")
@ -172,7 +172,7 @@ class Interface(Blocks):
thumbnail: path or url to image to use as display image when the web demo is shared on social media.
theme: Theme to use - right now, only "default" is supported. Can be set with the GRADIO_THEME environment variable.
css: custom css or path to custom css file to use with interface.
allow_flagging: one of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every prediction will be automatically flagged. If "manual", samples are flagged when the user clicks flag button. Can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual".
allow_flagging: one of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every input the user submits will be automatically flagged (outputs are not flagged). If "manual", both the input and outputs are flagged when the user clicks flag button. This parameter can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual".
flagging_options: if provided, allows user to select from the list of options when flagging. Only applies if allow_flagging is "manual".
flagging_dir: what to name the directory where flagged data is stored.
flagging_callback: An instance of a subclass of FlaggingCallback which will be called when a sample is flagged. By default logs to a local CSV file.
@ -416,7 +416,10 @@ class Interface(Blocks):
component.label = "output " + str(i)
if self.allow_flagging != "never":
if self.interface_type == self.InterfaceTypes.UNIFIED:
if (
self.interface_type == self.InterfaceTypes.UNIFIED
or self.allow_flagging == "auto"
):
self.flagging_callback.setup(self.input_components, self.flagging_dir)
elif self.interface_type == self.InterfaceTypes.INPUT_ONLY:
pass
@ -608,22 +611,18 @@ class Interface(Blocks):
""",
)
class FlagMethod:
def __init__(self, flagging_callback, flag_option=None):
self.flagging_callback = flagging_callback
self.flag_option = flag_option
self.__name__ = "Flag"
def __call__(self, *flag_data):
self.flagging_callback.flag(flag_data, flag_option=self.flag_option)
if self.allow_flagging == "manual":
if self.allow_flagging in ["manual", "auto"]:
if self.interface_type in [
self.InterfaceTypes.STANDARD,
self.InterfaceTypes.OUTPUT_ONLY,
self.InterfaceTypes.UNIFIED,
]:
if self.interface_type == self.InterfaceTypes.UNIFIED:
if self.allow_flagging == "auto":
flag_btns = [(submit_btn, None)]
if (
self.interface_type == self.InterfaceTypes.UNIFIED
or self.allow_flagging == "auto"
):
flag_components = self.input_components
else:
flag_components = self.input_components + self.output_components

View File

@ -3,6 +3,7 @@ import tempfile
from unittest.mock import MagicMock
import huggingface_hub
import pytest
import gradio as gr
from gradio import flagging
@ -117,3 +118,19 @@ class TestDisableFlagging:
self.fail("launch() raised a PermissionError unexpectedly")
io.close()
class TestInterfaceConstructsFlagMethod:
@pytest.mark.parametrize(
"allow_flagging, called",
[
("manual", True),
("auto", True),
("never", False),
],
)
def test_flag_method_init_called(self, allow_flagging, called):
flagging.FlagMethod.__init__ = MagicMock()
flagging.FlagMethod.__init__.return_value = None
gr.Interface(lambda x: x, "text", "text", allow_flagging=allow_flagging)
assert flagging.FlagMethod.__init__.called == called