2022-01-26 13:44:41 +08:00
|
|
|
import os
|
2021-11-02 00:40:51 +08:00
|
|
|
import tempfile
|
2023-05-02 00:59:41 +08:00
|
|
|
from unittest.mock import MagicMock, patch
|
2022-01-26 13:44:41 +08:00
|
|
|
|
2022-12-21 05:27:14 +08:00
|
|
|
import pytest
|
2021-11-02 00:40:51 +08:00
|
|
|
|
2022-01-21 21:44:12 +08:00
|
|
|
import gradio as gr
|
|
|
|
from gradio import flagging
|
|
|
|
|
2022-08-26 04:23:28 +08:00
|
|
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
|
|
|
|
2022-03-26 02:14:42 +08:00
|
|
|
|
2022-11-08 08:37:55 +08:00
|
|
|
class TestDefaultFlagging:
|
2022-03-26 02:12:45 +08:00
|
|
|
def test_default_flagging_callback(self):
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
|
|
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
|
|
|
|
io.launch(prevent_thread_lock=True)
|
|
|
|
row_count = io.flagging_callback.flag(["test", "test"])
|
2022-11-08 08:37:55 +08:00
|
|
|
assert row_count == 1 # 2 rows written including header
|
2022-03-26 02:12:45 +08:00
|
|
|
row_count = io.flagging_callback.flag(["test", "test"])
|
2022-11-08 08:37:55 +08:00
|
|
|
assert row_count == 2 # 3 rows written including header
|
2022-03-26 02:12:45 +08:00
|
|
|
io.close()
|
2021-11-02 00:40:51 +08:00
|
|
|
|
2022-01-24 12:54:48 +08:00
|
|
|
|
2022-11-08 08:37:55 +08:00
|
|
|
class TestSimpleFlagging:
|
2022-03-26 02:12:45 +08:00
|
|
|
def test_simple_csv_flagging_callback(self):
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
|
|
io = gr.Interface(
|
|
|
|
lambda x: x,
|
|
|
|
"text",
|
|
|
|
"text",
|
|
|
|
flagging_dir=tmpdirname,
|
|
|
|
flagging_callback=flagging.SimpleCSVLogger(),
|
|
|
|
)
|
|
|
|
io.launch(prevent_thread_lock=True)
|
|
|
|
row_count = io.flagging_callback.flag(["test", "test"])
|
2022-11-08 08:37:55 +08:00
|
|
|
assert row_count == 0 # no header in SimpleCSVLogger
|
2022-03-26 02:12:45 +08:00
|
|
|
row_count = io.flagging_callback.flag(["test", "test"])
|
2022-11-08 08:37:55 +08:00
|
|
|
assert row_count == 1 # no header in SimpleCSVLogger
|
2022-03-26 02:12:45 +08:00
|
|
|
io.close()
|
2021-11-02 00:40:51 +08:00
|
|
|
|
2022-02-09 02:56:13 +08:00
|
|
|
|
2022-11-08 08:37:55 +08:00
|
|
|
class TestHuggingFaceDatasetSaver:
|
2023-05-02 00:59:41 +08:00
|
|
|
@patch(
|
|
|
|
"huggingface_hub.create_repo",
|
|
|
|
return_value=MagicMock(repo_id="gradio-tests/test"),
|
|
|
|
)
|
|
|
|
@patch("huggingface_hub.hf_hub_download")
|
|
|
|
def test_saver_setup(self, mock_download, mock_create):
|
|
|
|
flagger = flagging.HuggingFaceDatasetSaver("test_token", "test")
|
2022-03-26 02:12:45 +08:00
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
|
|
flagger.setup([gr.Audio, gr.Textbox], tmpdirname)
|
2023-05-02 00:59:41 +08:00
|
|
|
mock_create.assert_called_once()
|
|
|
|
mock_download.assert_called()
|
|
|
|
|
|
|
|
@patch(
|
|
|
|
"huggingface_hub.create_repo",
|
|
|
|
return_value=MagicMock(repo_id="gradio-tests/test"),
|
|
|
|
)
|
|
|
|
@patch("huggingface_hub.hf_hub_download")
|
|
|
|
@patch("huggingface_hub.upload_folder")
|
|
|
|
@patch("huggingface_hub.upload_file")
|
|
|
|
def test_saver_flag_same_dir(
|
|
|
|
self, mock_upload_file, mock_upload, mock_download, mock_create
|
|
|
|
):
|
2022-03-26 02:12:45 +08:00
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
|
|
io = gr.Interface(
|
|
|
|
lambda x: x,
|
|
|
|
"text",
|
|
|
|
"text",
|
|
|
|
flagging_dir=tmpdirname,
|
|
|
|
flagging_callback=flagging.HuggingFaceDatasetSaver("test", "test"),
|
|
|
|
)
|
2023-03-01 02:29:34 +08:00
|
|
|
row_count = io.flagging_callback.flag(["test", "test"], "")
|
2022-11-08 08:37:55 +08:00
|
|
|
assert row_count == 1 # 2 rows written including header
|
2022-03-26 02:12:45 +08:00
|
|
|
row_count = io.flagging_callback.flag(["test", "test"])
|
2022-11-08 08:37:55 +08:00
|
|
|
assert row_count == 2 # 3 rows written including header
|
2023-05-02 00:59:41 +08:00
|
|
|
for _, _, filenames in os.walk(tmpdirname):
|
|
|
|
for f in filenames:
|
|
|
|
fname = os.path.basename(f)
|
|
|
|
assert fname in ["data.csv", "dataset_info.json"] or fname.endswith(
|
|
|
|
".lock"
|
|
|
|
)
|
|
|
|
|
|
|
|
@patch(
|
|
|
|
"huggingface_hub.create_repo",
|
|
|
|
return_value=MagicMock(repo_id="gradio-tests/test"),
|
|
|
|
)
|
|
|
|
@patch("huggingface_hub.hf_hub_download")
|
|
|
|
@patch("huggingface_hub.upload_folder")
|
|
|
|
@patch("huggingface_hub.upload_file")
|
|
|
|
def test_saver_flag_separate_dirs(
|
|
|
|
self, mock_upload_file, mock_upload, mock_download, mock_create
|
|
|
|
):
|
2022-08-24 07:01:37 +08:00
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
|
|
io = gr.Interface(
|
|
|
|
lambda x: x,
|
|
|
|
"text",
|
|
|
|
"text",
|
|
|
|
flagging_dir=tmpdirname,
|
2023-05-02 00:59:41 +08:00
|
|
|
flagging_callback=flagging.HuggingFaceDatasetSaver(
|
|
|
|
"test", "test", separate_dirs=True
|
|
|
|
),
|
2022-08-24 07:01:37 +08:00
|
|
|
)
|
2023-05-02 00:59:41 +08:00
|
|
|
row_count = io.flagging_callback.flag(["test", "test"], "")
|
|
|
|
assert row_count == 1 # 2 rows written including header
|
|
|
|
row_count = io.flagging_callback.flag(["test", "test"])
|
|
|
|
assert row_count == 2 # 3 rows written including header
|
|
|
|
for _, _, filenames in os.walk(tmpdirname):
|
|
|
|
for f in filenames:
|
|
|
|
fname = os.path.basename(f)
|
|
|
|
assert fname in [
|
|
|
|
"metadata.jsonl",
|
|
|
|
"dataset_info.json",
|
|
|
|
] or fname.endswith(".lock")
|
2022-08-24 07:01:37 +08:00
|
|
|
|
|
|
|
|
2022-11-08 08:37:55 +08:00
|
|
|
class TestDisableFlagging:
|
2022-02-15 18:14:47 +08:00
|
|
|
def test_flagging_no_permission_error_with_flagging_disabled(self):
|
2023-03-28 02:55:59 +08:00
|
|
|
tmpdirname = tempfile.mkdtemp()
|
|
|
|
os.chmod(tmpdirname, 0o444) # Make directory read-only
|
|
|
|
nonwritable_path = os.path.join(tmpdirname, "flagging_dir")
|
|
|
|
io = gr.Interface(
|
|
|
|
lambda x: x,
|
|
|
|
"text",
|
|
|
|
"text",
|
|
|
|
allow_flagging="never",
|
|
|
|
flagging_dir=nonwritable_path,
|
|
|
|
)
|
|
|
|
io.launch(prevent_thread_lock=True)
|
2022-02-15 18:14:47 +08:00
|
|
|
io.close()
|
2022-12-21 05:27:14 +08:00
|
|
|
|
|
|
|
|
2023-03-01 02:29:34 +08:00
|
|
|
class TestInterfaceSetsUpFlagging:
|
2022-12-21 05:27:14 +08:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"allow_flagging, called",
|
|
|
|
[
|
|
|
|
("manual", True),
|
|
|
|
("auto", True),
|
|
|
|
("never", False),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_flag_method_init_called(self, allow_flagging, called):
|
|
|
|
flagging.FlagMethod.__init__ = MagicMock()
|
|
|
|
flagging.FlagMethod.__init__.return_value = None
|
|
|
|
gr.Interface(lambda x: x, "text", "text", allow_flagging=allow_flagging)
|
|
|
|
assert flagging.FlagMethod.__init__.called == called
|
2023-03-01 02:29:34 +08:00
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"options, processed_options",
|
|
|
|
[
|
|
|
|
(None, [("Flag", "")]),
|
|
|
|
(["yes", "no"], [("Flag as yes", "yes"), ("Flag as no", "no")]),
|
|
|
|
([("abc", "de"), ("123", "45")], [("abc", "de"), ("123", "45")]),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_flagging_options_processed_correctly(self, options, processed_options):
|
|
|
|
io = gr.Interface(lambda x: x, "text", "text", flagging_options=options)
|
|
|
|
assert io.flagging_options == processed_options
|