gradio/test/test_helpers.py
cswamy 77c900311e
Fixes issue 5781: Enables specifying a caching directory for Examples (#6803)
* issue 5781 first commit

* second commit

* unnecessary str removed

* backend formatted

* Update gradio/helpers.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Update guides/02_building-interfaces/03_more-on-examples.md

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* tests added

* add changeset

* format

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
2023-12-18 17:07:38 -08:00

883 lines
28 KiB
Python

import asyncio
import json
import os
import shutil
import subprocess
import tempfile
import time
from pathlib import Path
from unittest.mock import patch
import gradio_client as grc
import pytest
from gradio_client import media_data
from gradio_client import utils as client_utils
from pydub import AudioSegment
from starlette.testclient import TestClient
from tqdm import tqdm
import gradio as gr
from gradio import utils
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
class TestExamples:
def test_handle_single_input(self, patched_cache_folder):
examples = gr.Examples(["hello", "hi"], gr.Textbox())
assert examples.processed_examples == [["hello"], ["hi"]]
examples = gr.Examples([["hello"]], gr.Textbox())
assert examples.processed_examples == [["hello"]]
examples = gr.Examples(["test/test_files/bus.png"], gr.Image())
assert (
client_utils.encode_file_to_base64(
examples.processed_examples[0][0]["path"]
)
== media_data.BASE64_IMAGE
)
def test_handle_multiple_inputs(self, patched_cache_folder):
examples = gr.Examples(
[["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()]
)
assert examples.processed_examples[0][0] == "hello"
assert (
client_utils.encode_file_to_base64(
examples.processed_examples[0][1]["path"]
)
== media_data.BASE64_IMAGE
)
def test_handle_directory(self, patched_cache_folder):
examples = gr.Examples("test/test_files/images", gr.Image())
assert len(examples.processed_examples) == 2
for row in examples.processed_examples:
for output in row:
assert (
client_utils.encode_file_to_base64(output["path"])
== media_data.BASE64_IMAGE
)
def test_handle_directory_with_log_file(self, patched_cache_folder):
examples = gr.Examples(
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
)
ex = client_utils.traverse(
examples.processed_examples,
lambda s: client_utils.encode_file_to_base64(s["path"]),
lambda x: isinstance(x, dict) and Path(x["path"]).exists(),
)
assert ex == [
[media_data.BASE64_IMAGE, "hello"],
[media_data.BASE64_IMAGE, "hi"],
]
for sample in examples.dataset.samples:
assert os.path.isabs(sample[0])
def test_examples_per_page(self, patched_cache_folder):
examples = gr.Examples(["hello", "hi"], gr.Textbox(), examples_per_page=2)
assert examples.dataset.get_config()["samples_per_page"] == 2
def test_no_preprocessing(self, patched_cache_folder):
with gr.Blocks():
image = gr.Image()
textbox = gr.Textbox()
examples = gr.Examples(
examples=["test/test_files/bus.png"],
inputs=image,
outputs=textbox,
fn=lambda x: x["path"],
cache_examples=True,
preprocess=False,
)
prediction = examples.load_from_cache(0)
assert (
client_utils.encode_file_to_base64(prediction[0]) == media_data.BASE64_IMAGE
)
def test_no_postprocessing(self, patched_cache_folder):
def im(x):
return [
{
"image": {
"path": "test/test_files/bus.png",
},
"caption": "hi",
}
]
with gr.Blocks():
text = gr.Textbox()
gall = gr.Gallery()
examples = gr.Examples(
examples=["hi"],
inputs=text,
outputs=gall,
fn=im,
cache_examples=True,
postprocess=False,
)
prediction = examples.load_from_cache(0)
file = prediction[0].root[0].image.path
assert client_utils.encode_url_or_file_to_base64(
file
) == client_utils.encode_url_or_file_to_base64("test/test_files/bus.png")
def test_setting_cache_dir_env_variable(monkeypatch):
temp_dir = tempfile.mkdtemp()
monkeypatch.setenv("GRADIO_EXAMPLES_CACHE", temp_dir)
with gr.Blocks():
image = gr.Image()
image2 = gr.Image()
examples = gr.Examples(
examples=["test/test_files/bus.png"],
inputs=image,
outputs=image2,
fn=lambda x: x,
cache_examples=True,
)
prediction = examples.load_from_cache(0)
path_to_cached_file = Path(prediction[0].path)
assert utils.is_in_or_equal(path_to_cached_file, temp_dir)
monkeypatch.delenv("GRADIO_EXAMPLES_CACHE", raising=False)
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
class TestExamplesDataset:
def test_no_headers(self, patched_cache_folder):
examples = gr.Examples("test/test_files/images_log", [gr.Image(), gr.Text()])
assert examples.dataset.headers == []
def test_all_headers(self, patched_cache_folder):
examples = gr.Examples(
"test/test_files/images_log",
[gr.Image(label="im"), gr.Text(label="your text")],
)
assert examples.dataset.headers == ["im", "your text"]
def test_some_headers(self, patched_cache_folder):
examples = gr.Examples(
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
)
assert examples.dataset.headers == ["im", ""]
def test_example_caching_relaunch(connect):
def combine(a, b):
return a + " " + b
with gr.Blocks() as demo:
txt = gr.Textbox(label="Input")
txt_2 = gr.Textbox(label="Input 2")
txt_3 = gr.Textbox(value="", label="Output")
btn = gr.Button(value="Submit")
btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3])
gr.Examples(
[["hi", "Adam"], ["hello", "Eve"]],
[txt, txt_2],
txt_3,
combine,
cache_examples=True,
api_name="examples",
)
with connect(demo) as client:
assert client.predict(1, api_name="/examples") == (
"hello",
"Eve",
"hello Eve",
)
# Let the server shut down
time.sleep(1)
with connect(demo) as client:
assert client.predict(1, api_name="/examples") == (
"hello",
"Eve",
"hello Eve",
)
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
class TestProcessExamples:
def test_caching(self, patched_cache_folder):
io = gr.Interface(
lambda x: f"Hello {x}",
"text",
"text",
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
prediction = io.examples_handler.load_from_cache(1)
assert prediction[0] == "Hello Dunya"
def test_example_caching_relaunch(self, patched_cache_folder, connect):
def combine(a, b):
return a + " " + b
with gr.Blocks() as demo:
txt = gr.Textbox(label="Input")
txt_2 = gr.Textbox(label="Input 2")
txt_3 = gr.Textbox(value="", label="Output")
btn = gr.Button(value="Submit")
btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3])
gr.Examples(
[["hi", "Adam"], ["hello", "Eve"]],
[txt, txt_2],
txt_3,
combine,
cache_examples=True,
api_name="examples",
)
with connect(demo) as client:
assert client.predict(1, api_name="/examples") == (
"hello",
"Eve",
"hello Eve",
)
with connect(demo) as client:
assert client.predict(1, api_name="/examples") == (
"hello",
"Eve",
"hello Eve",
)
def test_caching_image(self, patched_cache_folder):
io = gr.Interface(
lambda x: x,
"image",
"image",
examples=[["test/test_files/bus.png"]],
cache_examples=True,
)
prediction = io.examples_handler.load_from_cache(0)
assert client_utils.encode_url_or_file_to_base64(prediction[0].path).startswith(
""
)
def test_caching_audio(self, patched_cache_folder):
io = gr.Interface(
lambda x: x,
"audio",
"audio",
examples=[["test/test_files/audio_sample.wav"]],
cache_examples=True,
)
prediction = io.examples_handler.load_from_cache(0)
file = prediction[0].path
assert client_utils.encode_url_or_file_to_base64(file).startswith(
"data:audio/wav;base64,UklGRgA/"
)
def test_caching_with_update(self, patched_cache_folder):
io = gr.Interface(
lambda x: gr.update(visible=False),
"text",
"image",
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
prediction = io.examples_handler.load_from_cache(1)
assert prediction[0] == {
"visible": False,
"__type__": "update",
}
def test_caching_with_mix_update(self, patched_cache_folder):
io = gr.Interface(
lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"],
"text",
["text", "image"],
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
prediction = io.examples_handler.load_from_cache(1)
assert prediction[0] == {
"lines": 4,
"value": "hello",
"__type__": "update",
}
def test_caching_with_dict(self, patched_cache_folder):
text = gr.Textbox()
out = gr.Label()
io = gr.Interface(
lambda _: {text: gr.update(lines=4, interactive=False), out: "lion"},
"textbox",
[text, out],
examples=["abc"],
cache_examples=True,
)
prediction = io.examples_handler.load_from_cache(0)
assert prediction == [
{"lines": 4, "__type__": "update", "interactive": False},
gr.Label.data_model(**{"label": "lion", "confidences": None}),
]
def test_caching_with_generators(self, patched_cache_folder):
def test_generator(x):
for y in range(len(x)):
yield "Your output: " + x[: y + 1]
io = gr.Interface(
test_generator,
"textbox",
"textbox",
examples=["abcdef"],
cache_examples=True,
)
prediction = io.examples_handler.load_from_cache(0)
assert prediction[0] == "Your output: abcdef"
def test_caching_with_generators_and_streamed_output(self, patched_cache_folder):
file_dir = Path(Path(__file__).parent, "test_files")
audio = str(file_dir / "audio_sample.wav")
def test_generator(x):
for y in range(int(x)):
yield audio, y * 5
io = gr.Interface(
test_generator,
"number",
[gr.Audio(streaming=True), "number"],
examples=[3],
cache_examples=True,
)
prediction = io.examples_handler.load_from_cache(0)
len_input_audio = len(AudioSegment.from_wav(audio))
len_output_audio = len(AudioSegment.from_wav(prediction[0].path))
length_ratio = len_output_audio / len_input_audio
assert round(length_ratio, 1) == 3.0 # might not be exactly 3x
assert float(prediction[1]) == 10.0
def test_caching_with_async_generators(self, patched_cache_folder):
async def test_generator(x):
for y in range(len(x)):
yield "Your output: " + x[: y + 1]
io = gr.Interface(
test_generator,
"textbox",
"textbox",
examples=["abcdef"],
cache_examples=True,
)
prediction = io.examples_handler.load_from_cache(0)
assert prediction[0] == "Your output: abcdef"
def test_raise_helpful_error_message_if_providing_partial_examples(
self, patched_cache_folder, tmp_path
):
def foo(a, b):
return a + b
with pytest.warns(
UserWarning,
match="^Examples are being cached but not all input components have",
):
with pytest.raises(Exception):
gr.Interface(
foo,
inputs=["text", "text"],
outputs=["text"],
examples=[["foo"], ["bar"]],
cache_examples=True,
)
with pytest.warns(
UserWarning,
match="^Examples are being cached but not all input components have",
):
with pytest.raises(Exception):
gr.Interface(
foo,
inputs=["text", "text"],
outputs=["text"],
examples=[["foo", "bar"], ["bar", None]],
cache_examples=True,
)
def foo_no_exception(a, b=2):
return a * b
gr.Interface(
foo_no_exception,
inputs=["text", "number"],
outputs=["text"],
examples=[["foo"], ["bar"]],
cache_examples=True,
)
def many_missing(a, b, c):
return a * b
with pytest.warns(
UserWarning,
match="^Examples are being cached but not all input components have",
):
with pytest.raises(Exception):
gr.Interface(
many_missing,
inputs=["text", "number", "number"],
outputs=["text"],
examples=[["foo", None, None], ["bar", 2, 3]],
cache_examples=True,
)
def test_caching_with_batch(self, patched_cache_folder):
def trim_words(words, lens):
trimmed_words = [word[:length] for word, length in zip(words, lens)]
return [trimmed_words]
io = gr.Interface(
trim_words,
["textbox", gr.Number(precision=0)],
["textbox"],
batch=True,
max_batch_size=16,
examples=[["hello", 3], ["hi", 4]],
cache_examples=True,
)
prediction = io.examples_handler.load_from_cache(0)
assert prediction == ["hel"]
def test_caching_with_batch_multiple_outputs(self, patched_cache_folder):
def trim_words(words, lens):
trimmed_words = [word[:length] for word, length in zip(words, lens)]
return trimmed_words, lens
io = gr.Interface(
trim_words,
["textbox", gr.Number(precision=0)],
["textbox", gr.Number(precision=0)],
batch=True,
max_batch_size=16,
examples=[["hello", 3], ["hi", 4]],
cache_examples=True,
)
prediction = io.examples_handler.load_from_cache(0)
assert prediction == ["hel", "3"]
def test_caching_with_non_io_component(self, patched_cache_folder):
def predict(name):
return name, gr.update(visible=True)
with gr.Blocks():
t1 = gr.Textbox()
with gr.Column(visible=False) as c:
t2 = gr.Textbox()
examples = gr.Examples(
[["John"], ["Mary"]],
fn=predict,
inputs=[t1],
outputs=[t2, c],
cache_examples=True,
)
prediction = examples.load_from_cache(0)
assert prediction == ["John", {"visible": True, "__type__": "update"}]
def test_end_to_end(self, patched_cache_folder):
def concatenate(str1, str2):
return str1 + str2
with gr.Blocks() as demo:
t1 = gr.Textbox()
t2 = gr.Textbox()
t1.submit(concatenate, [t1, t2], t2)
gr.Examples(
[["Hello,", None], ["Michael", None]],
inputs=[t1, t2],
api_name="load_example",
)
app, _, _ = demo.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/load_example/", json={"data": [0]})
assert response.json()["data"] == [
{
"lines": 1,
"max_lines": 20,
"show_label": True,
"container": True,
"min_width": 160,
"autofocus": False,
"autoscroll": True,
"elem_classes": [],
"rtl": False,
"show_copy_button": False,
"__type__": "update",
"visible": True,
"value": "Hello,",
"type": "text",
}
]
response = client.post("/api/load_example/", json={"data": [1]})
assert response.json()["data"] == [
{
"lines": 1,
"max_lines": 20,
"show_label": True,
"container": True,
"min_width": 160,
"autofocus": False,
"autoscroll": True,
"elem_classes": [],
"rtl": False,
"show_copy_button": False,
"__type__": "update",
"visible": True,
"value": "Michael",
"type": "text",
}
]
def test_end_to_end_cache_examples(self, patched_cache_folder):
def concatenate(str1, str2):
return f"{str1} {str2}"
with gr.Blocks() as demo:
t1 = gr.Textbox()
t2 = gr.Textbox()
t1.submit(concatenate, [t1, t2], t2)
gr.Examples(
examples=[["Hello,", "World"], ["Michael", "Jordan"]],
inputs=[t1, t2],
outputs=[t2],
fn=concatenate,
cache_examples=True,
api_name="load_example",
)
app, _, _ = demo.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/load_example/", json={"data": [0]})
assert response.json()["data"] == ["Hello,", "World", "Hello, World"]
response = client.post("/api/load_example/", json={"data": [1]})
assert response.json()["data"] == ["Michael", "Jordan", "Michael Jordan"]
def test_multiple_file_flagging(tmp_path):
with patch("gradio.utils.get_cache_folder", return_value=tmp_path):
io = gr.Interface(
fn=lambda *x: list(x),
inputs=[
gr.Image(type="filepath", label="frame 1"),
gr.Image(type="filepath", label="frame 2"),
],
outputs=[gr.Files()],
examples=[["test/test_files/cheetah1.jpg", "test/test_files/bus.png"]],
cache_examples=True,
)
prediction = io.examples_handler.load_from_cache(0)
assert len(prediction[0].root) == 2
assert all(isinstance(d, gr.FileData) for d in prediction[0].root)
def test_examples_keep_all_suffixes(tmp_path):
with patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())):
file_1 = tmp_path / "foo.bar.txt"
file_1.write_text("file 1")
file_2 = tmp_path / "file_2"
file_2.mkdir(parents=True)
file_2 = file_2 / "foo.bar.txt"
file_2.write_text("file 2")
io = gr.Interface(
fn=lambda x: x.name,
inputs=gr.File(),
outputs=[gr.File()],
examples=[[str(file_1)], [str(file_2)]],
cache_examples=True,
)
prediction = io.examples_handler.load_from_cache(0)
assert Path(prediction[0].path).read_text() == "file 1"
assert prediction[0].orig_name == "foo.bar.txt"
assert prediction[0].path.endswith("foo.bar.txt")
prediction = io.examples_handler.load_from_cache(1)
assert Path(prediction[0].path).read_text() == "file 2"
assert prediction[0].orig_name == "foo.bar.txt"
assert prediction[0].path.endswith("foo.bar.txt")
def test_make_waveform_with_spaces_in_filename():
with tempfile.TemporaryDirectory() as tmpdirname:
audio = os.path.join(tmpdirname, "test audio.wav")
shutil.copy("test/test_files/audio_sample.wav", audio)
waveform = gr.make_waveform(audio)
assert waveform.endswith(".mp4")
try:
command = [
"ffprobe",
"-v",
"error",
"-select_streams",
"v:0",
"-show_entries",
"stream=width,height",
"-of",
"json",
waveform,
]
result = subprocess.run(command, capture_output=True, text=True, check=True)
output = result.stdout
data = json.loads(output)
width = data["streams"][0]["width"]
height = data["streams"][0]["height"]
assert width == 1000
assert height == 400
except subprocess.CalledProcessError as e:
print("Error retrieving resolution of output waveform video:", e)
def test_make_waveform_raises_if_ffmpeg_fails(tmp_path, monkeypatch):
"""
Test that make_waveform raises an exception if ffmpeg fails,
instead of returning a path to a non-existent or empty file.
"""
audio = tmp_path / "test audio.wav"
shutil.copy("test/test_files/audio_sample.wav", audio)
def _failing_ffmpeg(*args, **kwargs):
raise subprocess.CalledProcessError(1, "ffmpeg")
monkeypatch.setattr(subprocess, "call", _failing_ffmpeg)
with pytest.raises(Exception):
gr.make_waveform(str(audio))
class TestProgressBar:
@pytest.mark.asyncio
async def test_progress_bar(self):
with gr.Blocks() as demo:
name = gr.Textbox()
greeting = gr.Textbox()
button = gr.Button(value="Greet")
def greet(s, prog=gr.Progress()):
prog(0, desc="start")
time.sleep(0.15)
for _ in prog.tqdm(range(4), unit="iter"):
time.sleep(0.15)
time.sleep(0.15)
for _ in tqdm(["a", "b", "c"], desc="alphabet"):
time.sleep(0.15)
return f"Hello, {s}!"
button.click(greet, name, greeting)
demo.queue(max_size=1).launch(prevent_thread_lock=True)
client = grc.Client(demo.local_url)
job = client.submit("Gradio")
status_updates = []
while not job.done():
status = job.status()
update = (
status.progress_data[0].index if status.progress_data else None,
status.progress_data[0].desc if status.progress_data else None,
)
if update != (None, None) and (
len(status_updates) == 0 or status_updates[-1] != update
):
status_updates.append(update)
time.sleep(0.05)
assert status_updates == [
(None, "start"),
(0, None),
(1, None),
(2, None),
(3, None),
(4, None),
]
@pytest.mark.asyncio
async def test_progress_bar_track_tqdm(self):
with gr.Blocks() as demo:
name = gr.Textbox()
greeting = gr.Textbox()
button = gr.Button(value="Greet")
def greet(s, prog=gr.Progress(track_tqdm=True)):
prog(0, desc="start")
time.sleep(0.15)
for _ in prog.tqdm(range(4), unit="iter"):
time.sleep(0.15)
time.sleep(0.15)
for _ in tqdm(["a", "b", "c"], desc="alphabet"):
time.sleep(0.15)
return f"Hello, {s}!"
button.click(greet, name, greeting)
demo.queue(max_size=1).launch(prevent_thread_lock=True)
client = grc.Client(demo.local_url)
job = client.submit("Gradio")
status_updates = []
while not job.done():
status = job.status()
update = (
status.progress_data[0].index if status.progress_data else None,
status.progress_data[0].desc if status.progress_data else None,
)
if update != (None, None) and (
len(status_updates) == 0 or status_updates[-1] != update
):
status_updates.append(update)
time.sleep(0.05)
assert status_updates == [
(None, "start"),
(0, None),
(1, None),
(2, None),
(3, None),
(4, None),
(0, "alphabet"),
(1, "alphabet"),
(2, "alphabet"),
]
@pytest.mark.asyncio
async def test_progress_bar_track_tqdm_without_iterable(self):
def greet(s, _=gr.Progress(track_tqdm=True)):
with tqdm(total=len(s)) as progress_bar:
for _c in s:
progress_bar.update()
time.sleep(0.15)
return f"Hello, {s}!"
demo = gr.Interface(greet, "text", "text")
demo.queue().launch(prevent_thread_lock=True)
client = grc.Client(demo.local_url)
job = client.submit("Gradio")
status_updates = []
while not job.done():
status = job.status()
update = (
status.progress_data[0].index if status.progress_data else None,
status.progress_data[0].unit if status.progress_data else None,
)
if update != (None, None) and (
len(status_updates) == 0 or status_updates[-1] != update
):
status_updates.append(update)
time.sleep(0.05)
assert status_updates == [
(1, "steps"),
(2, "steps"),
(3, "steps"),
(4, "steps"),
(5, "steps"),
(6, "steps"),
]
@pytest.mark.asyncio
async def test_info_and_warning_alerts(self):
def greet(s):
for _c in s:
gr.Info(f"Letter {_c}")
time.sleep(0.15)
if len(s) < 5:
gr.Warning("Too short!")
time.sleep(0.15)
return f"Hello, {s}!"
demo = gr.Interface(greet, "text", "text")
demo.queue().launch(prevent_thread_lock=True)
client = grc.Client(demo.local_url)
job = client.submit("Jon")
status_updates = []
while not job.done():
status = job.status()
update = status.log
if update is not None and (
len(status_updates) == 0 or status_updates[-1] != update
):
status_updates.append(update)
time.sleep(0.05)
assert status_updates == [
("Letter J", "info"),
("Letter o", "info"),
("Letter n", "info"),
("Too short!", "warning"),
]
@pytest.mark.asyncio
@pytest.mark.parametrize("async_handler", [True, False])
async def test_info_isolation(async_handler: bool):
async def greet_async(name):
await asyncio.sleep(2)
gr.Info(f"Hello {name}")
await asyncio.sleep(1)
return name
def greet_sync(name):
time.sleep(2)
gr.Info(f"Hello {name}")
time.sleep(1)
return name
demo = gr.Interface(
greet_async if async_handler else greet_sync,
"text",
"text",
concurrency_limit=2,
)
demo.launch(prevent_thread_lock=True)
async def session_interaction(name, delay=0):
client = grc.Client(demo.local_url)
job = client.submit(name)
status_updates = []
while not job.done():
status = job.status()
update = status.log
if update is not None and (
len(status_updates) == 0 or status_updates[-1] != update
):
status_updates.append(update)
time.sleep(0.05)
return status_updates[-1][0] if status_updates else None
alice_logs, bob_logs = await asyncio.gather(
session_interaction("Alice"),
session_interaction("Bob", delay=1),
)
assert alice_logs == "Hello Alice"
assert bob_logs == "Hello Bob"