mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-18 12:50:30 +08:00
Fix to HuggingFaceDatasetSaver (#3025)
* hf flag fix * fixed huggingface hub params * formatting * fix flagging tests * add a try / catch
This commit is contained in:
parent
862a8c7c71
commit
61d2f15562
@ -40,6 +40,7 @@ By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3011](https://git
|
||||
* Fix bug where the Image component could not serialize image urls by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2957](https://github.com/gradio-app/gradio/pull/2957)
|
||||
* Fix forwarding for guides after SEO renaming by [@aliabd](https://github.com/aliabd) in [PR 3017](https://github.com/gradio-app/gradio/pull/3017)
|
||||
* Switch all pages on the website to use latest stable gradio by [@aliabd](https://github.com/aliabd) in [PR 3016](https://github.com/gradio-app/gradio/pull/3016)
|
||||
* Fix bug related to deprecated parameters in `huggingface_hub` for the HuggingFaceDatasetSaver in [PR 3025](https://github.com/gradio-app/gradio/pull/3025)
|
||||
* Added better support for symlinks in the way absolute paths are resolved by [@abidlabs](https://github.com/abidlabs) in [PR 3037](https://github.com/gradio-app/gradio/pull/3037)
|
||||
* Fix several minor frontend bugs (loading animation, examples as gallery) frontend [@aliabid94](https://github.com/3026) in [PR 2961](https://github.com/gradio-app/gradio/pull/3026).
|
||||
|
||||
|
@ -7,9 +7,12 @@ import json
|
||||
import os
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from distutils.version import StrictVersion
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
|
||||
import pkg_resources
|
||||
|
||||
import gradio as gr
|
||||
from gradio import encryptor, utils
|
||||
from gradio.documentation import document, set_documentation_group
|
||||
@ -326,8 +329,20 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
"Package `huggingface_hub` not found is needed "
|
||||
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
|
||||
)
|
||||
hh_version = pkg_resources.get_distribution("huggingface_hub").version
|
||||
try:
|
||||
if StrictVersion(hh_version) < StrictVersion("0.6.0"):
|
||||
raise ImportError(
|
||||
"The `huggingface_hub` package must be version 0.6.0 or higher"
|
||||
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub --upgrade'."
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
repo_id = huggingface_hub.get_full_repo_name(
|
||||
self.dataset_name, token=self.hf_token
|
||||
)
|
||||
path_to_dataset_repo = huggingface_hub.create_repo(
|
||||
name=self.dataset_name,
|
||||
repo_id=repo_id,
|
||||
token=self.hf_token,
|
||||
private=self.dataset_private,
|
||||
repo_type="dataset",
|
||||
@ -409,7 +424,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hf_foken: str,
|
||||
hf_token: str,
|
||||
dataset_name: str,
|
||||
organization: str | None = None,
|
||||
private: bool = False,
|
||||
@ -428,7 +443,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
|
||||
verbose (bool): Whether to print out the status of the dataset
|
||||
creation.
|
||||
"""
|
||||
self.hf_foken = hf_foken
|
||||
self.hf_token = hf_token
|
||||
self.dataset_name = dataset_name
|
||||
self.organization_name = organization
|
||||
self.dataset_private = private
|
||||
@ -448,9 +463,21 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
|
||||
"Package `huggingface_hub` not found is needed "
|
||||
"for HuggingFaceDatasetJSONSaver. Try 'pip install huggingface_hub'."
|
||||
)
|
||||
hh_version = pkg_resources.get_distribution("huggingface_hub").version
|
||||
try:
|
||||
if StrictVersion(hh_version) < StrictVersion("0.6.0"):
|
||||
raise ImportError(
|
||||
"The `huggingface_hub` package must be version 0.6.0 or higher"
|
||||
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub --upgrade'."
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
repo_id = huggingface_hub.get_full_repo_name(
|
||||
self.dataset_name, token=self.hf_token
|
||||
)
|
||||
path_to_dataset_repo = huggingface_hub.create_repo(
|
||||
name=self.dataset_name,
|
||||
token=self.hf_foken,
|
||||
repo_id=repo_id,
|
||||
token=self.hf_token,
|
||||
private=self.dataset_private,
|
||||
repo_type="dataset",
|
||||
exist_ok=True,
|
||||
@ -462,7 +489,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
|
||||
self.repo = huggingface_hub.Repository(
|
||||
local_dir=str(self.dataset_dir),
|
||||
clone_from=path_to_dataset_repo,
|
||||
use_auth_token=self.hf_foken,
|
||||
use_auth_token=self.hf_token,
|
||||
)
|
||||
self.repo.git_pull(lfs=True)
|
||||
|
||||
|
@ -43,6 +43,7 @@ class TestSimpleFlagging:
|
||||
|
||||
class TestHuggingFaceDatasetSaver:
|
||||
def test_saver_setup(self):
|
||||
huggingface_hub.get_full_repo_name = MagicMock(return_value="test/test")
|
||||
huggingface_hub.create_repo = MagicMock()
|
||||
huggingface_hub.Repository = MagicMock()
|
||||
flagger = flagging.HuggingFaceDatasetSaver("test", "test")
|
||||
@ -51,6 +52,7 @@ class TestHuggingFaceDatasetSaver:
|
||||
huggingface_hub.create_repo.assert_called_once()
|
||||
|
||||
def test_saver_flag(self):
|
||||
huggingface_hub.get_full_repo_name = MagicMock(return_value="test/test")
|
||||
huggingface_hub.create_repo = MagicMock()
|
||||
huggingface_hub.Repository = MagicMock()
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@ -71,6 +73,7 @@ class TestHuggingFaceDatasetSaver:
|
||||
|
||||
class TestHuggingFaceDatasetJSONSaver:
|
||||
def test_saver_setup(self):
|
||||
huggingface_hub.get_full_repo_name = MagicMock(return_value="test/test")
|
||||
huggingface_hub.create_repo = MagicMock()
|
||||
huggingface_hub.Repository = MagicMock()
|
||||
flagger = flagging.HuggingFaceDatasetJSONSaver("test", "test")
|
||||
@ -79,6 +82,7 @@ class TestHuggingFaceDatasetJSONSaver:
|
||||
huggingface_hub.create_repo.assert_called_once()
|
||||
|
||||
def test_saver_flag(self):
|
||||
huggingface_hub.get_full_repo_name = MagicMock(return_value="test/test")
|
||||
huggingface_hub.create_repo = MagicMock()
|
||||
huggingface_hub.Repository = MagicMock()
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
Loading…
x
Reference in New Issue
Block a user