Live audio streaming output (#5077)

* changes

* add changeset

* changes

* changes

* changes

* changes

* changes

* changes

* add changeset

* changes

* changes

* changes

* changes

* changes

* changes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
aliabid94 2023-08-08 15:08:28 -07:00 committed by GitHub
parent cd1353fa3e
commit 667875b244
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 187 additions and 11 deletions

View File

@ -0,0 +1,40 @@
---
"@gradio/upload": patch
"gradio": patch
---
fix:Live audio streaming output
highlight:
#### Now supports loading streamed outputs
Allows users to use generators to stream audio out, yielding consecutive chunks of audio. Requires `streaming=True` to be set on the output audio.
```python
import gradio as gr
from pydub import AudioSegment
def stream_audio(audio_file):
audio = AudioSegment.from_mp3(audio_file)
i = 0
chunk_size = 3000
while chunk_size*i < len(audio):
chunk = audio[chunk_size*i:chunk_size*(i+1)]
i += 1
if chunk:
file = f"/tmp/{i}.mp3"
chunk.export(file, format="mp3")
yield file
demo = gr.Interface(
fn=stream_audio,
inputs=gr.Audio(type="filepath", label="Audio file to stream"),
outputs=gr.Audio(autoplay=True, streaming=True),
)
demo.queue().launch()
```
From the backend, streamed outputs are served from the `/stream/` endpoint instead of the `/file/` endpoint. Currently just used to serve audio streaming output. The output JSON will have `is_stream`: `true`, instead of `is_file`: `true` in the file data object.

View File

@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: stream_audio_out"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from pydub import AudioSegment\n", "\n", "def stream_audio(audio_file):\n", " audio = AudioSegment.from_mp3(audio_file)\n", " i = 0\n", " chunk_size = 3000\n", " \n", " while chunk_size*i < len(audio):\n", " chunk = audio[chunk_size*i:chunk_size*(i+1)]\n", " i += 1\n", " if chunk:\n", " file = f\"/tmp/{i}.mp3\"\n", " chunk.export(file, format=\"mp3\") \n", " yield file\n", " \n", "demo = gr.Interface(\n", " fn=stream_audio,\n", " inputs=gr.Audio(type=\"filepath\", label=\"Audio file to stream\"),\n", " outputs=gr.Audio(autoplay=True, streaming=True),\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -0,0 +1,24 @@
import gradio as gr
from pydub import AudioSegment
def stream_audio(audio_file):
audio = AudioSegment.from_mp3(audio_file)
i = 0
chunk_size = 3000
while chunk_size*i < len(audio):
chunk = audio[chunk_size*i:chunk_size*(i+1)]
i += 1
if chunk:
file = f"/tmp/{i}.mp3"
chunk.export(file, format="mp3")
yield file
demo = gr.Interface(
fn=stream_audio,
inputs=gr.Audio(type="filepath", label="Audio file to stream"),
outputs=gr.Audio(autoplay=True, streaming=True),
)
if __name__ == "__main__":
demo.queue().launch()

View File

@ -12,6 +12,7 @@ import time
import warnings
import webbrowser
from abc import abstractmethod
from collections import defaultdict
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Literal, cast
@ -707,6 +708,7 @@ class Blocks(BlockContext):
self.share = False
self.enable_queue = None
self.max_threads = 40
self.pending_streams = defaultdict(dict)
self.show_error = True
if css is not None and os.path.exists(css):
with open(css) as css_file:
@ -1333,6 +1335,27 @@ Received outputs:
return output
def handle_streaming_outputs(
self, fn_index: int, data: list, session_hash: str | None, run: int | None
) -> list:
if session_hash is None or run is None:
return data
from gradio.events import StreamableOutput
for i, output_id in enumerate(self.dependencies[fn_index]["outputs"]):
block = self.blocks[output_id]
if isinstance(block, StreamableOutput) and block.streaming:
stream = block.stream_output(data[i])
if run not in self.pending_streams[session_hash]:
self.pending_streams[session_hash][run] = defaultdict(list)
self.pending_streams[session_hash][run][output_id].append(stream)
data[i] = {
"name": f"{session_hash}/{run}/{output_id}",
"is_stream": True,
}
return data
async def process_api(
self,
fn_index: int,
@ -1340,6 +1363,7 @@ Received outputs:
state: dict[int, Any],
request: routes.Request | list[routes.Request] | None = None,
iterators: dict[int, Any] | None = None,
session_hash: str | None = None,
event_id: str | None = None,
event_data: EventData | None = None,
) -> dict[str, Any]:
@ -1391,10 +1415,15 @@ Received outputs:
else:
inputs = self.preprocess_data(fn_index, inputs, state)
iterator = iterators.get(fn_index, None) if iterators else None
was_generating = iterator is not None
result = await self.call_function(
fn_index, inputs, iterator, request, event_id, event_data
)
data = self.postprocess_data(fn_index, result["prediction"], state)
if result["is_generating"] or was_generating:
data = self.handle_streaming_outputs(
fn_index, data, session_hash, id(iterator)
)
is_generating, iterator = result["is_generating"], result["iterator"]
block_fn.total_runtime += result["duration"]

View File

@ -7,6 +7,7 @@ from pathlib import Path
from typing import Any, Callable, Literal
import numpy as np
import requests
from gradio_client import media_data
from gradio_client import utils as client_utils
from gradio_client.documentation import document, set_documentation_group
@ -20,6 +21,7 @@ from gradio.events import (
Playable,
Recordable,
Streamable,
StreamableOutput,
Uploadable,
)
from gradio.interpretation import TokenInterpretable
@ -34,6 +36,7 @@ class Audio(
Playable,
Recordable,
Streamable,
StreamableOutput,
Uploadable,
IOComponent,
FileSerializable,
@ -52,7 +55,7 @@ class Audio(
self,
value: str | Path | tuple[int, np.ndarray] | Callable | None = None,
*,
source: Literal["upload", "microphone"] = "upload",
source: Literal["upload", "microphone"] | None = None,
type: Literal["numpy", "filepath"] = "numpy",
label: str | None = None,
every: float | None = None,
@ -84,7 +87,7 @@ class Audio(
min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.
interactive: if True, will allow users to upload and edit a audio file; if False, can only be used to play audio. If not provided, this is inferred based on whether the component is used as an input or output.
visible: If False, component will be hidden.
streaming: If set to True when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'microphone'.
streaming: If set to True when used in a `live` interface as an input, will automatically stream webcam feed. When used set as an output, takes audio chunks yield from the backend and combines them into one streaming audio output.
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
format: The file format to save audio files. Either 'wav' or 'mp3'. wav files are lossless but will tend to be larger files. mp3 files tend to be smaller. Default is wav. Applies both when this component is used as an input (when `type` is "format") and when this component is used as an output.
@ -93,6 +96,7 @@ class Audio(
show_share_button: If True, will show a share icon in the corner of the component that allows user to share outputs to Hugging Face Spaces Discussions. If False, icon does not appear. If set to None (default behavior), then the icon appears if this Gradio app is launched on Spaces, but not otherwise.
"""
valid_sources = ["upload", "microphone"]
source = source if source else ("microphone" if streaming else "upload")
if source not in valid_sources:
raise ValueError(
f"Invalid value for parameter `source`: {source}. Please choose from one of: {valid_sources}"
@ -105,7 +109,7 @@ class Audio(
)
self.type = type
self.streaming = streaming
if streaming and source != "microphone":
if streaming and source == "upload":
raise ValueError(
"Audio streaming only available if source is 'microphone'."
)
@ -340,6 +344,18 @@ class Audio(
file_path = self.make_temp_copy_if_needed(y)
return {"name": file_path, "data": None, "is_file": True}
def stream_output(self, y):
if y is None:
return None
if client_utils.is_http_url_like(y["name"]):
response = requests.get(y["name"])
bytes = response.content
else:
file_path = y["name"]
with open(file_path, "rb") as f:
bytes = f.read()
return bytes
def check_streamable(self):
if self.source != "microphone":
raise ValueError(

View File

@ -270,6 +270,14 @@ class Streamable(EventListener):
pass
class StreamableOutput(EventListener):
def __init__(self):
self.streaming: bool
def stream_output(self, y) -> bytes:
raise NotImplementedError
@document("*start_recording", "*stop_recording", inherit=True)
class Recordable(EventListener):
def __init__(self):

View File

@ -17,6 +17,7 @@ import os
import posixpath
import secrets
import tempfile
import time
import traceback
from asyncio import TimeoutError as AsyncTimeOutError
from collections import defaultdict
@ -386,6 +387,41 @@ class App(FastAPI):
return response
return FileResponse(abs_path, headers={"Accept-Ranges": "bytes"})
@app.get(
"/stream/{session_hash}/{run}/{component_id}",
dependencies=[Depends(login_check)],
)
async def stream(
session_hash: str, run: int, component_id: int, request: fastapi.Request
):
stream: list = (
app.get_blocks()
.pending_streams[session_hash]
.get(run, {})
.get(component_id, None)
)
if stream is None:
raise HTTPException(404, "Stream not found.")
def stream_wrapper():
check_stream_rate = 0.01
max_wait_time = 120 # maximum wait between yields - assume generator thread has crashed otherwise.
wait_time = 0
while True:
if len(stream) == 0:
if wait_time > max_wait_time:
return
wait_time += check_stream_rate
time.sleep(check_stream_rate)
continue
wait_time = 0
next_stream = stream.pop(0)
if next_stream is None:
return
yield next_stream
return StreamingResponse(stream_wrapper())
@app.get("/file/{path:path}", dependencies=[Depends(login_check)])
async def file_deprecated(path: str, request: fastapi.Request):
return await file(path, request)
@ -406,24 +442,25 @@ class App(FastAPI):
fn_index_inferred: int,
):
fn_index = body.fn_index
if hasattr(body, "session_hash"):
if body.session_hash not in app.state_holder:
app.state_holder[body.session_hash] = {
session_hash = getattr(body, "session_hash", None)
if session_hash is not None:
if session_hash not in app.state_holder:
app.state_holder[session_hash] = {
_id: deepcopy(getattr(block, "value", None))
for _id, block in app.get_blocks().blocks.items()
if getattr(block, "stateful", False)
}
session_state = app.state_holder[body.session_hash]
session_state = app.state_holder[session_hash]
# The should_reset set keeps track of the fn_indices
# that have been cancelled. When a job is cancelled,
# the /reset route will mark the jobs as having been reset.
# That way if the cancel job finishes BEFORE the job being cancelled
# the job being cancelled will not overwrite the state of the iterator.
if fn_index in app.iterators_to_reset[body.session_hash]:
if fn_index in app.iterators_to_reset[session_hash]:
iterators = {}
app.iterators_to_reset[body.session_hash].remove(fn_index)
app.iterators_to_reset[session_hash].remove(fn_index)
else:
iterators = app.iterators[body.session_hash]
iterators = app.iterators[session_hash]
else:
session_state = {}
iterators = {}
@ -448,6 +485,7 @@ class App(FastAPI):
request=request,
state=session_state,
iterators=iterators,
session_hash=session_hash,
event_id=event_id,
event_data=event_data,
)
@ -457,6 +495,15 @@ class App(FastAPI):
if isinstance(output, Error):
raise output
except BaseException as error:
iterator = iterators.get(fn_index, None)
if iterator is not None: # close off any streams that are still open
run_id = id(iterator)
pending_streams: dict[int, list] = (
app.get_blocks().pending_streams[session_hash].get(run_id, {})
)
for stream in pending_streams.values():
stream.append(None)
show_error = app.get_blocks().show_error or isinstance(error, Error)
traceback.print_exc()
return JSONResponse(

View File

@ -19,4 +19,8 @@ The difference between `gr.Audio(source='microphone')` and `gr.Audio(source='mic
Here is example code of streaming images from the webcam.
$code_stream_frames
$code_stream_frames
Streaming can also be done in an output component. A `gr.Audio(streaming=True)` output component can take a stream of audio data yielded piece-wise by a generator function and combines them into a single audio file.
$code_stream_audio_out

View File

@ -5,6 +5,7 @@ export interface FileData {
data: string;
blob?: File;
is_file?: boolean;
is_stream?: boolean;
mime_type?: string;
alt_text?: string;
}

View File

@ -47,6 +47,12 @@ export function normalise_file(
} else {
file.data = "/proxy=" + root_url + "file=" + file.name;
}
} else if (file.is_stream) {
if (root_url == null) {
file.data = root + "/stream/" + file.name;
} else {
file.data = "/proxy=" + root_url + "stream/" + file.name;
}
}
return file;
}