Allow caching examples with streamed output (#5295)

* changes

* changes

* add changeset

* add changeset

* chages

* Update silver-clowns-brush.md

* changes

* chagers

* changes

* Update silver-clowns-brush.md

* change

* change

* change

* changes

* chages

* changes

* add changeset

* changes

* changes

* changes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: pngwn <hello@pngwn.io>
This commit is contained in:
aliabid94 2023-08-23 11:47:07 -07:00 committed by GitHub
parent a0f22626f2
commit 7b8fa8aa58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 121 additions and 25 deletions

View File

@ -0,0 +1,6 @@
---
"gradio": minor
"gradio_client": minor
---
fix:Allow caching examples with streamed output

View File

@ -49,7 +49,7 @@ class Serializable:
types = api_info.get("serialized_output", [api_info["info"]["type"]] * 2) # type: ignore types = api_info.get("serialized_output", [api_info["info"]["type"]] * 2) # type: ignore
return (types[0], types[1]) return (types[0], types[1])
def serialize(self, x: Any, load_dir: str | Path = ""): def serialize(self, x: Any, load_dir: str | Path = "", allow_links: bool = False):
""" """
Convert data from human-readable format to serialized format for a browser. Convert data from human-readable format to serialized format for a browser.
""" """
@ -167,6 +167,7 @@ class ImgSerializable(Serializable):
self, self,
x: str | None, x: str | None,
load_dir: str | Path = "", load_dir: str | Path = "",
allow_links: bool = False,
) -> str | None: ) -> str | None:
""" """
Convert from human-friendly version of a file (string filepath) to a serialized Convert from human-friendly version of a file (string filepath) to a serialized
@ -257,7 +258,10 @@ class FileSerializable(Serializable):
} }
def _serialize_single( def _serialize_single(
self, x: str | FileData | None, load_dir: str | Path = "" self,
x: str | FileData | None,
load_dir: str | Path = "",
allow_links: bool = False,
) -> FileData | None: ) -> FileData | None:
if x is None or isinstance(x, dict): if x is None or isinstance(x, dict):
return x return x
@ -269,9 +273,11 @@ class FileSerializable(Serializable):
size = Path(filename).stat().st_size size = Path(filename).stat().st_size
return { return {
"name": filename, "name": filename,
"data": utils.encode_url_or_file_to_base64(filename), "data": None
if allow_links
else utils.encode_url_or_file_to_base64(filename),
"orig_name": Path(filename).name, "orig_name": Path(filename).name,
"is_file": False, "is_file": allow_links,
"size": size, "size": size,
} }
@ -328,6 +334,7 @@ class FileSerializable(Serializable):
self, self,
x: str | FileData | None | list[str | FileData | None], x: str | FileData | None | list[str | FileData | None],
load_dir: str | Path = "", load_dir: str | Path = "",
allow_links: bool = False,
) -> FileData | None | list[FileData | None]: ) -> FileData | None | list[FileData | None]:
""" """
Convert from human-friendly version of a file (string filepath) to a Convert from human-friendly version of a file (string filepath) to a
@ -335,13 +342,14 @@ class FileSerializable(Serializable):
Parameters: Parameters:
x: String path to file to serialize x: String path to file to serialize
load_dir: Path to directory containing x load_dir: Path to directory containing x
allow_links: Will allow path returns instead of raw file content
""" """
if x is None or x == "": if x is None or x == "":
return None return None
if isinstance(x, list): if isinstance(x, list):
return [self._serialize_single(f, load_dir=load_dir) for f in x] return [self._serialize_single(f, load_dir, allow_links) for f in x]
else: else:
return self._serialize_single(x, load_dir=load_dir) return self._serialize_single(x, load_dir, allow_links)
def deserialize( def deserialize(
self, self,
@ -390,9 +398,9 @@ class VideoSerializable(FileSerializable):
} }
def serialize( def serialize(
self, x: str | None, load_dir: str | Path = "" self, x: str | None, load_dir: str | Path = "", allow_links: bool = False
) -> tuple[FileData | None, None]: ) -> tuple[FileData | None, None]:
return (super().serialize(x, load_dir), None) # type: ignore return (super().serialize(x, load_dir, allow_links), None) # type: ignore
def deserialize( def deserialize(
self, self,
@ -432,6 +440,7 @@ class JSONSerializable(Serializable):
self, self,
x: str | None, x: str | None,
load_dir: str | Path = "", load_dir: str | Path = "",
allow_links: bool = False,
) -> dict | list | None: ) -> dict | list | None:
""" """
Convert from a a human-friendly version (string path to json file) to a Convert from a a human-friendly version (string path to json file) to a
@ -488,7 +497,7 @@ class GallerySerializable(Serializable):
} }
def serialize( def serialize(
self, x: str | None, load_dir: str | Path = "" self, x: str | None, load_dir: str | Path = "", allow_links: bool = False
) -> list[list[str | None]] | None: ) -> list[list[str | None]] | None:
if x is None or x == "": if x is None or x == "":
return None return None
@ -497,7 +506,7 @@ class GallerySerializable(Serializable):
with captions_file.open("r") as captions_json: with captions_file.open("r") as captions_json:
captions = json.load(captions_json) captions = json.load(captions_json)
for file_name, caption in captions.items(): for file_name, caption in captions.items():
img = FileSerializable().serialize(file_name) img = FileSerializable().serialize(file_name, allow_links=allow_links)
files.append([img, caption]) files.append([img, caption])
return files return files

