mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-24 10:54:04 +08:00
flagging tests pass
This commit is contained in:
parent
6a0177337b
commit
6eeaccc11e
@ -73,7 +73,7 @@ class SimpleCSVLogger(FlaggingCallback):
|
||||
username: Optional[str] = None,
|
||||
) -> int:
|
||||
flagging_dir = self.flagging_dir
|
||||
log_filepath = "{}/log.csv".format(flagging_dir)
|
||||
log_filepath = os.path.join(flagging_dir, "log.csv")
|
||||
|
||||
csv_data = []
|
||||
for component, sample in zip(self.components, flag_data):
|
||||
@ -120,8 +120,8 @@ class CSVLogger(FlaggingCallback):
|
||||
username: Optional[str] = None,
|
||||
) -> int:
|
||||
flagging_dir = self.flagging_dir
|
||||
log_fp = "{}/log.csv".format(flagging_dir)
|
||||
is_new = not os.path.exists(log_fp)
|
||||
log_filepath = os.path.join(flagging_dir, "log.csv")
|
||||
is_new = not os.path.exists(log_filepath)
|
||||
|
||||
if flag_index is None:
|
||||
csv_data = []
|
||||
@ -160,7 +160,7 @@ class CSVLogger(FlaggingCallback):
|
||||
if self.encryption_key:
|
||||
output = io.StringIO()
|
||||
if not is_new:
|
||||
with open(log_fp, "rb") as csvfile:
|
||||
with open(log_filepath, "rb") as csvfile:
|
||||
encrypted_csv = csvfile.read()
|
||||
decrypted_csv = encryptor.decrypt(
|
||||
self.encryption_key, encrypted_csv
|
||||
@ -174,26 +174,26 @@ class CSVLogger(FlaggingCallback):
|
||||
if is_new:
|
||||
writer.writerow(headers)
|
||||
writer.writerow(csv_data)
|
||||
with open(log_fp, "wb") as csvfile:
|
||||
with open(log_filepath, "wb") as csvfile:
|
||||
csvfile.write(
|
||||
encryptor.encrypt(self.encryption_key, output.getvalue().encode())
|
||||
)
|
||||
else:
|
||||
if flag_index is None:
|
||||
with open(log_fp, "a", newline="") as csvfile:
|
||||
with open(log_filepath, "a", newline="") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
if is_new:
|
||||
writer.writerow(headers)
|
||||
writer.writerow(csv_data)
|
||||
else:
|
||||
with open(log_fp) as csvfile:
|
||||
with open(log_filepath) as csvfile:
|
||||
file_content = csvfile.read()
|
||||
file_content = replace_flag_at_index(file_content)
|
||||
with open(
|
||||
log_fp, "w", newline=""
|
||||
log_filepath, "w", newline=""
|
||||
) as csvfile: # newline parameter needed for Windows
|
||||
csvfile.write(file_content)
|
||||
with open(log_fp, "r") as csvfile:
|
||||
with open(log_filepath, "r") as csvfile:
|
||||
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
||||
return line_count
|
||||
|
||||
|
@ -492,8 +492,8 @@ class Interface(Blocks):
|
||||
for component in self.input_components:
|
||||
component.render()
|
||||
with Row():
|
||||
submit_btn = Button("Submit")
|
||||
clear_btn = Button("Clear")
|
||||
submit_btn = Button("Submit")
|
||||
with Column(
|
||||
css={
|
||||
"background-color": "rgb(249,250,251)",
|
||||
@ -502,7 +502,7 @@ class Interface(Blocks):
|
||||
}
|
||||
):
|
||||
for component in self.output_components:
|
||||
Block.__init__(component)
|
||||
component.render()
|
||||
with Row():
|
||||
flag_btn = Button("Flag")
|
||||
submit_btn.click(
|
||||
|
@ -45,9 +45,9 @@
|
||||
</script>
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script>
|
||||
<title>Gradio</title>
|
||||
<script type="module" crossorigin src="./assets/index.ea63c9ea.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index.68bdf974.js"></script>
|
||||
<link rel="modulepreload" href="./assets/vendor.c988cbcf.js">
|
||||
<link rel="stylesheet" href="./assets/index.778d40cb.css">
|
||||
<link rel="stylesheet" href="./assets/index.cdf32a5f.css">
|
||||
</head>
|
||||
|
||||
<body style="height: 100%; margin: 0; padding: 0">
|
||||
|
@ -1,17 +0,0 @@
|
||||
# import unittest
|
||||
|
||||
# from gradio.context import Context
|
||||
|
||||
|
||||
# class TestContext(unittest.TestCase):
|
||||
# def test_context(self):
|
||||
# self.assertEqual(Context.id, 0)
|
||||
# Context.id += 1
|
||||
# self.assertEqual(Context.id, 1)
|
||||
# Context.root_block = {}
|
||||
# Context.root_block["1"] = 1
|
||||
# self.assertEqual(Context.root_block, {"1": 1})
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# unittest.main()
|
@ -8,62 +8,62 @@ import huggingface_hub
|
||||
import gradio as gr
|
||||
from gradio import flagging
|
||||
|
||||
# class TestDefaultFlagging(unittest.TestCase):
|
||||
# def test_default_flagging_callback(self):
|
||||
# with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
|
||||
# io.launch(prevent_thread_lock=True)
|
||||
# row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
# self.assertEqual(row_count, 1) # 2 rows written including header
|
||||
# row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
# self.assertEqual(row_count, 2) # 3 rows written including header
|
||||
# io.close()
|
||||
class TestDefaultFlagging(unittest.TestCase):
|
||||
def test_default_flagging_callback(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
|
||||
io.launch(prevent_thread_lock=True)
|
||||
row_count = io.flagging_callback.flag(["test", "test"])
|
||||
self.assertEqual(row_count, 1) # 2 rows written including header
|
||||
row_count = io.flagging_callback.flag(["test", "test"])
|
||||
self.assertEqual(row_count, 2) # 3 rows written including header
|
||||
io.close()
|
||||
|
||||
|
||||
# class TestSimpleFlagging(unittest.TestCase):
|
||||
# def test_simple_csv_flagging_callback(self):
|
||||
# with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# io = gr.Interface(
|
||||
# lambda x: x,
|
||||
# "text",
|
||||
# "text",
|
||||
# flagging_dir=tmpdirname,
|
||||
# flagging_callback=flagging.SimpleCSVLogger(),
|
||||
# )
|
||||
# io.launch(prevent_thread_lock=True)
|
||||
# row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
# self.assertEqual(row_count, 0) # no header in SimpleCSVLogger
|
||||
# row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
# self.assertEqual(row_count, 1) # no header in SimpleCSVLogger
|
||||
# io.close()
|
||||
class TestSimpleFlagging(unittest.TestCase):
|
||||
def test_simple_csv_flagging_callback(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(
|
||||
lambda x: x,
|
||||
"text",
|
||||
"text",
|
||||
flagging_dir=tmpdirname,
|
||||
flagging_callback=flagging.SimpleCSVLogger(),
|
||||
)
|
||||
io.launch(prevent_thread_lock=True)
|
||||
row_count = io.flagging_callback.flag(["test", "test"])
|
||||
self.assertEqual(row_count, 0) # no header in SimpleCSVLogger
|
||||
row_count = io.flagging_callback.flag(["test", "test"])
|
||||
self.assertEqual(row_count, 1) # no header in SimpleCSVLogger
|
||||
io.close()
|
||||
|
||||
|
||||
# class TestHuggingFaceDatasetSaver(unittest.TestCase):
|
||||
# def test_saver_setup(self):
|
||||
# huggingface_hub.create_repo = MagicMock()
|
||||
# huggingface_hub.Repository = MagicMock()
|
||||
# flagger = flagging.HuggingFaceDatasetSaver("test", "test")
|
||||
# with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# flagger.setup(tmpdirname)
|
||||
# huggingface_hub.create_repo.assert_called_once()
|
||||
class TestHuggingFaceDatasetSaver(unittest.TestCase):
|
||||
def test_saver_setup(self):
|
||||
huggingface_hub.create_repo = MagicMock()
|
||||
huggingface_hub.Repository = MagicMock()
|
||||
flagger = flagging.HuggingFaceDatasetSaver("test", "test")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
flagger.setup([gr.Audio, gr.Textbox], tmpdirname)
|
||||
huggingface_hub.create_repo.assert_called_once()
|
||||
|
||||
# def test_saver_flag(self):
|
||||
# huggingface_hub.create_repo = MagicMock()
|
||||
# huggingface_hub.Repository = MagicMock()
|
||||
# with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# io = gr.Interface(
|
||||
# lambda x: x,
|
||||
# "text",
|
||||
# "text",
|
||||
# flagging_dir=tmpdirname,
|
||||
# flagging_callback=flagging.HuggingFaceDatasetSaver("test", "test"),
|
||||
# )
|
||||
# os.mkdir(os.path.join(tmpdirname, "test"))
|
||||
# io.launch(prevent_thread_lock=True)
|
||||
# row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
# self.assertEqual(row_count, 1) # 2 rows written including header
|
||||
# row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
# self.assertEqual(row_count, 2) # 3 rows written including header
|
||||
def test_saver_flag(self):
|
||||
huggingface_hub.create_repo = MagicMock()
|
||||
huggingface_hub.Repository = MagicMock()
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(
|
||||
lambda x: x,
|
||||
"text",
|
||||
"text",
|
||||
flagging_dir=tmpdirname,
|
||||
flagging_callback=flagging.HuggingFaceDatasetSaver("test", "test"),
|
||||
)
|
||||
os.mkdir(os.path.join(tmpdirname, "test"))
|
||||
io.launch(prevent_thread_lock=True)
|
||||
row_count = io.flagging_callback.flag(["test", "test"])
|
||||
self.assertEqual(row_count, 1) # 2 rows written including header
|
||||
row_count = io.flagging_callback.flag(["test", "test"])
|
||||
self.assertEqual(row_count, 2) # 3 rows written including header
|
||||
|
||||
|
||||
class TestDisableFlagging(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user