mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-27 02:30:17 +08:00
added code for previewing audio
This commit is contained in:
parent
6765928355
commit
b8b02b65da
@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
|
||||
import csv
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from typing import Callable, Any, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
@ -245,6 +246,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
path_to_dataset_repo = huggingface_hub.create_repo(
|
||||
name=self.dataset_name, token=self.hf_foken,
|
||||
private=self.dataset_private, repo_type="dataset", exist_ok=True)
|
||||
self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
|
||||
self.flagging_dir = flagging_dir
|
||||
self.dataset_dir = os.path.join(flagging_dir, self.dataset_name)
|
||||
self.repo = huggingface_hub.Repository(
|
||||
@ -254,6 +256,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
|
||||
#Should filename be user-specified?
|
||||
self.log_file = os.path.join(self.dataset_dir, "data.csv")
|
||||
self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json")
|
||||
|
||||
def flag(
|
||||
self,
|
||||
@ -265,20 +268,34 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
username: Optional[str] = None
|
||||
) -> int:
|
||||
is_new = not os.path.exists(self.log_file)
|
||||
infos = {"flagged": {"features": {}}}
|
||||
|
||||
with open(self.log_file, "a", newline="") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
|
||||
# File previews for certain input and output types
|
||||
file_preview_types = (gr.inputs.Audio, gr.outputs.Audio)
|
||||
file_preview_types = {
|
||||
gr.inputs.Audio: "Audio",
|
||||
gr.outputs.Audio: "Audio"
|
||||
}
|
||||
|
||||
# Generate the headers
|
||||
if is_new:
|
||||
headers = []
|
||||
for component in interface.input_components + interface.output_components:
|
||||
headers.append(component.label)
|
||||
infos["flagged"]["features"][component.label] = {
|
||||
"dtype": "string",
|
||||
"_type": "Value"
|
||||
}
|
||||
if isinstance(component, file_preview_types):
|
||||
headers.append(interface.label + " file")
|
||||
for _component, _type in file_preview_types.items():
|
||||
if isinstance(component, _component):
|
||||
infos["flagged"]["features"][component.label] = {
|
||||
"_type": _type
|
||||
}
|
||||
break
|
||||
if interface.flagging_options is not None:
|
||||
headers.append("flag")
|
||||
writer.writerow(headers)
|
||||
@ -289,16 +306,23 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
|
||||
filepath = input.save_flagged(self.dataset_dir, input.label, input_data[i], None)
|
||||
csv_data.append(filepath)
|
||||
if isinstance(component, file_preview_types):
|
||||
csv_data.append("https://huggingface.co/datasets/abidlabs/test-audio-1/resolve/main/x/0.wav")
|
||||
csv_data.append("{}/resolve/main/{}".format(
|
||||
self.path_to_dataset_repo, filepath))
|
||||
for i, output in enumerate(interface.output_components):
|
||||
csv_data.append(output.save_flagged(self.dataset_dir, interface.config["output_components"][i]["label"], output_data[i], None) if
|
||||
output_data[i] is not None else "")
|
||||
if isinstance(component, file_preview_types):
|
||||
csv_data.append("{}/resolve/main/{}".format(
|
||||
self.path_to_dataset_repo, filepath))
|
||||
if flag_option is not None:
|
||||
csv_data.append(flag_option)
|
||||
|
||||
# Write the rows
|
||||
writer.writerow(csv_data)
|
||||
|
||||
if is_new:
|
||||
json.dump(infos, open(self.infos_file, "w"))
|
||||
|
||||
# return number of samples in dataset
|
||||
with open(self.log_file, "r") as csvfile:
|
||||
line_count = len([None for row in csv.reader(csvfile)]) - 1
|
||||
|
Loading…
Reference in New Issue
Block a user