added flagging tests

This commit is contained in:
Abubakar Abid 2022-01-25 23:44:41 -06:00
parent e76aa4a1d9
commit 82cea220fa
2 changed files with 37 additions and 5 deletions

View File

@ -18,7 +18,7 @@ jobs:
. venv/bin/activate
pip install --upgrade pip
pip install -r gradio.egg-info/requires.txt
pip install shap IPython comet_ml wandb mlflow tensorflow transformers
pip install shap IPython comet_ml wandb mlflow tensorflow transformers huggingface_hub
pip install selenium==4.0.0a6.post2 coverage scikit-image
- run:
command: |

View File

@ -1,12 +1,16 @@
import os
import tempfile
import unittest
from unittest.mock import MagicMock
import huggingface_hub
import gradio as gr
from gradio import flagging
class TestDefaultFlagging(unittest.TestCase):
def test_default_flagging_handler(self):
def test_default_flagging_callback(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
io.launch(prevent_thread_lock=True)
@ -18,7 +22,7 @@ class TestDefaultFlagging(unittest.TestCase):
class TestSimpleFlagging(unittest.TestCase):
def test_simple_csv_flagging_handler(self):
def test_simple_csv_flagging_callback(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
lambda x: x,
@ -29,11 +33,39 @@ class TestSimpleFlagging(unittest.TestCase):
)
io.launch(prevent_thread_lock=True)
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
self.assertEqual(row_count, 0) # no header
self.assertEqual(row_count, 0) # no header in SimpleCSVLogger
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
self.assertEqual(row_count, 1) # no header
self.assertEqual(row_count, 1) # no header in SimpleCSVLogger
io.close()
class TestHuggingFaceDatasetSaver(unittest.TestCase):
def test_saver_setup(self):
huggingface_hub.create_repo = MagicMock()
huggingface_hub.Repository = MagicMock()
flagger = flagging.HuggingFaceDatasetSaver("test", "test")
with tempfile.TemporaryDirectory() as tmpdirname:
flagger.setup(tmpdirname)
huggingface_hub.create_repo.assert_called_once()
def test_saver_flag(self):
huggingface_hub.create_repo = MagicMock()
huggingface_hub.Repository = MagicMock()
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
lambda x: x,
"text",
"text",
flagging_dir=tmpdirname,
flagging_callback=flagging.HuggingFaceDatasetSaver("test", "test"),
)
os.mkdir(os.path.join(tmpdirname, "test"))
io.launch(prevent_thread_lock=True)
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
self.assertEqual(row_count, 1) # 2 rows written including header
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
self.assertEqual(row_count, 2) # 3 rows written including header
if __name__ == "__main__":
unittest.main()