Enable streaming audio in python client (#5248)

* Add code

* Remove file

* add changeset

* add changeset

* Update chilly-fans-make.md

* lint

* Lint

* Add ffmpeg

* Lint

* Cleaner way to handle stream change

* Fix windows test

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2023-08-21 15:15:36 -04:00 committed by GitHub
parent a2f42e28bd
commit 390624d8ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 119 additions and 5 deletions

View File

@ -0,0 +1,30 @@
---
"gradio": minor
"gradio_client": minor
---
highlight:
#### Enable streaming audio in python client
The `gradio_client` now supports streaming file outputs 🌊
No new syntax! Connect to a gradio demo that supports streaming file outputs and call `predict` or `submit` as you normally would.
```python
import gradio_client as grc
client = grc.Client("gradio/stream_audio_out")
# Get the entire generated audio as a local file
client.predict("/Users/freddy/Pictures/bark_demo.mp4", api_name="/predict")
job = client.submit("/Users/freddy/Pictures/bark_demo.mp4", api_name="/predict")
# Get the entire generated audio as a local file
job.result()
# Each individual chunk
job.outputs()
```

View File

@ -98,6 +98,8 @@ jobs:
run: |
. venv/bin/activate
python -m pip install -r client/python/test/requirements.txt
- name: Install ffmpeg
uses: FedericoCarboni/setup-ffmpeg@v2
- name: Install Gradio and Client Libraries Locally (Linux)
if: runner.os == 'Linux'
run: |

View File

@ -13,3 +13,5 @@ class FileData(TypedDict):
bool
] # whether the data corresponds to a file or base64 encoded data
orig_name: NotRequired[str] # original filename
mime_type: NotRequired[str]
is_stream: NotRequired[bool]

View File

@ -2,6 +2,8 @@ from __future__ import annotations
import json
import os
import secrets
import tempfile
import uuid
from pathlib import Path
from typing import Any
@ -204,6 +206,11 @@ class ImgSerializable(Serializable):
class FileSerializable(Serializable):
"""Expects a dict with base64 representation of object as input/output which is serialized to a filepath."""
def __init__(self) -> None:
self.stream = None
self.stream_name = None
super().__init__()
def serialized_info(self):
return self._single_file_serialized_info()
@ -268,6 +275,9 @@ class FileSerializable(Serializable):
"size": size,
}
def _setup_stream(self, url, hf_token):
return utils.download_byte_stream(url, hf_token)
def _deserialize_single(
self,
x: str | FileData | None,
@ -291,6 +301,19 @@ class FileSerializable(Serializable):
)
else:
file_name = utils.create_tmp_copy_of_file(filepath, dir=save_dir)
elif x.get("is_stream"):
assert x["name"] and root_url and save_dir
if not self.stream or self.stream_name != x["name"]:
self.stream = self._setup_stream(
root_url + "stream/" + x["name"], hf_token=hf_token
)
self.stream_name = x["name"]
chunk = next(self.stream)
path = Path(save_dir or tempfile.gettempdir()) / secrets.token_hex(20)
path.mkdir(parents=True, exist_ok=True)
path = path / x.get("orig_name", "output")
path.write_bytes(chunk)
file_name = str(path)
else:
data = x.get("data")
assert data is not None, f"The 'data' field is missing in {x}"

View File

@ -387,6 +387,16 @@ def encode_url_or_file_to_base64(path: str | Path):
return encode_file_to_base64(path)
def download_byte_stream(url: str, hf_token=None):
arr = bytearray()
headers = {"Authorization": "Bearer " + hf_token} if hf_token else {}
with httpx.stream("GET", url, headers=headers) as r:
for data in r.iter_bytes():
arr += data
yield data
yield arr
def decode_base64_to_binary(encoding: str) -> tuple[bytes, str | None]:
extension = get_extension(encoding)
data = encoding.rsplit(",", 1)[-1]

View File

@ -4,6 +4,7 @@ import time
import gradio as gr
import pytest
from pydub import AudioSegment
def pytest_configure(config):
@ -297,6 +298,31 @@ def hello_world_with_state_and_accordion():
return demo
@pytest.fixture
def stream_audio():
import pathlib
import tempfile
def _stream_audio(audio_file):
audio = AudioSegment.from_mp3(audio_file)
i = 0
chunk_size = 3000
while chunk_size * i < len(audio):
chunk = audio[chunk_size * i : chunk_size * (i + 1)]
i += 1
if chunk:
file = str(pathlib.Path(tempfile.gettempdir()) / f"{i}.wav")
chunk.export(file, format="wav")
yield file
return gr.Interface(
fn=_stream_audio,
inputs=gr.Audio(type="filepath", label="Audio file to stream"),
outputs=gr.Audio(autoplay=True, streaming=True),
).queue()
@pytest.fixture
def all_components():
classes_to_check = gr.components.Component.__subclasses__()

View File

@ -4,3 +4,4 @@ pytest==7.1.2
ruff==0.0.264
pyright==1.1.305
gradio
pydub==0.25.1

View File

@ -268,6 +268,21 @@ class TestClientPredictions:
assert job2.status().code == Status.FINISHED
assert len(job2.outputs()) == 4
def test_stream_audio(self, stream_audio):
with connect(stream_audio) as client:
job1 = client.submit(
"https://gradio-builds.s3.amazonaws.com/demo-files/bark_demo.mp4",
api_name="/predict",
)
assert Path(job1.result()).exists()
job2 = client.submit(
"https://gradio-builds.s3.amazonaws.com/demo-files/audio_sample.wav",
api_name="/predict",
)
assert Path(job2.result()).exists()
assert all(Path(p).exists() for p in job2.outputs())
@pytest.mark.flaky
def test_upload_file_private_space(self):
client = Client(

View File

@ -1360,10 +1360,10 @@ Received outputs:
if run not in self.pending_streams[session_hash]:
self.pending_streams[session_hash][run] = defaultdict(list)
self.pending_streams[session_hash][run][output_id].append(stream)
data[i] = {
"name": f"{session_hash}/{run}/{output_id}",
"is_stream": True,
}
if data[i]:
data[i]["is_file"] = False
data[i]["name"] = f"{session_hash}/{run}/{output_id}"
data[i]["is_stream"] = True
return data
async def process_api(

View File

@ -357,7 +357,12 @@ class Audio(
self.temp_files.add(file_path)
else:
file_path = self.make_temp_copy_if_needed(y)
return {"name": file_path, "data": None, "is_file": True}
return {
"name": file_path,
"data": None,
"is_file": True,
"orig_name": Path(file_path).name,
}
def stream_output(self, y):
if y is None: