mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
Allow caching examples with streamed output (#5295)
* changes * changes * add changeset * add changeset * chages * Update silver-clowns-brush.md * changes * chagers * changes * Update silver-clowns-brush.md * change * change * change * changes * chages * changes * add changeset * changes * changes * changes --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: pngwn <hello@pngwn.io>
This commit is contained in:
parent
a0f22626f2
commit
7b8fa8aa58
6
.changeset/full-melons-pick.md
Normal file
6
.changeset/full-melons-pick.md
Normal file
@ -0,0 +1,6 @@
|
||||
---
|
||||
"gradio": minor
|
||||
"gradio_client": minor
|
||||
---
|
||||
|
||||
fix:Allow caching examples with streamed output
|
@ -49,7 +49,7 @@ class Serializable:
|
||||
types = api_info.get("serialized_output", [api_info["info"]["type"]] * 2) # type: ignore
|
||||
return (types[0], types[1])
|
||||
|
||||
def serialize(self, x: Any, load_dir: str | Path = ""):
|
||||
def serialize(self, x: Any, load_dir: str | Path = "", allow_links: bool = False):
|
||||
"""
|
||||
Convert data from human-readable format to serialized format for a browser.
|
||||
"""
|
||||
@ -167,6 +167,7 @@ class ImgSerializable(Serializable):
|
||||
self,
|
||||
x: str | None,
|
||||
load_dir: str | Path = "",
|
||||
allow_links: bool = False,
|
||||
) -> str | None:
|
||||
"""
|
||||
Convert from human-friendly version of a file (string filepath) to a serialized
|
||||
@ -257,7 +258,10 @@ class FileSerializable(Serializable):
|
||||
}
|
||||
|
||||
def _serialize_single(
|
||||
self, x: str | FileData | None, load_dir: str | Path = ""
|
||||
self,
|
||||
x: str | FileData | None,
|
||||
load_dir: str | Path = "",
|
||||
allow_links: bool = False,
|
||||
) -> FileData | None:
|
||||
if x is None or isinstance(x, dict):
|
||||
return x
|
||||
@ -269,9 +273,11 @@ class FileSerializable(Serializable):
|
||||
size = Path(filename).stat().st_size
|
||||
return {
|
||||
"name": filename,
|
||||
"data": utils.encode_url_or_file_to_base64(filename),
|
||||
"data": None
|
||||
if allow_links
|
||||
else utils.encode_url_or_file_to_base64(filename),
|
||||
"orig_name": Path(filename).name,
|
||||
"is_file": False,
|
||||
"is_file": allow_links,
|
||||
"size": size,
|
||||
}
|
||||
|
||||
@ -328,6 +334,7 @@ class FileSerializable(Serializable):
|
||||
self,
|
||||
x: str | FileData | None | list[str | FileData | None],
|
||||
load_dir: str | Path = "",
|
||||
allow_links: bool = False,
|
||||
) -> FileData | None | list[FileData | None]:
|
||||
"""
|
||||
Convert from human-friendly version of a file (string filepath) to a
|
||||
@ -335,13 +342,14 @@ class FileSerializable(Serializable):
|
||||
Parameters:
|
||||
x: String path to file to serialize
|
||||
load_dir: Path to directory containing x
|
||||
allow_links: Will allow path returns instead of raw file content
|
||||
"""
|
||||
if x is None or x == "":
|
||||
return None
|
||||
if isinstance(x, list):
|
||||
return [self._serialize_single(f, load_dir=load_dir) for f in x]
|
||||
return [self._serialize_single(f, load_dir, allow_links) for f in x]
|
||||
else:
|
||||
return self._serialize_single(x, load_dir=load_dir)
|
||||
return self._serialize_single(x, load_dir, allow_links)
|
||||
|
||||
def deserialize(
|
||||
self,
|
||||
@ -390,9 +398,9 @@ class VideoSerializable(FileSerializable):
|
||||
}
|
||||
|
||||
def serialize(
|
||||
self, x: str | None, load_dir: str | Path = ""
|
||||
self, x: str | None, load_dir: str | Path = "", allow_links: bool = False
|
||||
) -> tuple[FileData | None, None]:
|
||||
return (super().serialize(x, load_dir), None) # type: ignore
|
||||
return (super().serialize(x, load_dir, allow_links), None) # type: ignore
|
||||
|
||||
def deserialize(
|
||||
self,
|
||||
@ -432,6 +440,7 @@ class JSONSerializable(Serializable):
|
||||
self,
|
||||
x: str | None,
|
||||
load_dir: str | Path = "",
|
||||
allow_links: bool = False,
|
||||
) -> dict | list | None:
|
||||
"""
|
||||
Convert from a a human-friendly version (string path to json file) to a
|
||||
@ -488,7 +497,7 @@ class GallerySerializable(Serializable):
|
||||
}
|
||||
|
||||
def serialize(
|
||||
self, x: str | None, load_dir: str | Path = ""
|
||||
self, x: str | None, load_dir: str | Path = "", allow_links: bool = False
|
||||
) -> list[list[str | None]] | None:
|
||||
if x is None or x == "":
|
||||
return None
|
||||
@ -497,7 +506,7 @@ class GallerySerializable(Serializable):
|
||||
with captions_file.open("r") as captions_json:
|
||||
captions = json.load(captions_json)
|
||||
for file_name, caption in captions.items():
|
||||
img = FileSerializable().serialize(file_name)
|
||||
img = FileSerializable().serialize(file_name, allow_links=allow_links)
|
||||
files.append([img, caption])
|
||||
return files
|
||||
|
||||
|
BIN
demo/stream_audio_out/audio/cantina.wav
Normal file
BIN
demo/stream_audio_out/audio/cantina.wav
Normal file
Binary file not shown.
@ -1 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: stream_audio_out"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from pydub import AudioSegment\n", "from time import sleep\n", "\n", "with gr.Blocks() as demo:\n", " input_audio = gr.Audio(label=\"Input Audio\", type=\"filepath\", format=\"mp3\")\n", " with gr.Row():\n", " with gr.Column():\n", " stream_as_file_btn = gr.Button(\"Stream as File\")\n", " format = gr.Radio([\"wav\", \"mp3\"], value=\"wav\", label=\"Format\")\n", " stream_as_file_output = gr.Audio(streaming=True)\n", "\n", " def stream_file(audio_file, format):\n", " audio = AudioSegment.from_file(audio_file)\n", " i = 0\n", " chunk_size = 3000 \n", " while chunk_size*i < len(audio):\n", " chunk = audio[chunk_size*i:chunk_size*(i+1)]\n", " i += 1\n", " if chunk:\n", " file = f\"/tmp/{i}.{format}\"\n", " chunk.export(file, format=format)\n", " yield file\n", " sleep(1)\n", " \n", " stream_as_file_btn.click(stream_file, [input_audio, format], stream_as_file_output)\n", "\n", " with gr.Column():\n", " stream_as_bytes_btn = gr.Button(\"Stream as Bytes\")\n", " stream_as_bytes_output = gr.Audio(format=\"bytes\", streaming=True)\n", "\n", " def stream_bytes(audio_file):\n", " chunk_size = 20_000\n", " with open(audio_file, \"rb\") as f:\n", " while True:\n", " chunk = f.read(chunk_size)\n", " if chunk:\n", " yield chunk\n", " sleep(1)\n", " else:\n", " break\n", " \n", " stream_as_bytes_btn.click(stream_bytes, input_audio, stream_as_bytes_output)\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: stream_audio_out"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('audio')\n", "!wget -q -O audio/cantina.wav https://github.com/gradio-app/gradio/raw/main/demo/stream_audio_out/audio/cantina.wav"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from pydub import AudioSegment\n", "from time import sleep\n", "\n", "with gr.Blocks() as demo:\n", " input_audio = gr.Audio(label=\"Input Audio\", type=\"filepath\", format=\"mp3\")\n", " with gr.Row():\n", " with gr.Column():\n", " stream_as_file_btn = gr.Button(\"Stream as File\")\n", " format = gr.Radio([\"wav\", \"mp3\"], value=\"wav\", label=\"Format\")\n", " stream_as_file_output = gr.Audio(streaming=True)\n", "\n", " def stream_file(audio_file, format):\n", " audio = AudioSegment.from_file(audio_file)\n", " i = 0\n", " chunk_size = 1000\n", " while chunk_size * i < len(audio):\n", " chunk = audio[chunk_size * i : chunk_size * (i + 1)]\n", " i += 1\n", " if chunk:\n", " file = f\"/tmp/{i}.{format}\"\n", " chunk.export(file, format=format)\n", " yield file\n", " sleep(0.5)\n", "\n", " stream_as_file_btn.click(\n", " stream_file, [input_audio, format], stream_as_file_output\n", " )\n", "\n", " gr.Examples(\n", " [[\"audio/cantina.wav\", \"wav\"], [\"audio/cantina.wav\", \"mp3\"]],\n", " [input_audio, format],\n", " fn=stream_file,\n", " outputs=stream_as_file_output,\n", " cache_examples=True,\n", " )\n", "\n", " with gr.Column():\n", " stream_as_bytes_btn = gr.Button(\"Stream as Bytes\")\n", " stream_as_bytes_output = gr.Audio(format=\"bytes\", streaming=True)\n", "\n", " def stream_bytes(audio_file):\n", " chunk_size = 20_000\n", " with open(audio_file, \"rb\") as f:\n", " while True:\n", " chunk = f.read(chunk_size)\n", " if chunk:\n", " yield chunk\n", " sleep(1)\n", " else:\n", " break\n", " stream_as_bytes_btn.click(stream_bytes, input_audio, stream_as_bytes_output)\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -13,17 +13,27 @@ with gr.Blocks() as demo:
|
||||
def stream_file(audio_file, format):
|
||||
audio = AudioSegment.from_file(audio_file)
|
||||
i = 0
|
||||
chunk_size = 3000
|
||||
while chunk_size*i < len(audio):
|
||||
chunk = audio[chunk_size*i:chunk_size*(i+1)]
|
||||
chunk_size = 1000
|
||||
while chunk_size * i < len(audio):
|
||||
chunk = audio[chunk_size * i : chunk_size * (i + 1)]
|
||||
i += 1
|
||||
if chunk:
|
||||
file = f"/tmp/{i}.{format}"
|
||||
chunk.export(file, format=format)
|
||||
yield file
|
||||
sleep(1)
|
||||
|
||||
stream_as_file_btn.click(stream_file, [input_audio, format], stream_as_file_output)
|
||||
sleep(0.5)
|
||||
|
||||
stream_as_file_btn.click(
|
||||
stream_file, [input_audio, format], stream_as_file_output
|
||||
)
|
||||
|
||||
gr.Examples(
|
||||
[["audio/cantina.wav", "wav"], ["audio/cantina.wav", "mp3"]],
|
||||
[input_audio, format],
|
||||
fn=stream_file,
|
||||
outputs=stream_as_file_output,
|
||||
cache_examples=True,
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
stream_as_bytes_btn = gr.Button("Stream as Bytes")
|
||||
@ -39,7 +49,6 @@ with gr.Blocks() as demo:
|
||||
sleep(1)
|
||||
else:
|
||||
break
|
||||
|
||||
stream_as_bytes_btn.click(stream_bytes, input_audio, stream_as_bytes_output)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -291,12 +291,14 @@ class Examples:
|
||||
print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'")
|
||||
cache_logger = CSVLogger()
|
||||
|
||||
generated_values = []
|
||||
if inspect.isgeneratorfunction(self.fn):
|
||||
|
||||
def get_final_item(*args): # type: ignore
|
||||
x = None
|
||||
generated_values.clear()
|
||||
for x in self.fn(*args): # noqa: B007 # type: ignore
|
||||
pass
|
||||
generated_values.append(x)
|
||||
return x
|
||||
|
||||
fn = get_final_item
|
||||
@ -304,13 +306,15 @@ class Examples:
|
||||
|
||||
async def get_final_item(*args):
|
||||
x = None
|
||||
generated_values.clear()
|
||||
async for x in self.fn(*args): # noqa: B007 # type: ignore
|
||||
pass
|
||||
generated_values.append(x)
|
||||
return x
|
||||
|
||||
fn = get_final_item
|
||||
else:
|
||||
fn = self.fn
|
||||
|
||||
# create a fake dependency to process the examples and get the predictions
|
||||
dependency, fn_index = Context.root_block.set_event_trigger(
|
||||
event_name="fake_event",
|
||||
@ -337,6 +341,11 @@ class Examples:
|
||||
state={},
|
||||
)
|
||||
output = prediction["data"]
|
||||
if len(generated_values):
|
||||
output = merge_generated_values_into_output(
|
||||
self.outputs, generated_values, output
|
||||
)
|
||||
|
||||
if self.batch:
|
||||
output = [value[0] for value in output]
|
||||
cache_logger.flag(output)
|
||||
@ -395,13 +404,48 @@ class Examples:
|
||||
except (ValueError, TypeError, SyntaxError, AssertionError):
|
||||
output.append(
|
||||
component.serialize(
|
||||
value_to_use,
|
||||
self.cached_folder,
|
||||
value_to_use, self.cached_folder, allow_links=True
|
||||
)
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def merge_generated_values_into_output(
|
||||
components: list[IOComponent], generated_values: list, output: list
|
||||
):
|
||||
from gradio.events import StreamableOutput
|
||||
|
||||
for output_index, output_component in enumerate(components):
|
||||
if (
|
||||
isinstance(output_component, StreamableOutput)
|
||||
and output_component.streaming
|
||||
):
|
||||
binary_chunks = []
|
||||
for i, chunk in enumerate(generated_values):
|
||||
if len(components) > 1:
|
||||
chunk = chunk[output_index]
|
||||
processed_chunk = output_component.postprocess(chunk)
|
||||
binary_chunks.append(
|
||||
output_component.stream_output(processed_chunk, "", i == 0)[0]
|
||||
)
|
||||
binary_data = b"".join(binary_chunks)
|
||||
tempdir = os.environ.get("GRADIO_TEMP_DIR") or str(
|
||||
Path(tempfile.gettempdir()) / "gradio"
|
||||
)
|
||||
os.makedirs(tempdir, exist_ok=True)
|
||||
temp_file = tempfile.NamedTemporaryFile(dir=tempdir, delete=False)
|
||||
with open(temp_file.name, "wb") as f:
|
||||
f.write(binary_data)
|
||||
|
||||
output[output_index] = {
|
||||
"name": temp_file.name,
|
||||
"is_file": True,
|
||||
"data": None,
|
||||
}
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TrackedIterable:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -9,7 +9,8 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import websockets
|
||||
from gradio_client import media_data
|
||||
from gradio_client import media_data, utils
|
||||
from pydub import AudioSegment
|
||||
from starlette.testclient import TestClient
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -93,7 +94,8 @@ class TestExamples:
|
||||
)
|
||||
|
||||
prediction = await examples.load_from_cache(0)
|
||||
assert prediction[0][0][0]["data"] == media_data.BASE64_IMAGE
|
||||
file = prediction[0][0][0]["name"]
|
||||
assert utils.encode_url_or_file_to_base64(file) == media_data.BASE64_IMAGE
|
||||
|
||||
|
||||
@patch("gradio.helpers.CACHED_FOLDER", tempfile.mkdtemp())
|
||||
@ -152,7 +154,10 @@ class TestProcessExamples:
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(0)
|
||||
assert prediction[0]["data"].startswith("data:audio/wav;base64,UklGRgA/")
|
||||
file = prediction[0]["name"]
|
||||
assert utils.encode_url_or_file_to_base64(file).startswith(
|
||||
"data:audio/wav;base64,UklGRgA/"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_with_update(self):
|
||||
@ -220,6 +225,29 @@ class TestProcessExamples:
|
||||
prediction = await io.examples_handler.load_from_cache(0)
|
||||
assert prediction[0] == "Your output: abcdef"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_with_generators_and_streamed_output(self):
|
||||
file_dir = Path(Path(__file__).parent, "test_files")
|
||||
audio = str(file_dir / "audio_sample.wav")
|
||||
|
||||
def test_generator(x):
|
||||
for y in range(int(x)):
|
||||
yield audio, y * 5
|
||||
|
||||
io = gr.Interface(
|
||||
test_generator,
|
||||
"number",
|
||||
[gr.Audio(streaming=True), "number"],
|
||||
examples=[3],
|
||||
cache_examples=True,
|
||||
)
|
||||
prediction = await io.examples_handler.load_from_cache(0)
|
||||
len_input_audio = len(AudioSegment.from_wav(audio))
|
||||
len_output_audio = len(AudioSegment.from_wav(prediction[0]["name"]))
|
||||
length_ratio = len_output_audio / len_input_audio
|
||||
assert round(length_ratio, 1) == 3.0 # might not be exactly 3x
|
||||
assert float(prediction[1]) == 10.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_with_async_generators(self):
|
||||
async def test_generator(x):
|
||||
|
Loading…
Reference in New Issue
Block a user