diff --git a/gradio/flagging.py b/gradio/flagging.py index 9e2e5f3a95..4d66b0c80e 100644 --- a/gradio/flagging.py +++ b/gradio/flagging.py @@ -73,7 +73,7 @@ class SimpleCSVLogger(FlaggingCallback): username: Optional[str] = None, ) -> int: flagging_dir = self.flagging_dir - log_filepath = "{}/log.csv".format(flagging_dir) + log_filepath = os.path.join(flagging_dir, "log.csv") csv_data = [] for component, sample in zip(self.components, flag_data): @@ -120,8 +120,8 @@ class CSVLogger(FlaggingCallback): username: Optional[str] = None, ) -> int: flagging_dir = self.flagging_dir - log_fp = "{}/log.csv".format(flagging_dir) - is_new = not os.path.exists(log_fp) + log_filepath = os.path.join(flagging_dir, "log.csv") + is_new = not os.path.exists(log_filepath) if flag_index is None: csv_data = [] @@ -160,7 +160,7 @@ class CSVLogger(FlaggingCallback): if self.encryption_key: output = io.StringIO() if not is_new: - with open(log_fp, "rb") as csvfile: + with open(log_filepath, "rb") as csvfile: encrypted_csv = csvfile.read() decrypted_csv = encryptor.decrypt( self.encryption_key, encrypted_csv @@ -174,26 +174,26 @@ class CSVLogger(FlaggingCallback): if is_new: writer.writerow(headers) writer.writerow(csv_data) - with open(log_fp, "wb") as csvfile: + with open(log_filepath, "wb") as csvfile: csvfile.write( encryptor.encrypt(self.encryption_key, output.getvalue().encode()) ) else: if flag_index is None: - with open(log_fp, "a", newline="") as csvfile: + with open(log_filepath, "a", newline="") as csvfile: writer = csv.writer(csvfile) if is_new: writer.writerow(headers) writer.writerow(csv_data) else: - with open(log_fp) as csvfile: + with open(log_filepath) as csvfile: file_content = csvfile.read() file_content = replace_flag_at_index(file_content) with open( - log_fp, "w", newline="" + log_filepath, "w", newline="" ) as csvfile: # newline parameter needed for Windows csvfile.write(file_content) - with open(log_fp, "r") as csvfile: + with open(log_filepath, "r") as csvfile: line_count = len([None for row in csv.reader(csvfile)]) - 1 return line_count diff --git a/gradio/interface.py b/gradio/interface.py index 5df36c8335..71c8cf018f 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -492,8 +492,8 @@ class Interface(Blocks): for component in self.input_components: component.render() with Row(): - submit_btn = Button("Submit") clear_btn = Button("Clear") + submit_btn = Button("Submit") with Column( css={ "background-color": "rgb(249,250,251)", @@ -502,7 +502,7 @@ class Interface(Blocks): } ): for component in self.output_components: - Block.__init__(component) + component.render() with Row(): flag_btn = Button("Flag") submit_btn.click( diff --git a/gradio/templates/frontend/index.html b/gradio/templates/frontend/index.html index 7d1146f619..1a320c6b09 100644 --- a/gradio/templates/frontend/index.html +++ b/gradio/templates/frontend/index.html @@ -45,9 +45,9 @@ Gradio - + - + diff --git a/test/test_context.py b/test/test_context.py deleted file mode 100644 index 00e7a7e406..0000000000 --- a/test/test_context.py +++ /dev/null @@ -1,17 +0,0 @@ -# import unittest - -# from gradio.context import Context - - -# class TestContext(unittest.TestCase): -# def test_context(self): -# self.assertEqual(Context.id, 0) -# Context.id += 1 -# self.assertEqual(Context.id, 1) -# Context.root_block = {} -# Context.root_block["1"] = 1 -# self.assertEqual(Context.root_block, {"1": 1}) - - -# if __name__ == "__main__": -# unittest.main() diff --git a/test/test_flagging.py b/test/test_flagging.py index 1eddd26fd5..c1c41b08b3 100644 --- a/test/test_flagging.py +++ b/test/test_flagging.py @@ -8,62 +8,62 @@ import huggingface_hub import gradio as gr from gradio import flagging -# class TestDefaultFlagging(unittest.TestCase): -# def test_default_flagging_callback(self): -# with tempfile.TemporaryDirectory() as tmpdirname: -# io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname) -# io.launch(prevent_thread_lock=True) -# row_count = io.flagging_callback.flag(io, ["test"], ["test"]) -# self.assertEqual(row_count, 1) # 2 rows written including header -# row_count = io.flagging_callback.flag(io, ["test"], ["test"]) -# self.assertEqual(row_count, 2) # 3 rows written including header -# io.close() +class TestDefaultFlagging(unittest.TestCase): + def test_default_flagging_callback(self): + with tempfile.TemporaryDirectory() as tmpdirname: + io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname) + io.launch(prevent_thread_lock=True) + row_count = io.flagging_callback.flag(["test", "test"]) + self.assertEqual(row_count, 1) # 2 rows written including header + row_count = io.flagging_callback.flag(["test", "test"]) + self.assertEqual(row_count, 2) # 3 rows written including header + io.close() -# class TestSimpleFlagging(unittest.TestCase): -# def test_simple_csv_flagging_callback(self): -# with tempfile.TemporaryDirectory() as tmpdirname: -# io = gr.Interface( -# lambda x: x, -# "text", -# "text", -# flagging_dir=tmpdirname, -# flagging_callback=flagging.SimpleCSVLogger(), -# ) -# io.launch(prevent_thread_lock=True) -# row_count = io.flagging_callback.flag(io, ["test"], ["test"]) -# self.assertEqual(row_count, 0) # no header in SimpleCSVLogger -# row_count = io.flagging_callback.flag(io, ["test"], ["test"]) -# self.assertEqual(row_count, 1) # no header in SimpleCSVLogger -# io.close() +class TestSimpleFlagging(unittest.TestCase): + def test_simple_csv_flagging_callback(self): + with tempfile.TemporaryDirectory() as tmpdirname: + io = gr.Interface( + lambda x: x, + "text", + "text", + flagging_dir=tmpdirname, + flagging_callback=flagging.SimpleCSVLogger(), + ) + io.launch(prevent_thread_lock=True) + row_count = io.flagging_callback.flag(["test", "test"]) + self.assertEqual(row_count, 0) # no header in SimpleCSVLogger + row_count = io.flagging_callback.flag(["test", "test"]) + self.assertEqual(row_count, 1) # no header in SimpleCSVLogger + io.close() -# class TestHuggingFaceDatasetSaver(unittest.TestCase): -# def test_saver_setup(self): -# huggingface_hub.create_repo = MagicMock() -# huggingface_hub.Repository = MagicMock() -# flagger = flagging.HuggingFaceDatasetSaver("test", "test") -# with tempfile.TemporaryDirectory() as tmpdirname: -# flagger.setup(tmpdirname) -# huggingface_hub.create_repo.assert_called_once() +class TestHuggingFaceDatasetSaver(unittest.TestCase): + def test_saver_setup(self): + huggingface_hub.create_repo = MagicMock() + huggingface_hub.Repository = MagicMock() + flagger = flagging.HuggingFaceDatasetSaver("test", "test") + with tempfile.TemporaryDirectory() as tmpdirname: + flagger.setup([gr.Audio, gr.Textbox], tmpdirname) + huggingface_hub.create_repo.assert_called_once() -# def test_saver_flag(self): -# huggingface_hub.create_repo = MagicMock() -# huggingface_hub.Repository = MagicMock() -# with tempfile.TemporaryDirectory() as tmpdirname: -# io = gr.Interface( -# lambda x: x, -# "text", -# "text", -# flagging_dir=tmpdirname, -# flagging_callback=flagging.HuggingFaceDatasetSaver("test", "test"), -# ) -# os.mkdir(os.path.join(tmpdirname, "test")) -# io.launch(prevent_thread_lock=True) -# row_count = io.flagging_callback.flag(io, ["test"], ["test"]) -# self.assertEqual(row_count, 1) # 2 rows written including header -# row_count = io.flagging_callback.flag(io, ["test"], ["test"]) -# self.assertEqual(row_count, 2) # 3 rows written including header + def test_saver_flag(self): + huggingface_hub.create_repo = MagicMock() + huggingface_hub.Repository = MagicMock() + with tempfile.TemporaryDirectory() as tmpdirname: + io = gr.Interface( + lambda x: x, + "text", + "text", + flagging_dir=tmpdirname, + flagging_callback=flagging.HuggingFaceDatasetSaver("test", "test"), + ) + os.mkdir(os.path.join(tmpdirname, "test")) + io.launch(prevent_thread_lock=True) + row_count = io.flagging_callback.flag(["test", "test"]) + self.assertEqual(row_count, 1) # 2 rows written including header + row_count = io.flagging_callback.flag(["test", "test"]) + self.assertEqual(row_count, 2) # 3 rows written including header class TestDisableFlagging(unittest.TestCase):