flagging tests pass

This commit is contained in:
Abubakar Abid 2022-03-25 11:12:45 -07:00
parent 6a0177337b
commit 6eeaccc11e
5 changed files with 64 additions and 81 deletions

View File

@ -73,7 +73,7 @@ class SimpleCSVLogger(FlaggingCallback):
username: Optional[str] = None,
) -> int:
flagging_dir = self.flagging_dir
log_filepath = "{}/log.csv".format(flagging_dir)
log_filepath = os.path.join(flagging_dir, "log.csv")
csv_data = []
for component, sample in zip(self.components, flag_data):
@ -120,8 +120,8 @@ class CSVLogger(FlaggingCallback):
username: Optional[str] = None,
) -> int:
flagging_dir = self.flagging_dir
log_fp = "{}/log.csv".format(flagging_dir)
is_new = not os.path.exists(log_fp)
log_filepath = os.path.join(flagging_dir, "log.csv")
is_new = not os.path.exists(log_filepath)
if flag_index is None:
csv_data = []
@ -160,7 +160,7 @@ class CSVLogger(FlaggingCallback):
if self.encryption_key:
output = io.StringIO()
if not is_new:
with open(log_fp, "rb") as csvfile:
with open(log_filepath, "rb") as csvfile:
encrypted_csv = csvfile.read()
decrypted_csv = encryptor.decrypt(
self.encryption_key, encrypted_csv
@ -174,26 +174,26 @@ class CSVLogger(FlaggingCallback):
if is_new:
writer.writerow(headers)
writer.writerow(csv_data)
with open(log_fp, "wb") as csvfile:
with open(log_filepath, "wb") as csvfile:
csvfile.write(
encryptor.encrypt(self.encryption_key, output.getvalue().encode())
)
else:
if flag_index is None:
with open(log_fp, "a", newline="") as csvfile:
with open(log_filepath, "a", newline="") as csvfile:
writer = csv.writer(csvfile)
if is_new:
writer.writerow(headers)
writer.writerow(csv_data)
else:
with open(log_fp) as csvfile:
with open(log_filepath) as csvfile:
file_content = csvfile.read()
file_content = replace_flag_at_index(file_content)
with open(
log_fp, "w", newline=""
log_filepath, "w", newline=""
) as csvfile: # newline parameter needed for Windows
csvfile.write(file_content)
with open(log_fp, "r") as csvfile:
with open(log_filepath, "r") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
return line_count

View File

@ -492,8 +492,8 @@ class Interface(Blocks):
for component in self.input_components:
component.render()
with Row():
submit_btn = Button("Submit")
clear_btn = Button("Clear")
submit_btn = Button("Submit")
with Column(
css={
"background-color": "rgb(249,250,251)",
@ -502,7 +502,7 @@ class Interface(Blocks):
}
):
for component in self.output_components:
Block.__init__(component)
component.render()
with Row():
flag_btn = Button("Flag")
submit_btn.click(

View File

@ -45,9 +45,9 @@
</script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
<title>Gradio</title>
<script type="module" crossorigin src="./assets/index.ea63c9ea.js"></script>
<script type="module" crossorigin src="./assets/index.68bdf974.js"></script>
<link rel="modulepreload" href="./assets/vendor.c988cbcf.js">
<link rel="stylesheet" href="./assets/index.778d40cb.css">
<link rel="stylesheet" href="./assets/index.cdf32a5f.css">
</head>
<body style="height: 100%; margin: 0; padding: 0">

View File

@ -1,17 +0,0 @@
# import unittest
# from gradio.context import Context
# class TestContext(unittest.TestCase):
# def test_context(self):
# self.assertEqual(Context.id, 0)
# Context.id += 1
# self.assertEqual(Context.id, 1)
# Context.root_block = {}
# Context.root_block["1"] = 1
# self.assertEqual(Context.root_block, {"1": 1})
# if __name__ == "__main__":
# unittest.main()

View File

@ -8,62 +8,62 @@ import huggingface_hub
import gradio as gr
from gradio import flagging
# class TestDefaultFlagging(unittest.TestCase):
# def test_default_flagging_callback(self):
# with tempfile.TemporaryDirectory() as tmpdirname:
# io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
# io.launch(prevent_thread_lock=True)
# row_count = io.flagging_callback.flag(io, ["test"], ["test"])
# self.assertEqual(row_count, 1) # 2 rows written including header
# row_count = io.flagging_callback.flag(io, ["test"], ["test"])
# self.assertEqual(row_count, 2) # 3 rows written including header
# io.close()
class TestDefaultFlagging(unittest.TestCase):
def test_default_flagging_callback(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
io.launch(prevent_thread_lock=True)
row_count = io.flagging_callback.flag(["test", "test"])
self.assertEqual(row_count, 1) # 2 rows written including header
row_count = io.flagging_callback.flag(["test", "test"])
self.assertEqual(row_count, 2) # 3 rows written including header
io.close()
# class TestSimpleFlagging(unittest.TestCase):
# def test_simple_csv_flagging_callback(self):
# with tempfile.TemporaryDirectory() as tmpdirname:
# io = gr.Interface(
# lambda x: x,
# "text",
# "text",
# flagging_dir=tmpdirname,
# flagging_callback=flagging.SimpleCSVLogger(),
# )
# io.launch(prevent_thread_lock=True)
# row_count = io.flagging_callback.flag(io, ["test"], ["test"])
# self.assertEqual(row_count, 0) # no header in SimpleCSVLogger
# row_count = io.flagging_callback.flag(io, ["test"], ["test"])
# self.assertEqual(row_count, 1) # no header in SimpleCSVLogger
# io.close()
class TestSimpleFlagging(unittest.TestCase):
def test_simple_csv_flagging_callback(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
lambda x: x,
"text",
"text",
flagging_dir=tmpdirname,
flagging_callback=flagging.SimpleCSVLogger(),
)
io.launch(prevent_thread_lock=True)
row_count = io.flagging_callback.flag(["test", "test"])
self.assertEqual(row_count, 0) # no header in SimpleCSVLogger
row_count = io.flagging_callback.flag(["test", "test"])
self.assertEqual(row_count, 1) # no header in SimpleCSVLogger
io.close()
# class TestHuggingFaceDatasetSaver(unittest.TestCase):
# def test_saver_setup(self):
# huggingface_hub.create_repo = MagicMock()
# huggingface_hub.Repository = MagicMock()
# flagger = flagging.HuggingFaceDatasetSaver("test", "test")
# with tempfile.TemporaryDirectory() as tmpdirname:
# flagger.setup(tmpdirname)
# huggingface_hub.create_repo.assert_called_once()
class TestHuggingFaceDatasetSaver(unittest.TestCase):
def test_saver_setup(self):
huggingface_hub.create_repo = MagicMock()
huggingface_hub.Repository = MagicMock()
flagger = flagging.HuggingFaceDatasetSaver("test", "test")
with tempfile.TemporaryDirectory() as tmpdirname:
flagger.setup([gr.Audio, gr.Textbox], tmpdirname)
huggingface_hub.create_repo.assert_called_once()
# def test_saver_flag(self):
# huggingface_hub.create_repo = MagicMock()
# huggingface_hub.Repository = MagicMock()
# with tempfile.TemporaryDirectory() as tmpdirname:
# io = gr.Interface(
# lambda x: x,
# "text",
# "text",
# flagging_dir=tmpdirname,
# flagging_callback=flagging.HuggingFaceDatasetSaver("test", "test"),
# )
# os.mkdir(os.path.join(tmpdirname, "test"))
# io.launch(prevent_thread_lock=True)
# row_count = io.flagging_callback.flag(io, ["test"], ["test"])
# self.assertEqual(row_count, 1) # 2 rows written including header
# row_count = io.flagging_callback.flag(io, ["test"], ["test"])
# self.assertEqual(row_count, 2) # 3 rows written including header
def test_saver_flag(self):
huggingface_hub.create_repo = MagicMock()
huggingface_hub.Repository = MagicMock()
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
lambda x: x,
"text",
"text",
flagging_dir=tmpdirname,
flagging_callback=flagging.HuggingFaceDatasetSaver("test", "test"),
)
os.mkdir(os.path.join(tmpdirname, "test"))
io.launch(prevent_thread_lock=True)
row_count = io.flagging_callback.flag(["test", "test"])
self.assertEqual(row_count, 1) # 2 rows written including header
row_count = io.flagging_callback.flag(["test", "test"])
self.assertEqual(row_count, 2) # 3 rows written including header
class TestDisableFlagging(unittest.TestCase):