diff --git a/.changeset/weak-streets-check.md b/.changeset/weak-streets-check.md new file mode 100644 index 0000000000..4ba7eac331 --- /dev/null +++ b/.changeset/weak-streets-check.md @@ -0,0 +1,6 @@ +--- +"@gradio/audio": minor +"gradio": minor +--- + +fix:Add sample rate config option to `gr.Audio()` diff --git a/gradio/components/audio.py b/gradio/components/audio.py index 46b4218b34..589e0d1b4f 100644 --- a/gradio/components/audio.py +++ b/gradio/components/audio.py @@ -30,6 +30,7 @@ class WaveformOptions: show_recording_waveform: Whether to show the waveform when recording audio. Defaults to True. show_controls: Whether to show the standard HTML audio player below the waveform when recording audio or playing recorded audio. Defaults to False. skip_length: The percentage (between 0 and 100) of the audio to skip when clicking on the skip forward / skip backward buttons. Defaults to 5. + sample_rate: The output sample rate (in Hz) of the audio after editing. Defaults to 44100. """ waveform_color: str = "#9ca3af" @@ -37,6 +38,7 @@ class WaveformOptions: show_recording_waveform: bool = True show_controls: bool = False skip_length: int | float = 5 + sample_rate: int = 44100 @document() @@ -161,11 +163,10 @@ class Audio( self.editable = editable if waveform_options is None: self.waveform_options = WaveformOptions() - self.waveform_options = ( - WaveformOptions(**waveform_options) - if isinstance(waveform_options, dict) - else waveform_options - ) + elif isinstance(waveform_options, dict): + self.waveform_options = WaveformOptions(**waveform_options) + else: + self.waveform_options = waveform_options self.min_length = min_length self.max_length = max_length super().__init__( diff --git a/js/audio/Index.svelte b/js/audio/Index.svelte index a34278e524..b3df6b0760 100644 --- a/js/audio/Index.svelte +++ b/js/audio/Index.svelte @@ -104,7 +104,8 @@ dragToSeek: true, normalize: true, minPxPerSec: 20, - mediaControls: waveform_options.show_controls + mediaControls: waveform_options.show_controls, + sampleRate: waveform_options.sample_rate || 44100 }; const trim_region_settings = { diff --git a/js/audio/player/AudioPlayer.svelte b/js/audio/player/AudioPlayer.svelte index 579fc22fab..b7bd5829e0 100644 --- a/js/audio/player/AudioPlayer.svelte +++ b/js/audio/player/AudioPlayer.svelte @@ -112,13 +112,16 @@ mode = ""; const decodedData = waveform?.getDecodedData(); if (decodedData) - await process_audio(decodedData, start, end).then( - async (trimmedBlob: Uint8Array) => { - await dispatch_blob([trimmedBlob], "change"); - waveform?.destroy(); - create_waveform(); - } - ); + await process_audio( + decodedData, + start, + end, + waveform_settings.sampleRate + ).then(async (trimmedBlob: Uint8Array) => { + await dispatch_blob([trimmedBlob], "change"); + waveform?.destroy(); + container.innerHTML = ""; + }); dispatch("edit"); }; diff --git a/js/audio/recorder/AudioRecorder.svelte b/js/audio/recorder/AudioRecorder.svelte index aedc24c419..a0608a9b75 100644 --- a/js/audio/recorder/AudioRecorder.svelte +++ b/js/audio/recorder/AudioRecorder.svelte @@ -82,7 +82,9 @@ timing = false; clearInterval(interval); const array_buffer = await blob.arrayBuffer(); - const context = new AudioContext(); + const context = new AudioContext({ + sampleRate: waveform_settings.sampleRate + }); const audio_buffer = await context.decodeAudioData(array_buffer); if (audio_buffer) diff --git a/js/audio/shared/audioBufferToWav.ts b/js/audio/shared/audioBufferToWav.ts index 83872e6c0f..578821cf04 100644 --- a/js/audio/shared/audioBufferToWav.ts +++ b/js/audio/shared/audioBufferToWav.ts @@ -47,7 +47,9 @@ export function audioBufferToWav(audioBuffer: AudioBuffer): Uint8Array { for (let i = 0; i < audioBuffer.numberOfChannels; i++) { const channel = audioBuffer.getChannelData(i); for (let j = 0; j < channel.length; j++) { - view.setInt16(offset, channel[j] * 0xffff, true); + // Scaling Float32 to Int16 + const sample = Math.max(-1, Math.min(1, channel[j])); + view.setInt16(offset, sample * 0x7fff, true); offset += 2; } } diff --git a/js/audio/shared/types.ts b/js/audio/shared/types.ts index 681571b19b..0c96415536 100644 --- a/js/audio/shared/types.ts +++ b/js/audio/shared/types.ts @@ -5,4 +5,5 @@ export type WaveformOptions = { skip_length?: number; trim_region_color?: string; show_recording_waveform?: boolean; + sample_rate?: number; }; diff --git a/js/audio/shared/utils.ts b/js/audio/shared/utils.ts index 033949766a..b01b46ebfe 100644 --- a/js/audio/shared/utils.ts +++ b/js/audio/shared/utils.ts @@ -1,5 +1,4 @@ import type WaveSurfer from "wavesurfer.js"; -import Regions from "wavesurfer.js/dist/plugins/regions.js"; import { audioBufferToWav } from "./audioBufferToWav"; export interface LoadedParams { @@ -18,11 +17,14 @@ export function blob_to_data_url(blob: Blob): Promise { export const process_audio = async ( audioBuffer: AudioBuffer, start?: number, - end?: number + end?: number, + waveform_sample_rate?: number ): Promise => { - const audioContext = new AudioContext(); + const audioContext = new AudioContext({ + sampleRate: waveform_sample_rate || audioBuffer.sampleRate + }); const numberOfChannels = audioBuffer.numberOfChannels; - const sampleRate = audioBuffer.sampleRate; + const sampleRate = waveform_sample_rate || audioBuffer.sampleRate; let trimmedLength = audioBuffer.length; let startOffset = 0; diff --git a/test/test_components.py b/test/test_components.py index eaf110504d..a0d4765bd2 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -836,7 +836,14 @@ class TestAudio: "streamable": False, "max_length": None, "min_length": None, - "waveform_options": None, + "waveform_options": { + "sample_rate": 44100, + "show_controls": False, + "show_recording_waveform": True, + "skip_length": 5, + "waveform_color": "#9ca3af", + "waveform_progress_color": "#f97316", + }, "_selectable": False, } assert audio_input.preprocess(None) is None @@ -881,7 +888,14 @@ class TestAudio: "format": "wav", "streamable": False, "sources": ["upload", "microphone"], - "waveform_options": None, + "waveform_options": { + "sample_rate": 44100, + "show_controls": False, + "show_recording_waveform": True, + "skip_length": 5, + "waveform_color": "#9ca3af", + "waveform_progress_color": "#f97316", + }, "_selectable": False, }