From 390624d8ad2b1308a5bf8384435fd0db98d8e29e Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Mon, 21 Aug 2023 15:15:36 -0400 Subject: [PATCH] Enable streaming audio in python client (#5248) * Add code * Remove file * add changeset * add changeset * Update chilly-fans-make.md * lint * Lint * Add ffmpeg * Lint * Cleaner way to handle stream change * Fix windows test --------- Co-authored-by: gradio-pr-bot Co-authored-by: Abubakar Abid --- .changeset/chilly-fans-make.md | 30 +++++++++++++++++++++ .github/workflows/backend.yml | 2 ++ client/python/gradio_client/data_classes.py | 2 ++ client/python/gradio_client/serializing.py | 23 ++++++++++++++++ client/python/gradio_client/utils.py | 10 +++++++ client/python/test/conftest.py | 26 ++++++++++++++++++ client/python/test/requirements.txt | 1 + client/python/test/test_client.py | 15 +++++++++++ gradio/blocks.py | 8 +++--- gradio/components/audio.py | 7 ++++- 10 files changed, 119 insertions(+), 5 deletions(-) create mode 100644 .changeset/chilly-fans-make.md diff --git a/.changeset/chilly-fans-make.md b/.changeset/chilly-fans-make.md new file mode 100644 index 0000000000..7eb46be563 --- /dev/null +++ b/.changeset/chilly-fans-make.md @@ -0,0 +1,30 @@ +--- +"gradio": minor +"gradio_client": minor +--- + +highlight: + +#### Enable streaming audio in python client + +The `gradio_client` now supports streaming file outputs 🌊 + +No new syntax! Connect to a gradio demo that supports streaming file outputs and call `predict` or `submit` as you normally would. + +```python +import gradio_client as grc +client = grc.Client("gradio/stream_audio_out") + +# Get the entire generated audio as a local file +client.predict("/Users/freddy/Pictures/bark_demo.mp4", api_name="/predict") + +job = client.submit("/Users/freddy/Pictures/bark_demo.mp4", api_name="/predict") + +# Get the entire generated audio as a local file +job.result() + +# Each individual chunk +job.outputs() +``` + + diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml index 417662cd9c..9344e7daff 100644 --- a/.github/workflows/backend.yml +++ b/.github/workflows/backend.yml @@ -98,6 +98,8 @@ jobs: run: | . venv/bin/activate python -m pip install -r client/python/test/requirements.txt + - name: Install ffmpeg + uses: FedericoCarboni/setup-ffmpeg@v2 - name: Install Gradio and Client Libraries Locally (Linux) if: runner.os == 'Linux' run: | diff --git a/client/python/gradio_client/data_classes.py b/client/python/gradio_client/data_classes.py index 50f22042d3..bfd8665116 100644 --- a/client/python/gradio_client/data_classes.py +++ b/client/python/gradio_client/data_classes.py @@ -13,3 +13,5 @@ class FileData(TypedDict): bool ] # whether the data corresponds to a file or base64 encoded data orig_name: NotRequired[str] # original filename + mime_type: NotRequired[str] + is_stream: NotRequired[bool] diff --git a/client/python/gradio_client/serializing.py b/client/python/gradio_client/serializing.py index a5fb3c9080..f5dbc64fc7 100644 --- a/client/python/gradio_client/serializing.py +++ b/client/python/gradio_client/serializing.py @@ -2,6 +2,8 @@ from __future__ import annotations import json import os +import secrets +import tempfile import uuid from pathlib import Path from typing import Any @@ -204,6 +206,11 @@ class ImgSerializable(Serializable): class FileSerializable(Serializable): """Expects a dict with base64 representation of object as input/output which is serialized to a filepath.""" + def __init__(self) -> None: + self.stream = None + self.stream_name = None + super().__init__() + def serialized_info(self): return self._single_file_serialized_info() @@ -268,6 +275,9 @@ class FileSerializable(Serializable): "size": size, } + def _setup_stream(self, url, hf_token): + return utils.download_byte_stream(url, hf_token) + def _deserialize_single( self, x: str | FileData | None, @@ -291,6 +301,19 @@ class FileSerializable(Serializable): ) else: file_name = utils.create_tmp_copy_of_file(filepath, dir=save_dir) + elif x.get("is_stream"): + assert x["name"] and root_url and save_dir + if not self.stream or self.stream_name != x["name"]: + self.stream = self._setup_stream( + root_url + "stream/" + x["name"], hf_token=hf_token + ) + self.stream_name = x["name"] + chunk = next(self.stream) + path = Path(save_dir or tempfile.gettempdir()) / secrets.token_hex(20) + path.mkdir(parents=True, exist_ok=True) + path = path / x.get("orig_name", "output") + path.write_bytes(chunk) + file_name = str(path) else: data = x.get("data") assert data is not None, f"The 'data' field is missing in {x}" diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index 850e6f8882..f1b6cb5499 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -387,6 +387,16 @@ def encode_url_or_file_to_base64(path: str | Path): return encode_file_to_base64(path) +def download_byte_stream(url: str, hf_token=None): + arr = bytearray() + headers = {"Authorization": "Bearer " + hf_token} if hf_token else {} + with httpx.stream("GET", url, headers=headers) as r: + for data in r.iter_bytes(): + arr += data + yield data + yield arr + + def decode_base64_to_binary(encoding: str) -> tuple[bytes, str | None]: extension = get_extension(encoding) data = encoding.rsplit(",", 1)[-1] diff --git a/client/python/test/conftest.py b/client/python/test/conftest.py index 1190524d12..13a3d23750 100644 --- a/client/python/test/conftest.py +++ b/client/python/test/conftest.py @@ -4,6 +4,7 @@ import time import gradio as gr import pytest +from pydub import AudioSegment def pytest_configure(config): @@ -297,6 +298,31 @@ def hello_world_with_state_and_accordion(): return demo +@pytest.fixture +def stream_audio(): + import pathlib + import tempfile + + def _stream_audio(audio_file): + audio = AudioSegment.from_mp3(audio_file) + i = 0 + chunk_size = 3000 + + while chunk_size * i < len(audio): + chunk = audio[chunk_size * i : chunk_size * (i + 1)] + i += 1 + if chunk: + file = str(pathlib.Path(tempfile.gettempdir()) / f"{i}.wav") + chunk.export(file, format="wav") + yield file + + return gr.Interface( + fn=_stream_audio, + inputs=gr.Audio(type="filepath", label="Audio file to stream"), + outputs=gr.Audio(autoplay=True, streaming=True), + ).queue() + + @pytest.fixture def all_components(): classes_to_check = gr.components.Component.__subclasses__() diff --git a/client/python/test/requirements.txt b/client/python/test/requirements.txt index 76e2e3bcbe..064b932250 100644 --- a/client/python/test/requirements.txt +++ b/client/python/test/requirements.txt @@ -4,3 +4,4 @@ pytest==7.1.2 ruff==0.0.264 pyright==1.1.305 gradio +pydub==0.25.1 diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index c187ee2eab..eec7db76d8 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -268,6 +268,21 @@ class TestClientPredictions: assert job2.status().code == Status.FINISHED assert len(job2.outputs()) == 4 + def test_stream_audio(self, stream_audio): + with connect(stream_audio) as client: + job1 = client.submit( + "https://gradio-builds.s3.amazonaws.com/demo-files/bark_demo.mp4", + api_name="/predict", + ) + assert Path(job1.result()).exists() + + job2 = client.submit( + "https://gradio-builds.s3.amazonaws.com/demo-files/audio_sample.wav", + api_name="/predict", + ) + assert Path(job2.result()).exists() + assert all(Path(p).exists() for p in job2.outputs()) + @pytest.mark.flaky def test_upload_file_private_space(self): client = Client( diff --git a/gradio/blocks.py b/gradio/blocks.py index 98f390d480..e765cc55bb 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -1360,10 +1360,10 @@ Received outputs: if run not in self.pending_streams[session_hash]: self.pending_streams[session_hash][run] = defaultdict(list) self.pending_streams[session_hash][run][output_id].append(stream) - data[i] = { - "name": f"{session_hash}/{run}/{output_id}", - "is_stream": True, - } + if data[i]: + data[i]["is_file"] = False + data[i]["name"] = f"{session_hash}/{run}/{output_id}" + data[i]["is_stream"] = True return data async def process_api( diff --git a/gradio/components/audio.py b/gradio/components/audio.py index c63b486005..4f394778c7 100644 --- a/gradio/components/audio.py +++ b/gradio/components/audio.py @@ -357,7 +357,12 @@ class Audio( self.temp_files.add(file_path) else: file_path = self.make_temp_copy_if_needed(y) - return {"name": file_path, "data": None, "is_file": True} + return { + "name": file_path, + "data": None, + "is_file": True, + "orig_name": Path(file_path).name, + } def stream_output(self, y): if y is None: