mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-24 13:01:18 +08:00
Fix dataset features and dataset preview for HuggingFaceDatasetSaver (#5135)
* Add tests * Use hf token * lint * add changeset * Add empty string for None * add changeset --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
33baf70757
commit
80727bbe2c
5
.changeset/new-icons-refuse.md
Normal file
5
.changeset/new-icons-refuse.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
fix:Fix dataset features and dataset preview for HuggingFaceDatasetSaver
|
@ -248,6 +248,21 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
).repo_id
|
||||
path_glob = "**/*.jsonl" if self.separate_dirs else "data.csv"
|
||||
huggingface_hub.metadata_update(
|
||||
repo_id=self.dataset_id,
|
||||
repo_type="dataset",
|
||||
metadata={
|
||||
"configs": [
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": [{"split": "train", "path": path_glob}],
|
||||
}
|
||||
]
|
||||
},
|
||||
overwrite=True,
|
||||
token=self.hf_token,
|
||||
)
|
||||
|
||||
# Setup flagging dir
|
||||
self.components = components
|
||||
@ -284,7 +299,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
if self.separate_dirs:
|
||||
# JSONL files to support dataset preview on the Hub
|
||||
unique_id = str(uuid.uuid4())
|
||||
components_dir = self.dataset_dir / str(uuid.uuid4())
|
||||
components_dir = self.dataset_dir / unique_id
|
||||
data_file = components_dir / "metadata.jsonl"
|
||||
path_in_repo = unique_id # upload in sub folder (safer for concurrency)
|
||||
else:
|
||||
@ -416,28 +431,33 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
features[label] = {"dtype": "string", "_type": "Value"}
|
||||
try:
|
||||
assert Path(deserialized).exists()
|
||||
row.append(Path(deserialized).name)
|
||||
row.append(str(Path(deserialized).relative_to(self.dataset_dir)))
|
||||
except (AssertionError, TypeError, ValueError):
|
||||
row.append(str(deserialized))
|
||||
deserialized = "" if deserialized is None else str(deserialized)
|
||||
row.append(deserialized)
|
||||
|
||||
# If component is eligible for a preview, add the URL of the file
|
||||
# Be mindful that images and audio can be None
|
||||
if isinstance(component, tuple(file_preview_types)): # type: ignore
|
||||
for _component, _type in file_preview_types.items():
|
||||
if isinstance(component, _component):
|
||||
features[label + " file"] = {"_type": _type}
|
||||
break
|
||||
path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
|
||||
Path(deserialized).relative_to(self.dataset_dir)
|
||||
).replace(
|
||||
"\\", "/"
|
||||
)
|
||||
row.append(
|
||||
huggingface_hub.hf_hub_url(
|
||||
repo_id=self.dataset_id,
|
||||
filename=path_in_repo,
|
||||
repo_type="dataset",
|
||||
if deserialized:
|
||||
path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
|
||||
Path(deserialized).relative_to(self.dataset_dir)
|
||||
).replace(
|
||||
"\\", "/"
|
||||
)
|
||||
)
|
||||
row.append(
|
||||
huggingface_hub.hf_hub_url(
|
||||
repo_id=self.dataset_id,
|
||||
filename=path_in_repo,
|
||||
repo_type="dataset",
|
||||
)
|
||||
)
|
||||
else:
|
||||
row.append("")
|
||||
features["flag"] = {"dtype": "string", "_type": "Value"}
|
||||
features["username"] = {"dtype": "string", "_type": "Value"}
|
||||
row.append(flag_option)
|
||||
|
@ -46,7 +46,8 @@ class TestHuggingFaceDatasetSaver:
|
||||
return_value=MagicMock(repo_id="gradio-tests/test"),
|
||||
)
|
||||
@patch("huggingface_hub.hf_hub_download")
|
||||
def test_saver_setup(self, mock_download, mock_create):
|
||||
@patch("huggingface_hub.metadata_update")
|
||||
def test_saver_setup(self, metadata_update, mock_download, mock_create):
|
||||
flagger = flagging.HuggingFaceDatasetSaver("test_token", "test")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
flagger.setup([gr.Audio, gr.Textbox], tmpdirname)
|
||||
@ -60,8 +61,9 @@ class TestHuggingFaceDatasetSaver:
|
||||
@patch("huggingface_hub.hf_hub_download")
|
||||
@patch("huggingface_hub.upload_folder")
|
||||
@patch("huggingface_hub.upload_file")
|
||||
@patch("huggingface_hub.metadata_update")
|
||||
def test_saver_flag_same_dir(
|
||||
self, mock_upload_file, mock_upload, mock_download, mock_create
|
||||
self, metadata_update, mock_upload_file, mock_upload, mock_download, mock_create
|
||||
):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(
|
||||
@ -89,8 +91,9 @@ class TestHuggingFaceDatasetSaver:
|
||||
@patch("huggingface_hub.hf_hub_download")
|
||||
@patch("huggingface_hub.upload_folder")
|
||||
@patch("huggingface_hub.upload_file")
|
||||
@patch("huggingface_hub.metadata_update")
|
||||
def test_saver_flag_separate_dirs(
|
||||
self, mock_upload_file, mock_upload, mock_download, mock_create
|
||||
self, metadata_update, mock_upload_file, mock_upload, mock_download, mock_create
|
||||
):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(
|
||||
|
Loading…
x
Reference in New Issue
Block a user