2
0
mirror of https://github.com/gradio-app/gradio.git synced 2025-04-18 12:50:30 +08:00

Fix to HuggingFaceDatasetSaver ()

* hf flag fix

* fixed huggingface hub params

* formatting

* fix flagging tests

* add a try / catch
This commit is contained in:
Abubakar Abid 2023-01-20 20:15:45 -08:00 committed by GitHub
parent 862a8c7c71
commit 61d2f15562
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 6 deletions

@ -40,6 +40,7 @@ By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3011](https://git
* Fix bug where the Image component could not serialize image urls by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2957](https://github.com/gradio-app/gradio/pull/2957)
* Fix forwarding for guides after SEO renaming by [@aliabd](https://github.com/aliabd) in [PR 3017](https://github.com/gradio-app/gradio/pull/3017)
* Switch all pages on the website to use latest stable gradio by [@aliabd](https://github.com/aliabd) in [PR 3016](https://github.com/gradio-app/gradio/pull/3016)
* Fix bug related to deprecated parameters in `huggingface_hub` for the HuggingFaceDatasetSaver in [PR 3025](https://github.com/gradio-app/gradio/pull/3025)
* Added better support for symlinks in the way absolute paths are resolved by [@abidlabs](https://github.com/abidlabs) in [PR 3037](https://github.com/gradio-app/gradio/pull/3037)
* Fix several minor frontend bugs (loading animation, examples as gallery) frontend [@aliabid94](https://github.com/3026) in [PR 2961](https://github.com/gradio-app/gradio/pull/3026).

@ -7,9 +7,12 @@ import json
import os
import uuid
from abc import ABC, abstractmethod
from distutils.version import StrictVersion
from pathlib import Path
from typing import TYPE_CHECKING, Any, List
import pkg_resources
import gradio as gr
from gradio import encryptor, utils
from gradio.documentation import document, set_documentation_group
@ -326,8 +329,20 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
"Package `huggingface_hub` not found is needed "
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
)
hh_version = pkg_resources.get_distribution("huggingface_hub").version
try:
if StrictVersion(hh_version) < StrictVersion("0.6.0"):
raise ImportError(
"The `huggingface_hub` package must be version 0.6.0 or higher"
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub --upgrade'."
)
except ValueError:
pass
repo_id = huggingface_hub.get_full_repo_name(
self.dataset_name, token=self.hf_token
)
path_to_dataset_repo = huggingface_hub.create_repo(
name=self.dataset_name,
repo_id=repo_id,
token=self.hf_token,
private=self.dataset_private,
repo_type="dataset",
@ -409,7 +424,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
def __init__(
self,
hf_foken: str,
hf_token: str,
dataset_name: str,
organization: str | None = None,
private: bool = False,
@ -428,7 +443,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
verbose (bool): Whether to print out the status of the dataset
creation.
"""
self.hf_foken = hf_foken
self.hf_token = hf_token
self.dataset_name = dataset_name
self.organization_name = organization
self.dataset_private = private
@ -448,9 +463,21 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
"Package `huggingface_hub` not found is needed "
"for HuggingFaceDatasetJSONSaver. Try 'pip install huggingface_hub'."
)
hh_version = pkg_resources.get_distribution("huggingface_hub").version
try:
if StrictVersion(hh_version) < StrictVersion("0.6.0"):
raise ImportError(
"The `huggingface_hub` package must be version 0.6.0 or higher"
"for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub --upgrade'."
)
except ValueError:
pass
repo_id = huggingface_hub.get_full_repo_name(
self.dataset_name, token=self.hf_token
)
path_to_dataset_repo = huggingface_hub.create_repo(
name=self.dataset_name,
token=self.hf_foken,
repo_id=repo_id,
token=self.hf_token,
private=self.dataset_private,
repo_type="dataset",
exist_ok=True,
@ -462,7 +489,7 @@ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
self.repo = huggingface_hub.Repository(
local_dir=str(self.dataset_dir),
clone_from=path_to_dataset_repo,
use_auth_token=self.hf_foken,
use_auth_token=self.hf_token,
)
self.repo.git_pull(lfs=True)

@ -43,6 +43,7 @@ class TestSimpleFlagging:
class TestHuggingFaceDatasetSaver:
def test_saver_setup(self):
huggingface_hub.get_full_repo_name = MagicMock(return_value="test/test")
huggingface_hub.create_repo = MagicMock()
huggingface_hub.Repository = MagicMock()
flagger = flagging.HuggingFaceDatasetSaver("test", "test")
@ -51,6 +52,7 @@ class TestHuggingFaceDatasetSaver:
huggingface_hub.create_repo.assert_called_once()
def test_saver_flag(self):
huggingface_hub.get_full_repo_name = MagicMock(return_value="test/test")
huggingface_hub.create_repo = MagicMock()
huggingface_hub.Repository = MagicMock()
with tempfile.TemporaryDirectory() as tmpdirname:
@ -71,6 +73,7 @@ class TestHuggingFaceDatasetSaver:
class TestHuggingFaceDatasetJSONSaver:
def test_saver_setup(self):
huggingface_hub.get_full_repo_name = MagicMock(return_value="test/test")
huggingface_hub.create_repo = MagicMock()
huggingface_hub.Repository = MagicMock()
flagger = flagging.HuggingFaceDatasetJSONSaver("test", "test")
@ -79,6 +82,7 @@ class TestHuggingFaceDatasetJSONSaver:
huggingface_hub.create_repo.assert_called_once()
def test_saver_flag(self):
huggingface_hub.get_full_repo_name = MagicMock(return_value="test/test")
huggingface_hub.create_repo = MagicMock()
huggingface_hub.Repository = MagicMock()
with tempfile.TemporaryDirectory() as tmpdirname: