mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-12 12:40:29 +08:00
Live audio streaming output (#5077)
* changes * add changeset * changes * changes * changes * changes * changes * changes * add changeset * changes * changes * changes * changes * changes * changes --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
This commit is contained in:
parent
cd1353fa3e
commit
667875b244
40
.changeset/famous-rice-taste.md
Normal file
40
.changeset/famous-rice-taste.md
Normal file
@ -0,0 +1,40 @@
|
||||
---
|
||||
"@gradio/upload": patch
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
fix:Live audio streaming output
|
||||
|
||||
highlight:
|
||||
|
||||
#### Now supports loading streamed outputs
|
||||
|
||||
Allows users to use generators to stream audio out, yielding consecutive chunks of audio. Requires `streaming=True` to be set on the output audio.
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
from pydub import AudioSegment
|
||||
|
||||
def stream_audio(audio_file):
|
||||
audio = AudioSegment.from_mp3(audio_file)
|
||||
i = 0
|
||||
chunk_size = 3000
|
||||
|
||||
while chunk_size*i < len(audio):
|
||||
chunk = audio[chunk_size*i:chunk_size*(i+1)]
|
||||
i += 1
|
||||
if chunk:
|
||||
file = f"/tmp/{i}.mp3"
|
||||
chunk.export(file, format="mp3")
|
||||
yield file
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=stream_audio,
|
||||
inputs=gr.Audio(type="filepath", label="Audio file to stream"),
|
||||
outputs=gr.Audio(autoplay=True, streaming=True),
|
||||
)
|
||||
|
||||
demo.queue().launch()
|
||||
```
|
||||
|
||||
From the backend, streamed outputs are served from the `/stream/` endpoint instead of the `/file/` endpoint. Currently just used to serve audio streaming output. The output JSON will have `is_stream`: `true`, instead of `is_file`: `true` in the file data object.
|
1
demo/stream_audio_out/run.ipynb
Normal file
1
demo/stream_audio_out/run.ipynb
Normal file
@ -0,0 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: stream_audio_out"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from pydub import AudioSegment\n", "\n", "def stream_audio(audio_file):\n", " audio = AudioSegment.from_mp3(audio_file)\n", " i = 0\n", " chunk_size = 3000\n", " \n", " while chunk_size*i < len(audio):\n", " chunk = audio[chunk_size*i:chunk_size*(i+1)]\n", " i += 1\n", " if chunk:\n", " file = f\"/tmp/{i}.mp3\"\n", " chunk.export(file, format=\"mp3\") \n", " yield file\n", " \n", "demo = gr.Interface(\n", " fn=stream_audio,\n", " inputs=gr.Audio(type=\"filepath\", label=\"Audio file to stream\"),\n", " outputs=gr.Audio(autoplay=True, streaming=True),\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
24
demo/stream_audio_out/run.py
Normal file
24
demo/stream_audio_out/run.py
Normal file
@ -0,0 +1,24 @@
|
||||
import gradio as gr
|
||||
from pydub import AudioSegment
|
||||
|
||||
def stream_audio(audio_file):
|
||||
audio = AudioSegment.from_mp3(audio_file)
|
||||
i = 0
|
||||
chunk_size = 3000
|
||||
|
||||
while chunk_size*i < len(audio):
|
||||
chunk = audio[chunk_size*i:chunk_size*(i+1)]
|
||||
i += 1
|
||||
if chunk:
|
||||
file = f"/tmp/{i}.mp3"
|
||||
chunk.export(file, format="mp3")
|
||||
yield file
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=stream_audio,
|
||||
inputs=gr.Audio(type="filepath", label="Audio file to stream"),
|
||||
outputs=gr.Audio(autoplay=True, streaming=True),
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.queue().launch()
|
@ -12,6 +12,7 @@ import time
|
||||
import warnings
|
||||
import webbrowser
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Literal, cast
|
||||
@ -707,6 +708,7 @@ class Blocks(BlockContext):
|
||||
self.share = False
|
||||
self.enable_queue = None
|
||||
self.max_threads = 40
|
||||
self.pending_streams = defaultdict(dict)
|
||||
self.show_error = True
|
||||
if css is not None and os.path.exists(css):
|
||||
with open(css) as css_file:
|
||||
@ -1333,6 +1335,27 @@ Received outputs:
|
||||
|
||||
return output
|
||||
|
||||
def handle_streaming_outputs(
|
||||
self, fn_index: int, data: list, session_hash: str | None, run: int | None
|
||||
) -> list:
|
||||
if session_hash is None or run is None:
|
||||
return data
|
||||
|
||||
from gradio.events import StreamableOutput
|
||||
|
||||
for i, output_id in enumerate(self.dependencies[fn_index]["outputs"]):
|
||||
block = self.blocks[output_id]
|
||||
if isinstance(block, StreamableOutput) and block.streaming:
|
||||
stream = block.stream_output(data[i])
|
||||
if run not in self.pending_streams[session_hash]:
|
||||
self.pending_streams[session_hash][run] = defaultdict(list)
|
||||
self.pending_streams[session_hash][run][output_id].append(stream)
|
||||
data[i] = {
|
||||
"name": f"{session_hash}/{run}/{output_id}",
|
||||
"is_stream": True,
|
||||
}
|
||||
return data
|
||||
|
||||
async def process_api(
|
||||
self,
|
||||
fn_index: int,
|
||||
@ -1340,6 +1363,7 @@ Received outputs:
|
||||
state: dict[int, Any],
|
||||
request: routes.Request | list[routes.Request] | None = None,
|
||||
iterators: dict[int, Any] | None = None,
|
||||
session_hash: str | None = None,
|
||||
event_id: str | None = None,
|
||||
event_data: EventData | None = None,
|
||||
) -> dict[str, Any]:
|
||||
@ -1391,10 +1415,15 @@ Received outputs:
|
||||
else:
|
||||
inputs = self.preprocess_data(fn_index, inputs, state)
|
||||
iterator = iterators.get(fn_index, None) if iterators else None
|
||||
was_generating = iterator is not None
|
||||
result = await self.call_function(
|
||||
fn_index, inputs, iterator, request, event_id, event_data
|
||||
)
|
||||
data = self.postprocess_data(fn_index, result["prediction"], state)
|
||||
if result["is_generating"] or was_generating:
|
||||
data = self.handle_streaming_outputs(
|
||||
fn_index, data, session_hash, id(iterator)
|
||||
)
|
||||
is_generating, iterator = result["is_generating"], result["iterator"]
|
||||
|
||||
block_fn.total_runtime += result["duration"]
|
||||
|
@ -7,6 +7,7 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from gradio_client import media_data
|
||||
from gradio_client import utils as client_utils
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
@ -20,6 +21,7 @@ from gradio.events import (
|
||||
Playable,
|
||||
Recordable,
|
||||
Streamable,
|
||||
StreamableOutput,
|
||||
Uploadable,
|
||||
)
|
||||
from gradio.interpretation import TokenInterpretable
|
||||
@ -34,6 +36,7 @@ class Audio(
|
||||
Playable,
|
||||
Recordable,
|
||||
Streamable,
|
||||
StreamableOutput,
|
||||
Uploadable,
|
||||
IOComponent,
|
||||
FileSerializable,
|
||||
@ -52,7 +55,7 @@ class Audio(
|
||||
self,
|
||||
value: str | Path | tuple[int, np.ndarray] | Callable | None = None,
|
||||
*,
|
||||
source: Literal["upload", "microphone"] = "upload",
|
||||
source: Literal["upload", "microphone"] | None = None,
|
||||
type: Literal["numpy", "filepath"] = "numpy",
|
||||
label: str | None = None,
|
||||
every: float | None = None,
|
||||
@ -84,7 +87,7 @@ class Audio(
|
||||
min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value. If a certain scale value results in this Component being narrower than min_width, the min_width parameter will be respected first.
|
||||
interactive: if True, will allow users to upload and edit a audio file; if False, can only be used to play audio. If not provided, this is inferred based on whether the component is used as an input or output.
|
||||
visible: If False, component will be hidden.
|
||||
streaming: If set to True when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'microphone'.
|
||||
streaming: If set to True when used in a `live` interface as an input, will automatically stream webcam feed. When used set as an output, takes audio chunks yield from the backend and combines them into one streaming audio output.
|
||||
elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
|
||||
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
|
||||
format: The file format to save audio files. Either 'wav' or 'mp3'. wav files are lossless but will tend to be larger files. mp3 files tend to be smaller. Default is wav. Applies both when this component is used as an input (when `type` is "format") and when this component is used as an output.
|
||||
@ -93,6 +96,7 @@ class Audio(
|
||||
show_share_button: If True, will show a share icon in the corner of the component that allows user to share outputs to Hugging Face Spaces Discussions. If False, icon does not appear. If set to None (default behavior), then the icon appears if this Gradio app is launched on Spaces, but not otherwise.
|
||||
"""
|
||||
valid_sources = ["upload", "microphone"]
|
||||
source = source if source else ("microphone" if streaming else "upload")
|
||||
if source not in valid_sources:
|
||||
raise ValueError(
|
||||
f"Invalid value for parameter `source`: {source}. Please choose from one of: {valid_sources}"
|
||||
@ -105,7 +109,7 @@ class Audio(
|
||||
)
|
||||
self.type = type
|
||||
self.streaming = streaming
|
||||
if streaming and source != "microphone":
|
||||
if streaming and source == "upload":
|
||||
raise ValueError(
|
||||
"Audio streaming only available if source is 'microphone'."
|
||||
)
|
||||
@ -340,6 +344,18 @@ class Audio(
|
||||
file_path = self.make_temp_copy_if_needed(y)
|
||||
return {"name": file_path, "data": None, "is_file": True}
|
||||
|
||||
def stream_output(self, y):
|
||||
if y is None:
|
||||
return None
|
||||
if client_utils.is_http_url_like(y["name"]):
|
||||
response = requests.get(y["name"])
|
||||
bytes = response.content
|
||||
else:
|
||||
file_path = y["name"]
|
||||
with open(file_path, "rb") as f:
|
||||
bytes = f.read()
|
||||
return bytes
|
||||
|
||||
def check_streamable(self):
|
||||
if self.source != "microphone":
|
||||
raise ValueError(
|
||||
|
@ -270,6 +270,14 @@ class Streamable(EventListener):
|
||||
pass
|
||||
|
||||
|
||||
class StreamableOutput(EventListener):
|
||||
def __init__(self):
|
||||
self.streaming: bool
|
||||
|
||||
def stream_output(self, y) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@document("*start_recording", "*stop_recording", inherit=True)
|
||||
class Recordable(EventListener):
|
||||
def __init__(self):
|
||||
|
@ -17,6 +17,7 @@ import os
|
||||
import posixpath
|
||||
import secrets
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
from asyncio import TimeoutError as AsyncTimeOutError
|
||||
from collections import defaultdict
|
||||
@ -386,6 +387,41 @@ class App(FastAPI):
|
||||
return response
|
||||
return FileResponse(abs_path, headers={"Accept-Ranges": "bytes"})
|
||||
|
||||
@app.get(
|
||||
"/stream/{session_hash}/{run}/{component_id}",
|
||||
dependencies=[Depends(login_check)],
|
||||
)
|
||||
async def stream(
|
||||
session_hash: str, run: int, component_id: int, request: fastapi.Request
|
||||
):
|
||||
stream: list = (
|
||||
app.get_blocks()
|
||||
.pending_streams[session_hash]
|
||||
.get(run, {})
|
||||
.get(component_id, None)
|
||||
)
|
||||
if stream is None:
|
||||
raise HTTPException(404, "Stream not found.")
|
||||
|
||||
def stream_wrapper():
|
||||
check_stream_rate = 0.01
|
||||
max_wait_time = 120 # maximum wait between yields - assume generator thread has crashed otherwise.
|
||||
wait_time = 0
|
||||
while True:
|
||||
if len(stream) == 0:
|
||||
if wait_time > max_wait_time:
|
||||
return
|
||||
wait_time += check_stream_rate
|
||||
time.sleep(check_stream_rate)
|
||||
continue
|
||||
wait_time = 0
|
||||
next_stream = stream.pop(0)
|
||||
if next_stream is None:
|
||||
return
|
||||
yield next_stream
|
||||
|
||||
return StreamingResponse(stream_wrapper())
|
||||
|
||||
@app.get("/file/{path:path}", dependencies=[Depends(login_check)])
|
||||
async def file_deprecated(path: str, request: fastapi.Request):
|
||||
return await file(path, request)
|
||||
@ -406,24 +442,25 @@ class App(FastAPI):
|
||||
fn_index_inferred: int,
|
||||
):
|
||||
fn_index = body.fn_index
|
||||
if hasattr(body, "session_hash"):
|
||||
if body.session_hash not in app.state_holder:
|
||||
app.state_holder[body.session_hash] = {
|
||||
session_hash = getattr(body, "session_hash", None)
|
||||
if session_hash is not None:
|
||||
if session_hash not in app.state_holder:
|
||||
app.state_holder[session_hash] = {
|
||||
_id: deepcopy(getattr(block, "value", None))
|
||||
for _id, block in app.get_blocks().blocks.items()
|
||||
if getattr(block, "stateful", False)
|
||||
}
|
||||
session_state = app.state_holder[body.session_hash]
|
||||
session_state = app.state_holder[session_hash]
|
||||
# The should_reset set keeps track of the fn_indices
|
||||
# that have been cancelled. When a job is cancelled,
|
||||
# the /reset route will mark the jobs as having been reset.
|
||||
# That way if the cancel job finishes BEFORE the job being cancelled
|
||||
# the job being cancelled will not overwrite the state of the iterator.
|
||||
if fn_index in app.iterators_to_reset[body.session_hash]:
|
||||
if fn_index in app.iterators_to_reset[session_hash]:
|
||||
iterators = {}
|
||||
app.iterators_to_reset[body.session_hash].remove(fn_index)
|
||||
app.iterators_to_reset[session_hash].remove(fn_index)
|
||||
else:
|
||||
iterators = app.iterators[body.session_hash]
|
||||
iterators = app.iterators[session_hash]
|
||||
else:
|
||||
session_state = {}
|
||||
iterators = {}
|
||||
@ -448,6 +485,7 @@ class App(FastAPI):
|
||||
request=request,
|
||||
state=session_state,
|
||||
iterators=iterators,
|
||||
session_hash=session_hash,
|
||||
event_id=event_id,
|
||||
event_data=event_data,
|
||||
)
|
||||
@ -457,6 +495,15 @@ class App(FastAPI):
|
||||
if isinstance(output, Error):
|
||||
raise output
|
||||
except BaseException as error:
|
||||
iterator = iterators.get(fn_index, None)
|
||||
if iterator is not None: # close off any streams that are still open
|
||||
run_id = id(iterator)
|
||||
pending_streams: dict[int, list] = (
|
||||
app.get_blocks().pending_streams[session_hash].get(run_id, {})
|
||||
)
|
||||
for stream in pending_streams.values():
|
||||
stream.append(None)
|
||||
|
||||
show_error = app.get_blocks().show_error or isinstance(error, Error)
|
||||
traceback.print_exc()
|
||||
return JSONResponse(
|
||||
|
@ -19,4 +19,8 @@ The difference between `gr.Audio(source='microphone')` and `gr.Audio(source='mic
|
||||
|
||||
Here is example code of streaming images from the webcam.
|
||||
|
||||
$code_stream_frames
|
||||
$code_stream_frames
|
||||
|
||||
Streaming can also be done in an output component. A `gr.Audio(streaming=True)` output component can take a stream of audio data yielded piece-wise by a generator function and combines them into a single audio file.
|
||||
|
||||
$code_stream_audio_out
|
@ -5,6 +5,7 @@ export interface FileData {
|
||||
data: string;
|
||||
blob?: File;
|
||||
is_file?: boolean;
|
||||
is_stream?: boolean;
|
||||
mime_type?: string;
|
||||
alt_text?: string;
|
||||
}
|
||||
|
@ -47,6 +47,12 @@ export function normalise_file(
|
||||
} else {
|
||||
file.data = "/proxy=" + root_url + "file=" + file.name;
|
||||
}
|
||||
} else if (file.is_stream) {
|
||||
if (root_url == null) {
|
||||
file.data = root + "/stream/" + file.name;
|
||||
} else {
|
||||
file.data = "/proxy=" + root_url + "stream/" + file.name;
|
||||
}
|
||||
}
|
||||
return file;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user