fix: avoid unnecessary folders creation when flagging (#6245)

* fix: avoid unnecessary folders creation when flagging

* test: flagging_does_not_create_unnecessary_directories

---------

Co-authored-by: Egon Ferri <egon.ferri@immobiliare.it>
This commit is contained in:
Egon Ferri 2023-11-02 20:28:38 +01:00 committed by GitHub
parent 61c155e9ba
commit a6bb7222ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 1 deletions

View File

@ -273,6 +273,7 @@ class Component(ComponentBase, Block):
"""
if self.data_model:
payload = self.data_model.from_json(payload)
Path(flag_dir).mkdir(exist_ok=True)
return payload.copy_to_dir(flag_dir).model_dump_json()
return payload

View File

@ -166,7 +166,6 @@ class CSVLogger(FlaggingCallback):
) / client_utils.strip_invalid_filename_characters(
getattr(component, "label", None) or f"component {idx}"
)
save_dir.mkdir(exist_ok=True)
if utils.is_update(sample):
csv_data.append(str(sample))
else:

View File

@ -21,6 +21,13 @@ class TestDefaultFlagging:
assert row_count == 2 # 3 rows written including header
io.close()
def test_flagging_does_not_create_unnecessary_directories(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
io.launch(prevent_thread_lock=True)
io.flagging_callback.flag(["test", "test"])
assert os.listdir(tmpdirname) == ["log.csv"]
class TestSimpleFlagging:
def test_simple_csv_flagging_callback(self):