gradio/test/test_flagging.py

137 lines
5.1 KiB
Python
Raw Normal View History

2022-01-26 13:44:41 +08:00
import os
import tempfile
2022-01-26 13:44:41 +08:00
from unittest.mock import MagicMock
import huggingface_hub
import pytest
import gradio as gr
from gradio import flagging
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
2022-03-26 02:14:42 +08:00
class TestDefaultFlagging:
2022-03-26 02:12:45 +08:00
def test_default_flagging_callback(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
io.launch(prevent_thread_lock=True)
row_count = io.flagging_callback.flag(["test", "test"])
assert row_count == 1 # 2 rows written including header
2022-03-26 02:12:45 +08:00
row_count = io.flagging_callback.flag(["test", "test"])
assert row_count == 2 # 3 rows written including header
2022-03-26 02:12:45 +08:00
io.close()
2022-01-24 12:54:48 +08:00
class TestSimpleFlagging:
2022-03-26 02:12:45 +08:00
def test_simple_csv_flagging_callback(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
lambda x: x,
"text",
"text",
flagging_dir=tmpdirname,
flagging_callback=flagging.SimpleCSVLogger(),
)
io.launch(prevent_thread_lock=True)
row_count = io.flagging_callback.flag(["test", "test"])
assert row_count == 0 # no header in SimpleCSVLogger
2022-03-26 02:12:45 +08:00
row_count = io.flagging_callback.flag(["test", "test"])
assert row_count == 1 # no header in SimpleCSVLogger
2022-03-26 02:12:45 +08:00
io.close()
2022-02-09 02:56:13 +08:00
class TestHuggingFaceDatasetSaver:
2022-03-26 02:12:45 +08:00
def test_saver_setup(self):
huggingface_hub.create_repo = MagicMock()
huggingface_hub.Repository = MagicMock()
flagger = flagging.HuggingFaceDatasetSaver("test", "test")
with tempfile.TemporaryDirectory() as tmpdirname:
flagger.setup([gr.Audio, gr.Textbox], tmpdirname)
huggingface_hub.create_repo.assert_called_once()
2022-02-09 02:56:13 +08:00
2022-03-26 02:12:45 +08:00
def test_saver_flag(self):
huggingface_hub.create_repo = MagicMock()
huggingface_hub.Repository = MagicMock()
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
lambda x: x,
"text",
"text",
flagging_dir=tmpdirname,
flagging_callback=flagging.HuggingFaceDatasetSaver("test", "test"),
)
os.mkdir(os.path.join(tmpdirname, "test"))
io.launch(prevent_thread_lock=True)
row_count = io.flagging_callback.flag(["test", "test"])
assert row_count == 1 # 2 rows written including header
2022-03-26 02:12:45 +08:00
row_count = io.flagging_callback.flag(["test", "test"])
assert row_count == 2 # 3 rows written including header
2022-02-09 02:56:13 +08:00
class TestHuggingFaceDatasetJSONSaver:
def test_saver_setup(self):
huggingface_hub.create_repo = MagicMock()
huggingface_hub.Repository = MagicMock()
flagger = flagging.HuggingFaceDatasetJSONSaver("test", "test")
with tempfile.TemporaryDirectory() as tmpdirname:
flagger.setup([gr.Audio, gr.Textbox], tmpdirname)
huggingface_hub.create_repo.assert_called_once()
def test_saver_flag(self):
huggingface_hub.create_repo = MagicMock()
huggingface_hub.Repository = MagicMock()
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
lambda x: x,
"text",
"text",
flagging_dir=tmpdirname,
flagging_callback=flagging.HuggingFaceDatasetJSONSaver("test", "test"),
)
test_dir = os.path.join(tmpdirname, "test")
os.mkdir(test_dir)
io.launch(prevent_thread_lock=True)
row_unique_name = io.flagging_callback.flag(["test", "test"])
# Test existence of metadata.jsonl file for that example
assert os.path.isfile(
os.path.join(os.path.join(test_dir, row_unique_name), "metadata.jsonl")
)
class TestDisableFlagging:
def test_flagging_no_permission_error_with_flagging_disabled(self):
with tempfile.TemporaryDirectory() as tmpdirname:
os.chmod(tmpdirname, 0o444) # Make directory read-only
nonwritable_path = os.path.join(tmpdirname, "flagging_dir")
io = gr.Interface(
lambda x: x,
"text",
"text",
allow_flagging="never",
flagging_dir=nonwritable_path,
)
try:
io.launch(prevent_thread_lock=True)
except PermissionError:
self.fail("launch() raised a PermissionError unexpectedly")
io.close()
class TestInterfaceConstructsFlagMethod:
@pytest.mark.parametrize(
"allow_flagging, called",
[
("manual", True),
("auto", True),
("never", False),
],
)
def test_flag_method_init_called(self, allow_flagging, called):
flagging.FlagMethod.__init__ = MagicMock()
flagging.FlagMethod.__init__.return_value = None
gr.Interface(lambda x: x, "text", "text", allow_flagging=allow_flagging)
assert flagging.FlagMethod.__init__.called == called