Binary file not shown.

View File

@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: stream_audio_out"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from pydub import AudioSegment\n", "from time import sleep\n", "\n", "with gr.Blocks() as demo:\n", " input_audio = gr.Audio(label=\"Input Audio\", type=\"filepath\", format=\"mp3\")\n", " with gr.Row():\n", " with gr.Column():\n", " stream_as_file_btn = gr.Button(\"Stream as File\")\n", " format = gr.Radio([\"wav\", \"mp3\"], value=\"wav\", label=\"Format\")\n", " stream_as_file_output = gr.Audio(streaming=True)\n", "\n", " def stream_file(audio_file, format):\n", " audio = AudioSegment.from_file(audio_file)\n", " i = 0\n", " chunk_size = 3000 \n", " while chunk_size*i < len(audio):\n", " chunk = audio[chunk_size*i:chunk_size*(i+1)]\n", " i += 1\n", " if chunk:\n", " file = f\"/tmp/{i}.{format}\"\n", " chunk.export(file, format=format)\n", " yield file\n", " sleep(1)\n", " \n", " stream_as_file_btn.click(stream_file, [input_audio, format], stream_as_file_output)\n", "\n", " with gr.Column():\n", " stream_as_bytes_btn = gr.Button(\"Stream as Bytes\")\n", " stream_as_bytes_output = gr.Audio(format=\"bytes\", streaming=True)\n", "\n", " def stream_bytes(audio_file):\n", " chunk_size = 20_000\n", " with open(audio_file, \"rb\") as f:\n", " while True:\n", " chunk = f.read(chunk_size)\n", " if chunk:\n", " yield chunk\n", " sleep(1)\n", " else:\n", " break\n", " \n", " stream_as_bytes_btn.click(stream_bytes, input_audio, stream_as_bytes_output)\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} {"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: stream_audio_out"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('audio')\n", "!wget -q -O audio/cantina.wav https://github.com/gradio-app/gradio/raw/main/demo/stream_audio_out/audio/cantina.wav"]}, {"cell_type": "code", "execution_count": null, "id": 44380577570523278879349135829904343037, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from pydub import AudioSegment\n", "from time import sleep\n", "\n", "with gr.Blocks() as demo:\n", " input_audio = gr.Audio(label=\"Input Audio\", type=\"filepath\", format=\"mp3\")\n", " with gr.Row():\n", " with gr.Column():\n", " stream_as_file_btn = gr.Button(\"Stream as File\")\n", " format = gr.Radio([\"wav\", \"mp3\"], value=\"wav\", label=\"Format\")\n", " stream_as_file_output = gr.Audio(streaming=True)\n", "\n", " def stream_file(audio_file, format):\n", " audio = AudioSegment.from_file(audio_file)\n", " i = 0\n", " chunk_size = 1000\n", " while chunk_size * i < len(audio):\n", " chunk = audio[chunk_size * i : chunk_size * (i + 1)]\n", " i += 1\n", " if chunk:\n", " file = f\"/tmp/{i}.{format}\"\n", " chunk.export(file, format=format)\n", " yield file\n", " sleep(0.5)\n", "\n", " stream_as_file_btn.click(\n", " stream_file, [input_audio, format], stream_as_file_output\n", " )\n", "\n", " gr.Examples(\n", " [[\"audio/cantina.wav\", \"wav\"], [\"audio/cantina.wav\", \"mp3\"]],\n", " [input_audio, format],\n", " fn=stream_file,\n", " outputs=stream_as_file_output,\n", " cache_examples=True,\n", " )\n", "\n", " with gr.Column():\n", " stream_as_bytes_btn = gr.Button(\"Stream as Bytes\")\n", " stream_as_bytes_output = gr.Audio(format=\"bytes\", streaming=True)\n", "\n", " def stream_bytes(audio_file):\n", " chunk_size = 20_000\n", " with open(audio_file, \"rb\") as f:\n", " while True:\n", " chunk = f.read(chunk_size)\n", " if chunk:\n", " yield chunk\n", " sleep(1)\n", " else:\n", " break\n", " stream_as_bytes_btn.click(stream_bytes, input_audio, stream_as_bytes_output)\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -13,17 +13,27 @@ with gr.Blocks() as demo:
def stream_file(audio_file, format): def stream_file(audio_file, format):
audio = AudioSegment.from_file(audio_file) audio = AudioSegment.from_file(audio_file)
i = 0 i = 0
chunk_size = 3000 chunk_size = 1000
while chunk_size*i < len(audio): while chunk_size * i < len(audio):
chunk = audio[chunk_size*i:chunk_size*(i+1)] chunk = audio[chunk_size * i : chunk_size * (i + 1)]
i += 1 i += 1
if chunk: if chunk:
file = f"/tmp/{i}.{format}" file = f"/tmp/{i}.{format}"
chunk.export(file, format=format) chunk.export(file, format=format)
yield file yield file
sleep(1) sleep(0.5)
stream_as_file_btn.click(stream_file, [input_audio, format], stream_as_file_output) stream_as_file_btn.click(
stream_file, [input_audio, format], stream_as_file_output
)
gr.Examples(
[["audio/cantina.wav", "wav"], ["audio/cantina.wav", "mp3"]],
[input_audio, format],
fn=stream_file,
outputs=stream_as_file_output,
cache_examples=True,
)
with gr.Column(): with gr.Column():
stream_as_bytes_btn = gr.Button("Stream as Bytes") stream_as_bytes_btn = gr.Button("Stream as Bytes")
@ -39,7 +49,6 @@ with gr.Blocks() as demo:
sleep(1) sleep(1)
else: else:
break break
stream_as_bytes_btn.click(stream_bytes, input_audio, stream_as_bytes_output) stream_as_bytes_btn.click(stream_bytes, input_audio, stream_as_bytes_output)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -291,12 +291,14 @@ class Examples:
print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'") print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'")
cache_logger = CSVLogger() cache_logger = CSVLogger()
generated_values = []
if inspect.isgeneratorfunction(self.fn): if inspect.isgeneratorfunction(self.fn):
def get_final_item(*args): # type: ignore def get_final_item(*args): # type: ignore
x = None x = None
generated_values.clear()
for x in self.fn(*args): # noqa: B007 # type: ignore for x in self.fn(*args): # noqa: B007 # type: ignore
pass generated_values.append(x)
return x return x
fn = get_final_item fn = get_final_item
@ -304,13 +306,15 @@ class Examples:
async def get_final_item(*args): async def get_final_item(*args):
x = None x = None
generated_values.clear()
async for x in self.fn(*args): # noqa: B007 # type: ignore async for x in self.fn(*args): # noqa: B007 # type: ignore
pass generated_values.append(x)
return x return x
fn = get_final_item fn = get_final_item
else: else:
fn = self.fn fn = self.fn
# create a fake dependency to process the examples and get the predictions # create a fake dependency to process the examples and get the predictions
dependency, fn_index = Context.root_block.set_event_trigger( dependency, fn_index = Context.root_block.set_event_trigger(
event_name="fake_event", event_name="fake_event",
@ -337,6 +341,11 @@ class Examples:
state={}, state={},
) )
output = prediction["data"] output = prediction["data"]
if len(generated_values):
output = merge_generated_values_into_output(
self.outputs, generated_values, output
)
if self.batch: if self.batch:
output = [value[0] for value in output] output = [value[0] for value in output]
cache_logger.flag(output) cache_logger.flag(output)
@ -395,13 +404,48 @@ class Examples:
except (ValueError, TypeError, SyntaxError, AssertionError): except (ValueError, TypeError, SyntaxError, AssertionError):
output.append( output.append(
component.serialize( component.serialize(
value_to_use, value_to_use, self.cached_folder, allow_links=True
self.cached_folder,
) )
) )
return output return output
def merge_generated_values_into_output(
components: list[IOComponent], generated_values: list, output: list
):
from gradio.events import StreamableOutput
for output_index, output_component in enumerate(components):
if (
isinstance(output_component, StreamableOutput)
and output_component.streaming
):
binary_chunks = []
for i, chunk in enumerate(generated_values):
if len(components) > 1:
chunk = chunk[output_index]
processed_chunk = output_component.postprocess(chunk)
binary_chunks.append(
output_component.stream_output(processed_chunk, "", i == 0)[0]
)
binary_data = b"".join(binary_chunks)
tempdir = os.environ.get("GRADIO_TEMP_DIR") or str(
Path(tempfile.gettempdir()) / "gradio"
)
os.makedirs(tempdir, exist_ok=True)
temp_file = tempfile.NamedTemporaryFile(dir=tempdir, delete=False)
with open(temp_file.name, "wb") as f:
f.write(binary_data)
output[output_index] = {
"name": temp_file.name,
"is_file": True,
"data": None,
}
return output
class TrackedIterable: class TrackedIterable:
def __init__( def __init__(
self, self,

View File

@ -9,7 +9,8 @@ from unittest.mock import patch
import pytest import pytest
import websockets import websockets
from gradio_client import media_data from gradio_client import media_data, utils
from pydub import AudioSegment
from starlette.testclient import TestClient from starlette.testclient import TestClient
from tqdm import tqdm from tqdm import tqdm
@ -93,7 +94,8 @@ class TestExamples:
) )
prediction = await examples.load_from_cache(0) prediction = await examples.load_from_cache(0)
assert prediction[0][0][0]["data"] == media_data.BASE64_IMAGE file = prediction[0][0][0]["name"]
assert utils.encode_url_or_file_to_base64(file) == media_data.BASE64_IMAGE
@patch("gradio.helpers.CACHED_FOLDER", tempfile.mkdtemp()) @patch("gradio.helpers.CACHED_FOLDER", tempfile.mkdtemp())
@ -152,7 +154,10 @@ class TestProcessExamples:
cache_examples=True, cache_examples=True,
) )
prediction = await io.examples_handler.load_from_cache(0) prediction = await io.examples_handler.load_from_cache(0)
assert prediction[0]["data"].startswith("data:audio/wav;base64,UklGRgA/") file = prediction[0]["name"]
assert utils.encode_url_or_file_to_base64(file).startswith(
"data:audio/wav;base64,UklGRgA/"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_caching_with_update(self): async def test_caching_with_update(self):
@ -220,6 +225,29 @@ class TestProcessExamples:
prediction = await io.examples_handler.load_from_cache(0) prediction = await io.examples_handler.load_from_cache(0)
assert prediction[0] == "Your output: abcdef" assert prediction[0] == "Your output: abcdef"
@pytest.mark.asyncio
async def test_caching_with_generators_and_streamed_output(self):
file_dir = Path(Path(__file__).parent, "test_files")
audio = str(file_dir / "audio_sample.wav")
def test_generator(x):
for y in range(int(x)):
yield audio, y * 5
io = gr.Interface(
test_generator,
"number",
[gr.Audio(streaming=True), "number"],
examples=[3],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
len_input_audio = len(AudioSegment.from_wav(audio))
len_output_audio = len(AudioSegment.from_wav(prediction[0]["name"]))
length_ratio = len_output_audio / len_input_audio
assert round(length_ratio, 1) == 3.0 # might not be exactly 3x
assert float(prediction[1]) == 10.0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_caching_with_async_generators(self): async def test_caching_with_async_generators(self):
async def test_generator(x): async def test_generator(x):