From 55c69891e6466fedb08e8b70294963d61f126704 Mon Sep 17 00:00:00 2001 From: Ali Abid Date: Wed, 9 Mar 2022 22:46:52 +0000 Subject: [PATCH] changes --- demo/streaming_stt/run.py | 17 ++++---------- gradio/routes.py | 12 ++++++---- ui/packages/app/src/Interface.svelte | 22 ++++++++++++------- .../src/components/input/Audio/Audio.svelte | 22 ++++++++++++++----- 4 files changed, 42 insertions(+), 31 deletions(-) diff --git a/demo/streaming_stt/run.py b/demo/streaming_stt/run.py index 9efa21c521..da14e22e7a 100644 --- a/demo/streaming_stt/run.py +++ b/demo/streaming_stt/run.py @@ -19,14 +19,12 @@ def reformat_freq(sr, y): if sr not in ( 48000, 16000, - ): # Deepspeech only supports 16k, (we hackily convert 48k -> 16k) + ): # Deepspeech only supports 16k, (we convert 48k -> 16k) raise ValueError("Unsupported rate", sr) if sr == 48000: y = ( ((y / max(np.max(y), 1)) * 32767) - .reshape((-1, 4)) - .mean(axis=1) - .reshape((-1, 4)) + .reshape((-1, 3)) .mean(axis=1) .astype("int16") ) @@ -34,7 +32,7 @@ def reformat_freq(sr, y): return sr, y -def transcribe_stream(speech, stream): +def transcribe(speech, stream): _, y = reformat_freq(*speech) if stream is None: stream = model.createStream() @@ -42,11 +40,4 @@ def transcribe_stream(speech, stream): text = stream.intermediateDecode() return text, stream -def transcribe(speech): - _, y = reformat_freq(*speech) - stream = model.createStream() - stream.feedAudioContent(y) - text = stream.intermediateDecode() - return text - -gr.Interface(transcribe, ["microphone"], ["text"]).launch() +gr.Interface(transcribe, ["microphone", "state"], ["text", "state"], live=True).launch() diff --git a/gradio/routes.py b/gradio/routes.py index 36385ce4e1..9399fd2a87 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -72,6 +72,7 @@ class PredictBody(BaseModel): data: List[Any] state: Optional[Any] fn_index: Optional[int] + cleared: Optional[bool] class FlagData(BaseModel): @@ -252,10 +253,13 @@ def api_docs(request: Request): async def predict(body: PredictBody, username: str = Depends(get_current_user)): if app.launchable.stateful: session_hash = body.session_hash - state = app.state_holder.get( - (session_hash, "state"), app.launchable.state_default - ) - body.state = state + if body.cleared: + body.state = None + else: + state = app.state_holder.get( + (session_hash, "state"), app.launchable.state_default + ) + body.state = state try: output = await run_in_threadpool(app.launchable.process_api, body, username) if app.launchable.stateful: diff --git a/ui/packages/app/src/Interface.svelte b/ui/packages/app/src/Interface.svelte index 8165c6e0e5..eb2eb920d5 100644 --- a/ui/packages/app/src/Interface.svelte +++ b/ui/packages/app/src/Interface.svelte @@ -35,6 +35,7 @@ let queue_index: number | null = null; let initial_queue_index: number | null = null; let just_flagged: boolean = false; + let cleared_since_last_submit = false; const default_inputs: Array = input_components.map((component) => "default" in component ? component.default : null @@ -107,7 +108,7 @@ try { output = await fn( "predict", - { data: input_values }, + { data: input_values, cleared: cleared_since_last_submit }, queue, queueCallback ); @@ -128,6 +129,7 @@ } stopTimer(); output_values = output["data"]; + cleared_since_last_submit = false; for (let [i, value] of output_values.entries()) { if (output_components[i].name === "state") { for (let [j, input_component] of input_components.entries()) { @@ -158,6 +160,7 @@ output_values = deepCopy(default_outputs); interpret_mode = false; state = "START"; + cleared_since_last_submit = true; stopTimer(); }; const flag = (flag_option: string | null) => { @@ -223,6 +226,7 @@ {...input_component} {theme} {static_src} + {live} value={input_values[i]} interpretation={interpret_mode ? interpretation_values[i] @@ -240,19 +244,21 @@ > {$_("interface.clear")} - + {#if !live} + + {/if}
{#if state !== "START"}
diff --git a/ui/packages/app/src/components/input/Audio/Audio.svelte b/ui/packages/app/src/components/input/Audio/Audio.svelte index 5ec6c7f6e6..9f52fd09ea 100644 --- a/ui/packages/app/src/components/input/Audio/Audio.svelte +++ b/ui/packages/app/src/components/input/Audio/Audio.svelte @@ -25,7 +25,8 @@ let player; let inited = false; let crop_values = [0, 100]; - let converting_blob = false; + let submitting_data = false; + let record_interval; async function generate_data(): Promise<{ data: string; @@ -54,16 +55,15 @@ recorder.addEventListener("dataavailable", async (event) => { audio_chunks.push(event.data); - if (live && !converting_blob) { - converting_blob = true; + if (live && !submitting_data) { + submitting_data = true; await setValue(await generate_data()); - converting_blob = false; + submitting_data = false; + audio_chunks = []; } }); recorder.addEventListener("stop", async () => { - recording = false; - if (!live) { setValue(await generate_data()); } @@ -77,6 +77,12 @@ if (!inited) await prepare_audio(); recorder.start(); + if (live) { + record_interval = setInterval(() => { + recorder.stop(); + recorder.start(); + }, 1000) + } } onDestroy(() => { @@ -86,7 +92,11 @@ }); const stop = () => { + recording = false; recorder.stop(); + if (live) { + clearInterval(record_interval); + } }; function clear() {