mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
added flagging tests
This commit is contained in:
parent
e76aa4a1d9
commit
82cea220fa
@ -18,7 +18,7 @@ jobs:
|
||||
. venv/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install -r gradio.egg-info/requires.txt
|
||||
pip install shap IPython comet_ml wandb mlflow tensorflow transformers
|
||||
pip install shap IPython comet_ml wandb mlflow tensorflow transformers huggingface_hub
|
||||
pip install selenium==4.0.0a6.post2 coverage scikit-image
|
||||
- run:
|
||||
command: |
|
||||
|
@ -1,12 +1,16 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
import gradio as gr
|
||||
from gradio import flagging
|
||||
|
||||
|
||||
class TestDefaultFlagging(unittest.TestCase):
|
||||
def test_default_flagging_handler(self):
|
||||
def test_default_flagging_callback(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
|
||||
io.launch(prevent_thread_lock=True)
|
||||
@ -18,7 +22,7 @@ class TestDefaultFlagging(unittest.TestCase):
|
||||
|
||||
|
||||
class TestSimpleFlagging(unittest.TestCase):
|
||||
def test_simple_csv_flagging_handler(self):
|
||||
def test_simple_csv_flagging_callback(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(
|
||||
lambda x: x,
|
||||
@ -29,11 +33,39 @@ class TestSimpleFlagging(unittest.TestCase):
|
||||
)
|
||||
io.launch(prevent_thread_lock=True)
|
||||
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
self.assertEqual(row_count, 0) # no header
|
||||
self.assertEqual(row_count, 0) # no header in SimpleCSVLogger
|
||||
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
self.assertEqual(row_count, 1) # no header
|
||||
self.assertEqual(row_count, 1) # no header in SimpleCSVLogger
|
||||
io.close()
|
||||
|
||||
|
||||
class TestHuggingFaceDatasetSaver(unittest.TestCase):
|
||||
def test_saver_setup(self):
|
||||
huggingface_hub.create_repo = MagicMock()
|
||||
huggingface_hub.Repository = MagicMock()
|
||||
flagger = flagging.HuggingFaceDatasetSaver("test", "test")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
flagger.setup(tmpdirname)
|
||||
huggingface_hub.create_repo.assert_called_once()
|
||||
|
||||
def test_saver_flag(self):
|
||||
huggingface_hub.create_repo = MagicMock()
|
||||
huggingface_hub.Repository = MagicMock()
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(
|
||||
lambda x: x,
|
||||
"text",
|
||||
"text",
|
||||
flagging_dir=tmpdirname,
|
||||
flagging_callback=flagging.HuggingFaceDatasetSaver("test", "test"),
|
||||
)
|
||||
os.mkdir(os.path.join(tmpdirname, "test"))
|
||||
io.launch(prevent_thread_lock=True)
|
||||
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
self.assertEqual(row_count, 1) # 2 rows written including header
|
||||
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
self.assertEqual(row_count, 2) # 3 rows written including header
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user