mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-27 01:40:20 +08:00
fixed some flagging tests
This commit is contained in:
parent
d4c357b837
commit
704073d274
@ -80,8 +80,6 @@ class CSVLogger(FlaggingCallback):
|
||||
|
||||
def flag(self, interface, input_data, output_data, flag_option=None, flag_index=None, username=None):
|
||||
flagging_dir = self.flagging_dir
|
||||
log_filepath = "{}/log.csv".format(flagging_dir)
|
||||
|
||||
log_fp = "{}/log.csv".format(flagging_dir)
|
||||
encryption_key = interface.encryption_key if interface.encrypt else None
|
||||
is_new = not os.path.exists(log_fp)
|
||||
|
@ -10,20 +10,20 @@ class TestFlagging(unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
|
||||
io.launch(prevent_thread_lock=True)
|
||||
row_count = io.flagging_handler.flag(io, ["test"], ["test"])
|
||||
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
self.assertEqual(row_count, 1) # 2 rows written including header
|
||||
row_count = io.flagging_handler.flag(io, ["test"], ["test"])
|
||||
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
self.assertEqual(row_count, 2) # 3 rows written including header
|
||||
io.close()
|
||||
|
||||
def test_simple_csv_flagging_handler(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname, flagging_handler=flagging.SimpleCSVLogger())
|
||||
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname, flagging_callback=flagging.SimpleCSVLogger())
|
||||
io.launch(prevent_thread_lock=True)
|
||||
row_count = io.flagging_handler.flag(io, ["test"], ["test"])
|
||||
self.assertEqual(row_count, 1) # 2 rows written including header
|
||||
row_count = io.flagging_handler.flag(io, ["test"], ["test"])
|
||||
self.assertEqual(row_count, 2) # 3 rows written including header
|
||||
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
self.assertEqual(row_count, 0) # no header
|
||||
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
self.assertEqual(row_count, 1) # no header
|
||||
io.close()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user