mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-19 12:00:39 +08:00
Enable streaming audio in python client (#5248)
* Add code * Remove file * add changeset * add changeset * Update chilly-fans-make.md * lint * Lint * Add ffmpeg * Lint * Cleaner way to handle stream change * Fix windows test --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
a2f42e28bd
commit
390624d8ad
30
.changeset/chilly-fans-make.md
Normal file
30
.changeset/chilly-fans-make.md
Normal file
@ -0,0 +1,30 @@
|
||||
---
|
||||
"gradio": minor
|
||||
"gradio_client": minor
|
||||
---
|
||||
|
||||
highlight:
|
||||
|
||||
#### Enable streaming audio in python client
|
||||
|
||||
The `gradio_client` now supports streaming file outputs 🌊
|
||||
|
||||
No new syntax! Connect to a gradio demo that supports streaming file outputs and call `predict` or `submit` as you normally would.
|
||||
|
||||
```python
|
||||
import gradio_client as grc
|
||||
client = grc.Client("gradio/stream_audio_out")
|
||||
|
||||
# Get the entire generated audio as a local file
|
||||
client.predict("/Users/freddy/Pictures/bark_demo.mp4", api_name="/predict")
|
||||
|
||||
job = client.submit("/Users/freddy/Pictures/bark_demo.mp4", api_name="/predict")
|
||||
|
||||
# Get the entire generated audio as a local file
|
||||
job.result()
|
||||
|
||||
# Each individual chunk
|
||||
job.outputs()
|
||||
```
|
||||
|
||||
|
2
.github/workflows/backend.yml
vendored
2
.github/workflows/backend.yml
vendored
@ -98,6 +98,8 @@ jobs:
|
||||
run: |
|
||||
. venv/bin/activate
|
||||
python -m pip install -r client/python/test/requirements.txt
|
||||
- name: Install ffmpeg
|
||||
uses: FedericoCarboni/setup-ffmpeg@v2
|
||||
- name: Install Gradio and Client Libraries Locally (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
|
@ -13,3 +13,5 @@ class FileData(TypedDict):
|
||||
bool
|
||||
] # whether the data corresponds to a file or base64 encoded data
|
||||
orig_name: NotRequired[str] # original filename
|
||||
mime_type: NotRequired[str]
|
||||
is_stream: NotRequired[bool]
|
||||
|
@ -2,6 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@ -204,6 +206,11 @@ class ImgSerializable(Serializable):
|
||||
class FileSerializable(Serializable):
|
||||
"""Expects a dict with base64 representation of object as input/output which is serialized to a filepath."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.stream = None
|
||||
self.stream_name = None
|
||||
super().__init__()
|
||||
|
||||
def serialized_info(self):
|
||||
return self._single_file_serialized_info()
|
||||
|
||||
@ -268,6 +275,9 @@ class FileSerializable(Serializable):
|
||||
"size": size,
|
||||
}
|
||||
|
||||
def _setup_stream(self, url, hf_token):
|
||||
return utils.download_byte_stream(url, hf_token)
|
||||
|
||||
def _deserialize_single(
|
||||
self,
|
||||
x: str | FileData | None,
|
||||
@ -291,6 +301,19 @@ class FileSerializable(Serializable):
|
||||
)
|
||||
else:
|
||||
file_name = utils.create_tmp_copy_of_file(filepath, dir=save_dir)
|
||||
elif x.get("is_stream"):
|
||||
assert x["name"] and root_url and save_dir
|
||||
if not self.stream or self.stream_name != x["name"]:
|
||||
self.stream = self._setup_stream(
|
||||
root_url + "stream/" + x["name"], hf_token=hf_token
|
||||
)
|
||||
self.stream_name = x["name"]
|
||||
chunk = next(self.stream)
|
||||
path = Path(save_dir or tempfile.gettempdir()) / secrets.token_hex(20)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
path = path / x.get("orig_name", "output")
|
||||
path.write_bytes(chunk)
|
||||
file_name = str(path)
|
||||
else:
|
||||
data = x.get("data")
|
||||
assert data is not None, f"The 'data' field is missing in {x}"
|
||||
|
@ -387,6 +387,16 @@ def encode_url_or_file_to_base64(path: str | Path):
|
||||
return encode_file_to_base64(path)
|
||||
|
||||
|
||||
def download_byte_stream(url: str, hf_token=None):
|
||||
arr = bytearray()
|
||||
headers = {"Authorization": "Bearer " + hf_token} if hf_token else {}
|
||||
with httpx.stream("GET", url, headers=headers) as r:
|
||||
for data in r.iter_bytes():
|
||||
arr += data
|
||||
yield data
|
||||
yield arr
|
||||
|
||||
|
||||
def decode_base64_to_binary(encoding: str) -> tuple[bytes, str | None]:
|
||||
extension = get_extension(encoding)
|
||||
data = encoding.rsplit(",", 1)[-1]
|
||||
|
@ -4,6 +4,7 @@ import time
|
||||
|
||||
import gradio as gr
|
||||
import pytest
|
||||
from pydub import AudioSegment
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
@ -297,6 +298,31 @@ def hello_world_with_state_and_accordion():
|
||||
return demo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stream_audio():
|
||||
import pathlib
|
||||
import tempfile
|
||||
|
||||
def _stream_audio(audio_file):
|
||||
audio = AudioSegment.from_mp3(audio_file)
|
||||
i = 0
|
||||
chunk_size = 3000
|
||||
|
||||
while chunk_size * i < len(audio):
|
||||
chunk = audio[chunk_size * i : chunk_size * (i + 1)]
|
||||
i += 1
|
||||
if chunk:
|
||||
file = str(pathlib.Path(tempfile.gettempdir()) / f"{i}.wav")
|
||||
chunk.export(file, format="wav")
|
||||
yield file
|
||||
|
||||
return gr.Interface(
|
||||
fn=_stream_audio,
|
||||
inputs=gr.Audio(type="filepath", label="Audio file to stream"),
|
||||
outputs=gr.Audio(autoplay=True, streaming=True),
|
||||
).queue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def all_components():
|
||||
classes_to_check = gr.components.Component.__subclasses__()
|
||||
|
@ -4,3 +4,4 @@ pytest==7.1.2
|
||||
ruff==0.0.264
|
||||
pyright==1.1.305
|
||||
gradio
|
||||
pydub==0.25.1
|
||||
|
@ -268,6 +268,21 @@ class TestClientPredictions:
|
||||
assert job2.status().code == Status.FINISHED
|
||||
assert len(job2.outputs()) == 4
|
||||
|
||||
def test_stream_audio(self, stream_audio):
|
||||
with connect(stream_audio) as client:
|
||||
job1 = client.submit(
|
||||
"https://gradio-builds.s3.amazonaws.com/demo-files/bark_demo.mp4",
|
||||
api_name="/predict",
|
||||
)
|
||||
assert Path(job1.result()).exists()
|
||||
|
||||
job2 = client.submit(
|
||||
"https://gradio-builds.s3.amazonaws.com/demo-files/audio_sample.wav",
|
||||
api_name="/predict",
|
||||
)
|
||||
assert Path(job2.result()).exists()
|
||||
assert all(Path(p).exists() for p in job2.outputs())
|
||||
|
||||
@pytest.mark.flaky
|
||||
def test_upload_file_private_space(self):
|
||||
client = Client(
|
||||
|
@ -1360,10 +1360,10 @@ Received outputs:
|
||||
if run not in self.pending_streams[session_hash]:
|
||||
self.pending_streams[session_hash][run] = defaultdict(list)
|
||||
self.pending_streams[session_hash][run][output_id].append(stream)
|
||||
data[i] = {
|
||||
"name": f"{session_hash}/{run}/{output_id}",
|
||||
"is_stream": True,
|
||||
}
|
||||
if data[i]:
|
||||
data[i]["is_file"] = False
|
||||
data[i]["name"] = f"{session_hash}/{run}/{output_id}"
|
||||
data[i]["is_stream"] = True
|
||||
return data
|
||||
|
||||
async def process_api(
|
||||
|
@ -357,7 +357,12 @@ class Audio(
|
||||
self.temp_files.add(file_path)
|
||||
else:
|
||||
file_path = self.make_temp_copy_if_needed(y)
|
||||
return {"name": file_path, "data": None, "is_file": True}
|
||||
return {
|
||||
"name": file_path,
|
||||
"data": None,
|
||||
"is_file": True,
|
||||
"orig_name": Path(file_path).name,
|
||||
}
|
||||
|
||||
def stream_output(self, y):
|
||||
if y is None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user