mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
Allow caching examples with streamed output (#5295)
* changes * changes * add changeset * add changeset * chages * Update silver-clowns-brush.md * changes * chagers * changes * Update silver-clowns-brush.md * change * change * change * changes * chages * changes * add changeset * changes * changes * changes --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: pngwn <hello@pngwn.io>
This commit is contained in:
parent
a0f22626f2
commit
7b8fa8aa58
6
.changeset/full-melons-pick.md
Normal file
6
.changeset/full-melons-pick.md
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
---
|
||||||
|
"gradio": minor
|
||||||
|
"gradio_client": minor
|
||||||
|
---
|
||||||
|
|
||||||
|
fix:Allow caching examples with streamed output
|
@ -49,7 +49,7 @@ class Serializable:
|
|||||||
types = api_info.get("serialized_output", [api_info["info"]["type"]] * 2) # type: ignore
|
types = api_info.get("serialized_output", [api_info["info"]["type"]] * 2) # type: ignore
|
||||||
return (types[0], types[1])
|
return (types[0], types[1])
|
||||||
|
|
||||||
def serialize(self, x: Any, load_dir: str | Path = ""):
|
def serialize(self, x: Any, load_dir: str | Path = "", allow_links: bool = False):
|
||||||
"""
|
"""
|
||||||
Convert data from human-readable format to serialized format for a browser.
|
Convert data from human-readable format to serialized format for a browser.
|
||||||
"""
|
"""
|
||||||
@ -167,6 +167,7 @@ class ImgSerializable(Serializable):
|
|||||||
self,
|
self,
|
||||||
x: str | None,
|
x: str | None,
|
||||||
load_dir: str | Path = "",
|
load_dir: str | Path = "",
|
||||||
|
allow_links: bool = False,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""
|
"""
|
||||||
Convert from human-friendly version of a file (string filepath) to a serialized
|
Convert from human-friendly version of a file (string filepath) to a serialized
|
||||||
@ -257,7 +258,10 @@ class FileSerializable(Serializable):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _serialize_single(
|
def _serialize_single(
|
||||||
self, x: str | FileData | None, load_dir: str | Path = ""
|
self,
|
||||||
|
x: str | FileData | None,
|
||||||
|
load_dir: str | Path = "",
|
||||||
|
allow_links: bool = False,
|
||||||
) -> FileData | None:
|
) -> FileData | None:
|
||||||
if x is None or isinstance(x, dict):
|
if x is None or isinstance(x, dict):
|
||||||
return x
|
return x
|
||||||
@ -269,9 +273,11 @@ class FileSerializable(Serializable):
|
|||||||
size = Path(filename).stat().st_size
|
size = Path(filename).stat().st_size
|
||||||
return {
|
return {
|
||||||
"name": filename,
|
"name": filename,
|
||||||
"data": utils.encode_url_or_file_to_base64(filename),
|
"data": None
|
||||||
|
if allow_links
|
||||||
|
else utils.encode_url_or_file_to_base64(filename),
|
||||||
"orig_name": Path(filename).name,
|
"orig_name": Path(filename).name,
|
||||||
"is_file": False,
|
"is_file": allow_links,
|
||||||
"size": size,
|
"size": size,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -328,6 +334,7 @@ class FileSerializable(Serializable):
|
|||||||
self,
|
self,
|
||||||
x: str | FileData | None | list[str | FileData | None],
|
x: str | FileData | None | list[str | FileData | None],
|
||||||
load_dir: str | Path = "",
|
load_dir: str | Path = "",
|
||||||
|
allow_links: bool = False,
|
||||||
) -> FileData | None | list[FileData | None]:
|
) -> FileData | None | list[FileData | None]:
|
||||||
"""
|
"""
|
||||||
Convert from human-friendly version of a file (string filepath) to a
|
Convert from human-friendly version of a file (string filepath) to a
|
||||||
@ -335,13 +342,14 @@ class FileSerializable(Serializable):
|
|||||||
Parameters:
|
Parameters:
|
||||||
x: String path to file to serialize
|
x: String path to file to serialize
|
||||||
load_dir: Path to directory containing x
|
load_dir: Path to directory containing x
|
||||||
|
allow_links: Will allow path returns instead of raw file content
|
||||||
"""
|
"""
|
||||||
if x is None or x == "":
|
if x is None or x == "":
|
||||||
return None
|
return None
|
||||||
if isinstance(x, list):
|
if isinstance(x, list):
|
||||||
return [self._serialize_single(f, load_dir=load_dir) for f in x]
|
return [self._serialize_single(f, load_dir, allow_links) for f in x]
|
||||||
else:
|
else:
|
||||||
return self._serialize_single(x, load_dir=load_dir)
|
return self._serialize_single(x, load_dir, allow_links)
|
||||||
|
|
||||||
def deserialize(
|
def deserialize(
|
||||||
self,
|
self,
|
||||||
@ -390,9 +398,9 @@ class VideoSerializable(FileSerializable):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def serialize(
|
def serialize(
|
||||||
self, x: str | None, load_dir: str | Path = ""
|
self, x: str | None, load_dir: str | Path = "", allow_links: bool = False
|
||||||
) -> tuple[FileData | None, None]:
|
) -> tuple[FileData | None, None]:
|
||||||
return (super().serialize(x, load_dir), None) # type: ignore
|
return (super().serialize(x, load_dir, allow_links), None) # type: ignore
|
||||||
|
|
||||||
def deserialize(
|
def deserialize(
|
||||||
self,
|
self,
|
||||||
@ -432,6 +440,7 @@ class JSONSerializable(Serializable):
|
|||||||
self,
|
self,
|
||||||
x: str | None,
|
x: str | None,
|
||||||
load_dir: str | Path = "",
|
load_dir: str | Path = "",
|
||||||
|
allow_links: bool = False,
|
||||||
) -> dict | list | None:
|
) -> dict | list | None:
|
||||||
"""
|
"""
|
||||||
Convert from a a human-friendly version (string path to json file) to a
|
Convert from a a human-friendly version (string path to json file) to a
|
||||||
@ -488,7 +497,7 @@ class GallerySerializable(Serializable):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def serialize(
|
def serialize(
|
||||||
self, x: str | None, load_dir: str | Path = ""
|
self, x: str | None, load_dir: str | Path = "", allow_links: bool = False
|
||||||
) -> list[list[str | None]] | None:
|
) -> list[list[str | None]] | None:
|
||||||
if x is None or x == "":
|
if x is None or x == "":
|
||||||
return None
|
return None
|
||||||
@ -497,7 +506,7 @@ class GallerySerializable(Serializable):
|
|||||||
with captions_file.open("r") as captions_json:
|
with captions_file.open("r") as captions_json:
|
||||||
captions = json.load(captions_json)
|
captions = json.load(captions_json)
|
||||||
for file_name, caption in captions.items():
|
for file_name, caption in captions.items():
|
||||||
img = FileSerializable().serialize(file_name)
|
img = FileSerializable().serialize(file_name, allow_links=allow_links)
|
||||||
files.append([img, caption])
|
files.append([img, caption])
|
||||||
return files
|
return files
|
||||||
|
|
||||||
|
BIN
demo/stream_audio_out/audio/cantina.wav
Normal file
BIN
demo/stream_audio_out/audio/cantina.wav
Normal file
Binary file not shown.
@ -1 +1 @@
|
|||||||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: stream_audio_out"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from pydub import AudioSegment\n", "from time import sleep\n", "\n", "with gr.Blocks() as demo:\n", " input_audio = gr.Audio(label=\"Input Audio\", type=\"filepath\", format=\"mp3\")\n", " with gr.Row():\n", " with gr.Column():\n", " stream_as_file_btn = gr.Button(\"Stream as File\")\n", " format = gr.Radio([\"wav\", \"mp3\"], value=\"wav\", label=\"Format\")\n", " stream_as_file_output = gr.Audio(streaming=True)\n", "\n", " def stream_file(audio_file, format):\n", " audio = AudioSegment.from_file(audio_file)\n", " i = 0\n", " chunk_size = 3000 \n", " while chunk_size*i < len(audio):\n", " chunk = audio[chunk_size*i:chunk_size*(i+1)]\n", " i += 1\n", " if chunk:\n", " file = f\"/tmp/{i}.{format}\"\n", " chunk.export(file, format=format)\n", " yield file\n", " sleep(1)\n", " \n", " stream_as_file_btn.click(stream_file, [input_audio, format], stream_as_file_output)\n", "\n", " with gr.Column():\n", " stream_as_bytes_btn = gr.Button(\"Stream as Bytes\")\n", " stream_as_bytes_output = gr.Audio(format=\"bytes\", streaming=True)\n", "\n", " def stream_bytes(audio_file):\n", " chunk_size = 20_000\n", " with open(audio_file, \"rb\") as f:\n", " while True:\n", " chunk = f.read(chunk_size)\n", " if chunk:\n", " yield chunk\n", " sleep(1)\n", " else:\n", " break\n", " \n", " stream_as_bytes_btn.click(stream_bytes, input_audio, stream_as_bytes_output)\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: stream_audio_out"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('audio')\n", "!wget -q -O audio/cantina.wav https://github.com/gradio-app/gradio/raw/main/demo/stream_audio_out/audio/cantina.wav"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from pydub import AudioSegment\n", "from time import sleep\n", "\n", "with gr.Blocks() as demo:\n", " input_audio = gr.Audio(label=\"Input Audio\", type=\"filepath\", format=\"mp3\")\n", " with gr.Row():\n", " with gr.Column():\n", " stream_as_file_btn = gr.Button(\"Stream as File\")\n", " format = gr.Radio([\"wav\", \"mp3\"], value=\"wav\", label=\"Format\")\n", " stream_as_file_output = gr.Audio(streaming=True)\n", "\n", " def stream_file(audio_file, format):\n", " audio = AudioSegment.from_file(audio_file)\n", " i = 0\n", " chunk_size = 1000\n", " while chunk_size * i < len(audio):\n", " chunk = audio[chunk_size * i : chunk_size * (i + 1)]\n", " i += 1\n", " if chunk:\n", " file = f\"/tmp/{i}.{format}\"\n", " chunk.export(file, format=format)\n", " yield file\n", " sleep(0.5)\n", "\n", " stream_as_file_btn.click(\n", " stream_file, [input_audio, format], stream_as_file_output\n", " )\n", "\n", " gr.Examples(\n", " [[\"audio/cantina.wav\", \"wav\"], [\"audio/cantina.wav\", \"mp3\"]],\n", " [input_audio, format],\n", " fn=stream_file,\n", " outputs=stream_as_file_output,\n", " cache_examples=True,\n", " )\n", "\n", " with gr.Column():\n", " stream_as_bytes_btn = gr.Button(\"Stream as Bytes\")\n", " stream_as_bytes_output = gr.Audio(format=\"bytes\", streaming=True)\n", "\n", " def stream_bytes(audio_file):\n", " chunk_size = 20_000\n", " with open(audio_file, \"rb\") as f:\n", " while True:\n", " chunk = f.read(chunk_size)\n", " if chunk:\n", " yield chunk\n", " sleep(1)\n", " else:\n", " break\n", " stream_as_bytes_btn.click(stream_bytes, input_audio, stream_as_bytes_output)\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
@ -13,17 +13,27 @@ with gr.Blocks() as demo:
|
|||||||
def stream_file(audio_file, format):
|
def stream_file(audio_file, format):
|
||||||
audio = AudioSegment.from_file(audio_file)
|
audio = AudioSegment.from_file(audio_file)
|
||||||
i = 0
|
i = 0
|
||||||
chunk_size = 3000
|
chunk_size = 1000
|
||||||
while chunk_size*i < len(audio):
|
while chunk_size * i < len(audio):
|
||||||
chunk = audio[chunk_size*i:chunk_size*(i+1)]
|
chunk = audio[chunk_size * i : chunk_size * (i + 1)]
|
||||||
i += 1
|
i += 1
|
||||||
if chunk:
|
if chunk:
|
||||||
file = f"/tmp/{i}.{format}"
|
file = f"/tmp/{i}.{format}"
|
||||||
chunk.export(file, format=format)
|
chunk.export(file, format=format)
|
||||||
yield file
|
yield file
|
||||||
sleep(1)
|
sleep(0.5)
|
||||||
|
|
||||||
stream_as_file_btn.click(stream_file, [input_audio, format], stream_as_file_output)
|
stream_as_file_btn.click(
|
||||||
|
stream_file, [input_audio, format], stream_as_file_output
|
||||||
|
)
|
||||||
|
|
||||||
|
gr.Examples(
|
||||||
|
[["audio/cantina.wav", "wav"], ["audio/cantina.wav", "mp3"]],
|
||||||
|
[input_audio, format],
|
||||||
|
fn=stream_file,
|
||||||
|
outputs=stream_as_file_output,
|
||||||
|
cache_examples=True,
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
stream_as_bytes_btn = gr.Button("Stream as Bytes")
|
stream_as_bytes_btn = gr.Button("Stream as Bytes")
|
||||||
@ -39,7 +49,6 @@ with gr.Blocks() as demo:
|
|||||||
sleep(1)
|
sleep(1)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
stream_as_bytes_btn.click(stream_bytes, input_audio, stream_as_bytes_output)
|
stream_as_bytes_btn.click(stream_bytes, input_audio, stream_as_bytes_output)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -291,12 +291,14 @@ class Examples:
|
|||||||
print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'")
|
print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'")
|
||||||
cache_logger = CSVLogger()
|
cache_logger = CSVLogger()
|
||||||
|
|
||||||
|
generated_values = []
|
||||||
if inspect.isgeneratorfunction(self.fn):
|
if inspect.isgeneratorfunction(self.fn):
|
||||||
|
|
||||||
def get_final_item(*args): # type: ignore
|
def get_final_item(*args): # type: ignore
|
||||||
x = None
|
x = None
|
||||||
|
generated_values.clear()
|
||||||
for x in self.fn(*args): # noqa: B007 # type: ignore
|
for x in self.fn(*args): # noqa: B007 # type: ignore
|
||||||
pass
|
generated_values.append(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
fn = get_final_item
|
fn = get_final_item
|
||||||
@ -304,13 +306,15 @@ class Examples:
|
|||||||
|
|
||||||
async def get_final_item(*args):
|
async def get_final_item(*args):
|
||||||
x = None
|
x = None
|
||||||
|
generated_values.clear()
|
||||||
async for x in self.fn(*args): # noqa: B007 # type: ignore
|
async for x in self.fn(*args): # noqa: B007 # type: ignore
|
||||||
pass
|
generated_values.append(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
fn = get_final_item
|
fn = get_final_item
|
||||||
else:
|
else:
|
||||||
fn = self.fn
|
fn = self.fn
|
||||||
|
|
||||||
# create a fake dependency to process the examples and get the predictions
|
# create a fake dependency to process the examples and get the predictions
|
||||||
dependency, fn_index = Context.root_block.set_event_trigger(
|
dependency, fn_index = Context.root_block.set_event_trigger(
|
||||||
event_name="fake_event",
|
event_name="fake_event",
|
||||||
@ -337,6 +341,11 @@ class Examples:
|
|||||||
state={},
|
state={},
|
||||||
)
|
)
|
||||||
output = prediction["data"]
|
output = prediction["data"]
|
||||||
|
if len(generated_values):
|
||||||
|
output = merge_generated_values_into_output(
|
||||||
|
self.outputs, generated_values, output
|
||||||
|
)
|
||||||
|
|
||||||
if self.batch:
|
if self.batch:
|
||||||
output = [value[0] for value in output]
|
output = [value[0] for value in output]
|
||||||
cache_logger.flag(output)
|
cache_logger.flag(output)
|
||||||
@ -395,13 +404,48 @@ class Examples:
|
|||||||
except (ValueError, TypeError, SyntaxError, AssertionError):
|
except (ValueError, TypeError, SyntaxError, AssertionError):
|
||||||
output.append(
|
output.append(
|
||||||
component.serialize(
|
component.serialize(
|
||||||
value_to_use,
|
value_to_use, self.cached_folder, allow_links=True
|
||||||
self.cached_folder,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def merge_generated_values_into_output(
|
||||||
|
components: list[IOComponent], generated_values: list, output: list
|
||||||
|
):
|
||||||
|
from gradio.events import StreamableOutput
|
||||||
|
|
||||||
|
for output_index, output_component in enumerate(components):
|
||||||
|
if (
|
||||||
|
isinstance(output_component, StreamableOutput)
|
||||||
|
and output_component.streaming
|
||||||
|
):
|
||||||
|
binary_chunks = []
|
||||||
|
for i, chunk in enumerate(generated_values):
|
||||||
|
if len(components) > 1:
|
||||||
|
chunk = chunk[output_index]
|
||||||
|
processed_chunk = output_component.postprocess(chunk)
|
||||||
|
binary_chunks.append(
|
||||||
|
output_component.stream_output(processed_chunk, "", i == 0)[0]
|
||||||
|
)
|
||||||
|
binary_data = b"".join(binary_chunks)
|
||||||
|
tempdir = os.environ.get("GRADIO_TEMP_DIR") or str(
|
||||||
|
Path(tempfile.gettempdir()) / "gradio"
|
||||||
|
)
|
||||||
|
os.makedirs(tempdir, exist_ok=True)
|
||||||
|
temp_file = tempfile.NamedTemporaryFile(dir=tempdir, delete=False)
|
||||||
|
with open(temp_file.name, "wb") as f:
|
||||||
|
f.write(binary_data)
|
||||||
|
|
||||||
|
output[output_index] = {
|
||||||
|
"name": temp_file.name,
|
||||||
|
"is_file": True,
|
||||||
|
"data": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class TrackedIterable:
|
class TrackedIterable:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -9,7 +9,8 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import websockets
|
import websockets
|
||||||
from gradio_client import media_data
|
from gradio_client import media_data, utils
|
||||||
|
from pydub import AudioSegment
|
||||||
from starlette.testclient import TestClient
|
from starlette.testclient import TestClient
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@ -93,7 +94,8 @@ class TestExamples:
|
|||||||
)
|
)
|
||||||
|
|
||||||
prediction = await examples.load_from_cache(0)
|
prediction = await examples.load_from_cache(0)
|
||||||
assert prediction[0][0][0]["data"] == media_data.BASE64_IMAGE
|
file = prediction[0][0][0]["name"]
|
||||||
|
assert utils.encode_url_or_file_to_base64(file) == media_data.BASE64_IMAGE
|
||||||
|
|
||||||
|
|
||||||
@patch("gradio.helpers.CACHED_FOLDER", tempfile.mkdtemp())
|
@patch("gradio.helpers.CACHED_FOLDER", tempfile.mkdtemp())
|
||||||
@ -152,7 +154,10 @@ class TestProcessExamples:
|
|||||||
cache_examples=True,
|
cache_examples=True,
|
||||||
)
|
)
|
||||||
prediction = await io.examples_handler.load_from_cache(0)
|
prediction = await io.examples_handler.load_from_cache(0)
|
||||||
assert prediction[0]["data"].startswith("data:audio/wav;base64,UklGRgA/")
|
file = prediction[0]["name"]
|
||||||
|
assert utils.encode_url_or_file_to_base64(file).startswith(
|
||||||
|
"data:audio/wav;base64,UklGRgA/"
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_caching_with_update(self):
|
async def test_caching_with_update(self):
|
||||||
@ -220,6 +225,29 @@ class TestProcessExamples:
|
|||||||
prediction = await io.examples_handler.load_from_cache(0)
|
prediction = await io.examples_handler.load_from_cache(0)
|
||||||
assert prediction[0] == "Your output: abcdef"
|
assert prediction[0] == "Your output: abcdef"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_caching_with_generators_and_streamed_output(self):
|
||||||
|
file_dir = Path(Path(__file__).parent, "test_files")
|
||||||
|
audio = str(file_dir / "audio_sample.wav")
|
||||||
|
|
||||||
|
def test_generator(x):
|
||||||
|
for y in range(int(x)):
|
||||||
|
yield audio, y * 5
|
||||||
|
|
||||||
|
io = gr.Interface(
|
||||||
|
test_generator,
|
||||||
|
"number",
|
||||||
|
[gr.Audio(streaming=True), "number"],
|
||||||
|
examples=[3],
|
||||||
|
cache_examples=True,
|
||||||
|
)
|
||||||
|
prediction = await io.examples_handler.load_from_cache(0)
|
||||||
|
len_input_audio = len(AudioSegment.from_wav(audio))
|
||||||
|
len_output_audio = len(AudioSegment.from_wav(prediction[0]["name"]))
|
||||||
|
length_ratio = len_output_audio / len_input_audio
|
||||||
|
assert round(length_ratio, 1) == 3.0 # might not be exactly 3x
|
||||||
|
assert float(prediction[1]) == 10.0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_caching_with_async_generators(self):
|
async def test_caching_with_async_generators(self):
|
||||||
async def test_generator(x):
|
async def test_generator(x):
|
||||||
|
Loading…
Reference in New Issue
Block a user