Add format argument to Audio (#4178)

* experimental

* Add test

* Rename to format

* Rename

* CHANGELOG

* Add to docstring

* Update gradio/components.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2023-05-12 16:56:07 -04:00 committed by GitHub
parent aecf8feb1d
commit 8deab23623
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 28 additions and 8 deletions

View File

@ -2,6 +2,7 @@
## New Features:
- Added `format` argument to `Audio` component by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 4178](https://github.com/gradio-app/gradio/pull/4178)
- Add JS client code snippets to use via api page by [@aliabd](https://github.com/aliabd) in [PR 3927](https://github.com/gradio-app/gradio/pull/3927).
No changes to highlight.

View File

@ -2306,6 +2306,7 @@ class Audio(
streaming: bool = False,
elem_id: str | None = None,
elem_classes: list[str] | str | None = None,
format: Literal["wav", "mp3"] = "wav",
**kwargs,
):
"""
@ -2321,6 +2322,7 @@ class Audio(
streaming: If set to True when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'microphone'.
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
format: The file format to save audio files. Either 'wav' or 'mp3'. wav files are lossless but will tend to be larger files. mp3 files tend to be smaller. Default is wav. Applies both when this component is used as an input (when `type` is "format") and when this component is used as an output.
"""
valid_sources = ["upload", "microphone"]
if source not in valid_sources:
@ -2352,6 +2354,7 @@ class Audio(
**kwargs,
)
TokenInterpretable.__init__(self)
self.format = format
def get_config(self):
return {
@ -2427,8 +2430,11 @@ class Audio(
if self.type == "numpy":
return sample_rate, data
elif self.type == "filepath":
processing_utils.audio_to_file(sample_rate, data, output_file_name)
return output_file_name
output_file = str(Path(output_file_name).with_suffix(f".{self.format}"))
processing_utils.audio_to_file(
sample_rate, data, output_file, format=self.format
)
return output_file
else:
raise ValueError(
"Unknown type: "
@ -2527,8 +2533,10 @@ class Audio(
return {"name": y, "data": None, "is_file": True}
if isinstance(y, tuple):
sample_rate, data = y
file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
processing_utils.audio_to_file(sample_rate, data, file.name)
file = tempfile.NamedTemporaryFile(suffix=f".{self.format}", delete=False)
processing_utils.audio_to_file(
sample_rate, data, file.name, format=self.format
)
file_path = str(utils.abspath(file.name))
self.temp_files.add(file_path)
else:

View File

@ -724,7 +724,7 @@ def make_waveform(
audio = processing_utils.audio_from_file(audio)
else:
tmp_wav = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
processing_utils.audio_to_file(audio[0], audio[1], tmp_wav.name)
processing_utils.audio_to_file(audio[0], audio[1], tmp_wav.name, format="wav")
audio_file = tmp_wav.name
duration = round(len(audio[1]) / audio[0], 4)

View File

@ -159,15 +159,16 @@ def audio_from_file(filename, crop_min=0, crop_max=100):
return audio.frame_rate, data
def audio_to_file(sample_rate, data, filename):
data = convert_to_16_bit_wav(data)
def audio_to_file(sample_rate, data, filename, format="wav"):
if format == "wav":
data = convert_to_16_bit_wav(data)
audio = AudioSegment(
data.tobytes(),
frame_rate=sample_rate,
sample_width=data.dtype.itemsize,
channels=(1 if len(data.shape) == 1 else data.shape[1]),
)
file = audio.export(filename, format="wav")
file = audio.export(filename, format=format)
file.close() # type: ignore

View File

@ -911,6 +911,16 @@ class TestAudio:
output = audio_input.preprocess(x_wav)
wavfile.read(output)
def test_prepost_process_to_mp3(self):
x_wav = deepcopy(media_data.BASE64_MICROPHONE)
audio_input = gr.Audio(type="filepath", format="mp3")
output = audio_input.preprocess(x_wav)
assert output.endswith("mp3")
output = audio_input.postprocess(
(48000, np.random.randint(-256, 256, (5, 3)).astype(np.int16))
)
assert output["name"].endswith("mp3")
class TestFile:
def test_component_functions(self):