added code for previewing audio

This commit is contained in:
Abubakar Abid 2022-01-19 10:40:28 -06:00
parent 6765928355
commit b8b02b65da

View File

@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
import csv
import datetime
import io
import json
import os
from typing import Callable, Any, List, Optional, Tuple, TYPE_CHECKING
@ -245,6 +246,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
path_to_dataset_repo = huggingface_hub.create_repo(
name=self.dataset_name, token=self.hf_foken,
private=self.dataset_private, repo_type="dataset", exist_ok=True)
self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
self.flagging_dir = flagging_dir
self.dataset_dir = os.path.join(flagging_dir, self.dataset_name)
self.repo = huggingface_hub.Repository(
@ -254,6 +256,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
#Should filename be user-specified?
self.log_file = os.path.join(self.dataset_dir, "data.csv")
self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json")
def flag(
self,
@ -265,20 +268,34 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
username: Optional[str] = None
) -> int:
is_new = not os.path.exists(self.log_file)
infos = {"flagged": {"features": {}}}
with open(self.log_file, "a", newline="") as csvfile:
writer = csv.writer(csvfile)
# File previews for certain input and output types
file_preview_types = (gr.inputs.Audio, gr.outputs.Audio)
file_preview_types = {
gr.inputs.Audio: "Audio",
gr.outputs.Audio: "Audio"
}
# Generate the headers
if is_new:
headers = []
for component in interface.input_components + interface.output_components:
headers.append(component.label)
infos["flagged"]["features"][component.label] = {
"dtype": "string",
"_type": "Value"
}
if isinstance(component, file_preview_types):
headers.append(interface.label + " file")
for _component, _type in file_preview_types.items():
if isinstance(component, _component):
infos["flagged"]["features"][component.label] = {
"_type": _type
}
break
if interface.flagging_options is not None:
headers.append("flag")
writer.writerow(headers)
@ -289,16 +306,23 @@ class HuggingFaceDatasetSaver(FlaggingCallback):
filepath = input.save_flagged(self.dataset_dir, input.label, input_data[i], None)
csv_data.append(filepath)
if isinstance(component, file_preview_types):
csv_data.append("https://huggingface.co/datasets/abidlabs/test-audio-1/resolve/main/x/0.wav")
csv_data.append("{}/resolve/main/{}".format(
self.path_to_dataset_repo, filepath))
for i, output in enumerate(interface.output_components):
csv_data.append(output.save_flagged(self.dataset_dir, interface.config["output_components"][i]["label"], output_data[i], None) if
output_data[i] is not None else "")
if isinstance(component, file_preview_types):
csv_data.append("{}/resolve/main/{}".format(
self.path_to_dataset_repo, filepath))
if flag_option is not None:
csv_data.append(flag_option)
# Write the rows
writer.writerow(csv_data)
if is_new:
json.dump(infos, open(self.infos_file, "w"))
# return number of samples in dataset
with open(self.log_file, "r") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1