2021-11-02 00:40:51 +08:00
|
|
|
import gradio as gr
|
2021-11-13 14:33:59 +08:00
|
|
|
from gradio import flagging
|
2021-11-02 00:40:51 +08:00
|
|
|
import tempfile
|
|
|
|
import unittest
|
|
|
|
import unittest.mock as mock
|
|
|
|
|
|
|
|
|
|
|
|
class TestFlagging(unittest.TestCase):
|
2021-11-13 14:33:59 +08:00
|
|
|
def test_default_flagging_handler(self):
|
2021-11-02 00:40:51 +08:00
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
2021-11-13 14:33:59 +08:00
|
|
|
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
|
|
|
|
io.launch(prevent_thread_lock=True)
|
2021-11-17 02:43:40 +08:00
|
|
|
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
2021-11-13 14:33:59 +08:00
|
|
|
self.assertEqual(row_count, 1) # 2 rows written including header
|
2021-11-17 02:43:40 +08:00
|
|
|
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
2021-11-13 14:33:59 +08:00
|
|
|
self.assertEqual(row_count, 2) # 3 rows written including header
|
2021-11-02 00:40:51 +08:00
|
|
|
io.close()
|
|
|
|
|
2021-11-13 14:33:59 +08:00
|
|
|
def test_simple_csv_flagging_handler(self):
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
2021-11-17 02:43:40 +08:00
|
|
|
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname, flagging_callback=flagging.SimpleCSVLogger())
|
2021-11-13 14:33:59 +08:00
|
|
|
io.launch(prevent_thread_lock=True)
|
2021-11-17 02:43:40 +08:00
|
|
|
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
|
|
|
self.assertEqual(row_count, 0) # no header
|
|
|
|
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
|
|
|
self.assertEqual(row_count, 1) # no header
|
2021-11-13 14:33:59 +08:00
|
|
|
io.close()
|
2021-11-02 00:40:51 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|