mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-24 10:54:04 +08:00
61d2f15562
* hf flag fix * fixed huggingface hub params * formatting * fix flagging tests * add a try / catch
141 lines
5.4 KiB
Python
141 lines
5.4 KiB
Python
import os
|
|
import tempfile
|
|
from unittest.mock import MagicMock
|
|
|
|
import huggingface_hub
|
|
import pytest
|
|
|
|
import gradio as gr
|
|
from gradio import flagging
|
|
|
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
|
|
|
|
|
class TestDefaultFlagging:
|
|
def test_default_flagging_callback(self):
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
|
|
io.launch(prevent_thread_lock=True)
|
|
row_count = io.flagging_callback.flag(["test", "test"])
|
|
assert row_count == 1 # 2 rows written including header
|
|
row_count = io.flagging_callback.flag(["test", "test"])
|
|
assert row_count == 2 # 3 rows written including header
|
|
io.close()
|
|
|
|
|
|
class TestSimpleFlagging:
|
|
def test_simple_csv_flagging_callback(self):
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
io = gr.Interface(
|
|
lambda x: x,
|
|
"text",
|
|
"text",
|
|
flagging_dir=tmpdirname,
|
|
flagging_callback=flagging.SimpleCSVLogger(),
|
|
)
|
|
io.launch(prevent_thread_lock=True)
|
|
row_count = io.flagging_callback.flag(["test", "test"])
|
|
assert row_count == 0 # no header in SimpleCSVLogger
|
|
row_count = io.flagging_callback.flag(["test", "test"])
|
|
assert row_count == 1 # no header in SimpleCSVLogger
|
|
io.close()
|
|
|
|
|
|
class TestHuggingFaceDatasetSaver:
|
|
def test_saver_setup(self):
|
|
huggingface_hub.get_full_repo_name = MagicMock(return_value="test/test")
|
|
huggingface_hub.create_repo = MagicMock()
|
|
huggingface_hub.Repository = MagicMock()
|
|
flagger = flagging.HuggingFaceDatasetSaver("test", "test")
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
flagger.setup([gr.Audio, gr.Textbox], tmpdirname)
|
|
huggingface_hub.create_repo.assert_called_once()
|
|
|
|
def test_saver_flag(self):
|
|
huggingface_hub.get_full_repo_name = MagicMock(return_value="test/test")
|
|
huggingface_hub.create_repo = MagicMock()
|
|
huggingface_hub.Repository = MagicMock()
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
io = gr.Interface(
|
|
lambda x: x,
|
|
"text",
|
|
"text",
|
|
flagging_dir=tmpdirname,
|
|
flagging_callback=flagging.HuggingFaceDatasetSaver("test", "test"),
|
|
)
|
|
os.mkdir(os.path.join(tmpdirname, "test"))
|
|
io.launch(prevent_thread_lock=True)
|
|
row_count = io.flagging_callback.flag(["test", "test"])
|
|
assert row_count == 1 # 2 rows written including header
|
|
row_count = io.flagging_callback.flag(["test", "test"])
|
|
assert row_count == 2 # 3 rows written including header
|
|
|
|
|
|
class TestHuggingFaceDatasetJSONSaver:
|
|
def test_saver_setup(self):
|
|
huggingface_hub.get_full_repo_name = MagicMock(return_value="test/test")
|
|
huggingface_hub.create_repo = MagicMock()
|
|
huggingface_hub.Repository = MagicMock()
|
|
flagger = flagging.HuggingFaceDatasetJSONSaver("test", "test")
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
flagger.setup([gr.Audio, gr.Textbox], tmpdirname)
|
|
huggingface_hub.create_repo.assert_called_once()
|
|
|
|
def test_saver_flag(self):
|
|
huggingface_hub.get_full_repo_name = MagicMock(return_value="test/test")
|
|
huggingface_hub.create_repo = MagicMock()
|
|
huggingface_hub.Repository = MagicMock()
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
io = gr.Interface(
|
|
lambda x: x,
|
|
"text",
|
|
"text",
|
|
flagging_dir=tmpdirname,
|
|
flagging_callback=flagging.HuggingFaceDatasetJSONSaver("test", "test"),
|
|
)
|
|
test_dir = os.path.join(tmpdirname, "test")
|
|
os.mkdir(test_dir)
|
|
io.launch(prevent_thread_lock=True)
|
|
row_unique_name = io.flagging_callback.flag(["test", "test"])
|
|
# Test existence of metadata.jsonl file for that example
|
|
assert os.path.isfile(
|
|
os.path.join(os.path.join(test_dir, row_unique_name), "metadata.jsonl")
|
|
)
|
|
|
|
|
|
class TestDisableFlagging:
|
|
def test_flagging_no_permission_error_with_flagging_disabled(self):
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
os.chmod(tmpdirname, 0o444) # Make directory read-only
|
|
nonwritable_path = os.path.join(tmpdirname, "flagging_dir")
|
|
|
|
io = gr.Interface(
|
|
lambda x: x,
|
|
"text",
|
|
"text",
|
|
allow_flagging="never",
|
|
flagging_dir=nonwritable_path,
|
|
)
|
|
try:
|
|
io.launch(prevent_thread_lock=True)
|
|
except PermissionError:
|
|
self.fail("launch() raised a PermissionError unexpectedly")
|
|
|
|
io.close()
|
|
|
|
|
|
class TestInterfaceConstructsFlagMethod:
|
|
@pytest.mark.parametrize(
|
|
"allow_flagging, called",
|
|
[
|
|
("manual", True),
|
|
("auto", True),
|
|
("never", False),
|
|
],
|
|
)
|
|
def test_flag_method_init_called(self, allow_flagging, called):
|
|
flagging.FlagMethod.__init__ = MagicMock()
|
|
flagging.FlagMethod.__init__.return_value = None
|
|
gr.Interface(lambda x: x, "text", "text", allow_flagging=allow_flagging)
|
|
assert flagging.FlagMethod.__init__.called == called
|