gradio/test/test_flagging.py

40 lines
1.4 KiB
Python
Raw Normal View History

import tempfile
import unittest
import gradio as gr
from gradio import flagging
2022-01-24 12:54:48 +08:00
class TestDefaultFlagging(unittest.TestCase):
2021-11-13 14:33:59 +08:00
def test_default_flagging_handler(self):
with tempfile.TemporaryDirectory() as tmpdirname:
2021-11-13 14:33:59 +08:00
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
io.launch(prevent_thread_lock=True)
2021-11-17 02:43:40 +08:00
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
2021-11-13 14:33:59 +08:00
self.assertEqual(row_count, 1) # 2 rows written including header
2021-11-17 02:43:40 +08:00
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
2021-11-13 14:33:59 +08:00
self.assertEqual(row_count, 2) # 3 rows written including header
io.close()
2022-01-24 12:54:48 +08:00
class TestSimpleFlagging(unittest.TestCase):
2021-11-13 14:33:59 +08:00
def test_simple_csv_flagging_handler(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
lambda x: x,
"text",
"text",
flagging_dir=tmpdirname,
flagging_callback=flagging.SimpleCSVLogger(),
)
2021-11-13 14:33:59 +08:00
io.launch(prevent_thread_lock=True)
2021-11-17 02:43:40 +08:00
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
self.assertEqual(row_count, 0) # no header
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
self.assertEqual(row_count, 1) # no header
2021-11-13 14:33:59 +08:00
io.close()
2022-01-24 12:54:48 +08:00
if __name__ == "__main__":
unittest.main()