mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
Fixes flagging when allow_flagging
is set to "auto" (#2695)
* flagging fix * formatting * changelog * auto * auto
This commit is contained in:
parent
96297c0bad
commit
0879f0a296
@ -23,7 +23,7 @@ gr.LinePlot(stocks,
|
||||
By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2807](https://github.com/gradio-app/gradio/pull/2807)
|
||||
|
||||
## Bug Fixes:
|
||||
No changes to highlight.
|
||||
* Fixes flagging when `allow_flagging="auto"` in `gr.Interface()` by [@abidlabs](https://github.com/abidlabs) in [PR 2695](https://github.com/gradio-app/gradio/pull/2695)
|
||||
|
||||
## Documentation Changes:
|
||||
* Added a Guide on using BigQuery with Gradio's `DataFrame` and `ScatterPlot` component,
|
||||
|
@ -548,3 +548,17 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
|
||||
def dump_json(self, thing: dict, file_path: str) -> None:
|
||||
with open(file_path, "w+", encoding="utf8") as f:
|
||||
json.dump(thing, f)
|
||||
|
||||
|
||||
class FlagMethod:
|
||||
"""
|
||||
Helper class that contains the flagging button option and callback
|
||||
"""
|
||||
|
||||
def __init__(self, flagging_callback, flag_option=None):
|
||||
self.flagging_callback = flagging_callback
|
||||
self.flag_option = flag_option
|
||||
self.__name__ = "Flag"
|
||||
|
||||
def __call__(self, *flag_data):
|
||||
self.flagging_callback.flag(flag_data, flag_option=self.flag_option)
|
||||
|
@ -32,9 +32,9 @@ from gradio.components import (
|
||||
)
|
||||
from gradio.documentation import document, set_documentation_group
|
||||
from gradio.events import Changeable, Streamable
|
||||
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
|
||||
from gradio.flagging import CSVLogger, FlaggingCallback, FlagMethod
|
||||
from gradio.layouts import Column, Row, TabItem, Tabs
|
||||
from gradio.pipelines import load_from_pipeline # type: ignore
|
||||
from gradio.pipelines import load_from_pipeline
|
||||
|
||||
set_documentation_group("interface")
|
||||
|
||||
@ -172,7 +172,7 @@ class Interface(Blocks):
|
||||
thumbnail: path or url to image to use as display image when the web demo is shared on social media.
|
||||
theme: Theme to use - right now, only "default" is supported. Can be set with the GRADIO_THEME environment variable.
|
||||
css: custom css or path to custom css file to use with interface.
|
||||
allow_flagging: one of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every prediction will be automatically flagged. If "manual", samples are flagged when the user clicks flag button. Can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual".
|
||||
allow_flagging: one of "never", "auto", or "manual". If "never" or "auto", users will not see a button to flag an input and output. If "manual", users will see a button to flag. If "auto", every input the user submits will be automatically flagged (outputs are not flagged). If "manual", both the input and outputs are flagged when the user clicks flag button. This parameter can be set with environmental variable GRADIO_ALLOW_FLAGGING; otherwise defaults to "manual".
|
||||
flagging_options: if provided, allows user to select from the list of options when flagging. Only applies if allow_flagging is "manual".
|
||||
flagging_dir: what to name the directory where flagged data is stored.
|
||||
flagging_callback: An instance of a subclass of FlaggingCallback which will be called when a sample is flagged. By default logs to a local CSV file.
|
||||
@ -416,7 +416,10 @@ class Interface(Blocks):
|
||||
component.label = "output " + str(i)
|
||||
|
||||
if self.allow_flagging != "never":
|
||||
if self.interface_type == self.InterfaceTypes.UNIFIED:
|
||||
if (
|
||||
self.interface_type == self.InterfaceTypes.UNIFIED
|
||||
or self.allow_flagging == "auto"
|
||||
):
|
||||
self.flagging_callback.setup(self.input_components, self.flagging_dir)
|
||||
elif self.interface_type == self.InterfaceTypes.INPUT_ONLY:
|
||||
pass
|
||||
@ -608,22 +611,18 @@ class Interface(Blocks):
|
||||
""",
|
||||
)
|
||||
|
||||
class FlagMethod:
|
||||
def __init__(self, flagging_callback, flag_option=None):
|
||||
self.flagging_callback = flagging_callback
|
||||
self.flag_option = flag_option
|
||||
self.__name__ = "Flag"
|
||||
|
||||
def __call__(self, *flag_data):
|
||||
self.flagging_callback.flag(flag_data, flag_option=self.flag_option)
|
||||
|
||||
if self.allow_flagging == "manual":
|
||||
if self.allow_flagging in ["manual", "auto"]:
|
||||
if self.interface_type in [
|
||||
self.InterfaceTypes.STANDARD,
|
||||
self.InterfaceTypes.OUTPUT_ONLY,
|
||||
self.InterfaceTypes.UNIFIED,
|
||||
]:
|
||||
if self.interface_type == self.InterfaceTypes.UNIFIED:
|
||||
if self.allow_flagging == "auto":
|
||||
flag_btns = [(submit_btn, None)]
|
||||
if (
|
||||
self.interface_type == self.InterfaceTypes.UNIFIED
|
||||
or self.allow_flagging == "auto"
|
||||
):
|
||||
flag_components = self.input_components
|
||||
else:
|
||||
flag_components = self.input_components + self.output_components
|
||||
|
@ -3,6 +3,7 @@ import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import huggingface_hub
|
||||
import pytest
|
||||
|
||||
import gradio as gr
|
||||
from gradio import flagging
|
||||
@ -117,3 +118,19 @@ class TestDisableFlagging:
|
||||
self.fail("launch() raised a PermissionError unexpectedly")
|
||||
|
||||
io.close()
|
||||
|
||||
|
||||
class TestInterfaceConstructsFlagMethod:
|
||||
@pytest.mark.parametrize(
|
||||
"allow_flagging, called",
|
||||
[
|
||||
("manual", True),
|
||||
("auto", True),
|
||||
("never", False),
|
||||
],
|
||||
)
|
||||
def test_flag_method_init_called(self, allow_flagging, called):
|
||||
flagging.FlagMethod.__init__ = MagicMock()
|
||||
flagging.FlagMethod.__init__.return_value = None
|
||||
gr.Interface(lambda x: x, "text", "text", allow_flagging=allow_flagging)
|
||||
assert flagging.FlagMethod.__init__.called == called
|
||||
|
Loading…
Reference in New Issue
Block a user