From b8b02b65da8e08e0bcc63664e7bdfd643b695914 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Wed, 19 Jan 2022 10:40:28 -0600 Subject: [PATCH] added code for previewing audio --- gradio/flagging.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/gradio/flagging.py b/gradio/flagging.py index a71b32e2c8..f13d1ff2ed 100644 --- a/gradio/flagging.py +++ b/gradio/flagging.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod import csv import datetime import io +import json import os from typing import Callable, Any, List, Optional, Tuple, TYPE_CHECKING @@ -245,6 +246,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): path_to_dataset_repo = huggingface_hub.create_repo( name=self.dataset_name, token=self.hf_foken, private=self.dataset_private, repo_type="dataset", exist_ok=True) + self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10" self.flagging_dir = flagging_dir self.dataset_dir = os.path.join(flagging_dir, self.dataset_name) self.repo = huggingface_hub.Repository( @@ -254,6 +256,7 @@ class HuggingFaceDatasetSaver(FlaggingCallback): #Should filename be user-specified? self.log_file = os.path.join(self.dataset_dir, "data.csv") + self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json") def flag( self, @@ -265,20 +268,34 @@ class HuggingFaceDatasetSaver(FlaggingCallback): username: Optional[str] = None ) -> int: is_new = not os.path.exists(self.log_file) + infos = {"flagged": {"features": {}}} with open(self.log_file, "a", newline="") as csvfile: writer = csv.writer(csvfile) # File previews for certain input and output types - file_preview_types = (gr.inputs.Audio, gr.outputs.Audio) + file_preview_types = { + gr.inputs.Audio: "Audio", + gr.outputs.Audio: "Audio" + } # Generate the headers if is_new: headers = [] for component in interface.input_components + interface.output_components: headers.append(component.label) + infos["flagged"]["features"][component.label] = { + "dtype": "string", + "_type": "Value" + } if isinstance(component, file_preview_types): headers.append(interface.label + " file") + for _component, _type in file_preview_types.items(): + if isinstance(component, _component): + infos["flagged"]["features"][component.label] = { + "_type": _type + } + break if interface.flagging_options is not None: headers.append("flag") writer.writerow(headers) @@ -289,16 +306,23 @@ class HuggingFaceDatasetSaver(FlaggingCallback): filepath = input.save_flagged(self.dataset_dir, input.label, input_data[i], None) csv_data.append(filepath) if isinstance(component, file_preview_types): - csv_data.append("https://huggingface.co/datasets/abidlabs/test-audio-1/resolve/main/x/0.wav") + csv_data.append("{}/resolve/main/{}".format( + self.path_to_dataset_repo, filepath)) for i, output in enumerate(interface.output_components): csv_data.append(output.save_flagged(self.dataset_dir, interface.config["output_components"][i]["label"], output_data[i], None) if output_data[i] is not None else "") + if isinstance(component, file_preview_types): + csv_data.append("{}/resolve/main/{}".format( + self.path_to_dataset_repo, filepath)) if flag_option is not None: csv_data.append(flag_option) # Write the rows writer.writerow(csv_data) + if is_new: + json.dump(infos, open(self.infos_file, "w")) + # return number of samples in dataset with open(self.log_file, "r") as csvfile: line_count = len([None for row in csv.reader(csvfile)]) - 1