import asyncio import json import os import shutil import subprocess import tempfile import time from pathlib import Path from unittest.mock import patch import gradio_client as grc import pytest from gradio_client import media_data from gradio_client import utils as client_utils from pydub import AudioSegment from starlette.testclient import TestClient from tqdm import tqdm import gradio as gr from gradio import utils @patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())) class TestExamples: def test_handle_single_input(self, patched_cache_folder): examples = gr.Examples(["hello", "hi"], gr.Textbox()) assert examples.processed_examples == [["hello"], ["hi"]] examples = gr.Examples([["hello"]], gr.Textbox()) assert examples.processed_examples == [["hello"]] examples = gr.Examples(["test/test_files/bus.png"], gr.Image()) assert ( client_utils.encode_file_to_base64( examples.processed_examples[0][0]["path"] ) == media_data.BASE64_IMAGE ) def test_handle_multiple_inputs(self, patched_cache_folder): examples = gr.Examples( [["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()] ) assert examples.processed_examples[0][0] == "hello" assert ( client_utils.encode_file_to_base64( examples.processed_examples[0][1]["path"] ) == media_data.BASE64_IMAGE ) def test_handle_directory(self, patched_cache_folder): examples = gr.Examples("test/test_files/images", gr.Image()) assert len(examples.processed_examples) == 2 for row in examples.processed_examples: for output in row: assert ( client_utils.encode_file_to_base64(output["path"]) == media_data.BASE64_IMAGE ) def test_handle_directory_with_log_file(self, patched_cache_folder): examples = gr.Examples( "test/test_files/images_log", [gr.Image(label="im"), gr.Text()] ) ex = client_utils.traverse( examples.processed_examples, lambda s: client_utils.encode_file_to_base64(s["path"]), lambda x: isinstance(x, dict) and Path(x["path"]).exists(), ) assert ex == [ [media_data.BASE64_IMAGE, "hello"], [media_data.BASE64_IMAGE, "hi"], ] for sample in examples.dataset.samples: assert os.path.isabs(sample[0]) def test_examples_per_page(self, patched_cache_folder): examples = gr.Examples(["hello", "hi"], gr.Textbox(), examples_per_page=2) assert examples.dataset.get_config()["samples_per_page"] == 2 def test_no_preprocessing(self, patched_cache_folder): with gr.Blocks(): image = gr.Image() textbox = gr.Textbox() examples = gr.Examples( examples=["test/test_files/bus.png"], inputs=image, outputs=textbox, fn=lambda x: x["path"], cache_examples=True, preprocess=False, ) prediction = examples.load_from_cache(0) assert ( client_utils.encode_file_to_base64(prediction[0]) == media_data.BASE64_IMAGE ) def test_no_postprocessing(self, patched_cache_folder): def im(x): return [ { "image": { "path": "test/test_files/bus.png", }, "caption": "hi", } ] with gr.Blocks(): text = gr.Textbox() gall = gr.Gallery() examples = gr.Examples( examples=["hi"], inputs=text, outputs=gall, fn=im, cache_examples=True, postprocess=False, ) prediction = examples.load_from_cache(0) file = prediction[0].root[0].image.path assert client_utils.encode_url_or_file_to_base64( file ) == client_utils.encode_url_or_file_to_base64("test/test_files/bus.png") def test_setting_cache_dir_env_variable(monkeypatch): temp_dir = tempfile.mkdtemp() monkeypatch.setenv("GRADIO_EXAMPLES_CACHE", temp_dir) with gr.Blocks(): image = gr.Image() image2 = gr.Image() examples = gr.Examples( examples=["test/test_files/bus.png"], inputs=image, outputs=image2, fn=lambda x: x, cache_examples=True, ) prediction = examples.load_from_cache(0) path_to_cached_file = Path(prediction[0].path) assert utils.is_in_or_equal(path_to_cached_file, temp_dir) monkeypatch.delenv("GRADIO_EXAMPLES_CACHE", raising=False) @patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())) class TestExamplesDataset: def test_no_headers(self, patched_cache_folder): examples = gr.Examples("test/test_files/images_log", [gr.Image(), gr.Text()]) assert examples.dataset.headers == [] def test_all_headers(self, patched_cache_folder): examples = gr.Examples( "test/test_files/images_log", [gr.Image(label="im"), gr.Text(label="your text")], ) assert examples.dataset.headers == ["im", "your text"] def test_some_headers(self, patched_cache_folder): examples = gr.Examples( "test/test_files/images_log", [gr.Image(label="im"), gr.Text()] ) assert examples.dataset.headers == ["im", ""] def test_example_caching_relaunch(connect): def combine(a, b): return a + " " + b with gr.Blocks() as demo: txt = gr.Textbox(label="Input") txt_2 = gr.Textbox(label="Input 2") txt_3 = gr.Textbox(value="", label="Output") btn = gr.Button(value="Submit") btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3]) gr.Examples( [["hi", "Adam"], ["hello", "Eve"]], [txt, txt_2], txt_3, combine, cache_examples=True, api_name="examples", ) with connect(demo) as client: assert client.predict(1, api_name="/examples") == ( "hello", "Eve", "hello Eve", ) # Let the server shut down time.sleep(1) with connect(demo) as client: assert client.predict(1, api_name="/examples") == ( "hello", "Eve", "hello Eve", ) @patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())) class TestProcessExamples: def test_caching(self, patched_cache_folder): io = gr.Interface( lambda x: f"Hello {x}", "text", "text", examples=[["World"], ["Dunya"], ["Monde"]], cache_examples=True, ) prediction = io.examples_handler.load_from_cache(1) assert prediction[0] == "Hello Dunya" def test_example_caching_relaunch(self, patched_cache_folder, connect): def combine(a, b): return a + " " + b with gr.Blocks() as demo: txt = gr.Textbox(label="Input") txt_2 = gr.Textbox(label="Input 2") txt_3 = gr.Textbox(value="", label="Output") btn = gr.Button(value="Submit") btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3]) gr.Examples( [["hi", "Adam"], ["hello", "Eve"]], [txt, txt_2], txt_3, combine, cache_examples=True, api_name="examples", ) with connect(demo) as client: assert client.predict(1, api_name="/examples") == ( "hello", "Eve", "hello Eve", ) with connect(demo) as client: assert client.predict(1, api_name="/examples") == ( "hello", "Eve", "hello Eve", ) def test_caching_image(self, patched_cache_folder): io = gr.Interface( lambda x: x, "image", "image", examples=[["test/test_files/bus.png"]], cache_examples=True, ) prediction = io.examples_handler.load_from_cache(0) assert client_utils.encode_url_or_file_to_base64(prediction[0].path).startswith( "data:image/png;base64,iVBORw0KGgoAAA" ) def test_caching_audio(self, patched_cache_folder): io = gr.Interface( lambda x: x, "audio", "audio", examples=[["test/test_files/audio_sample.wav"]], cache_examples=True, ) prediction = io.examples_handler.load_from_cache(0) file = prediction[0].path assert client_utils.encode_url_or_file_to_base64(file).startswith( "data:audio/wav;base64,UklGRgA/" ) def test_caching_with_update(self, patched_cache_folder): io = gr.Interface( lambda x: gr.update(visible=False), "text", "image", examples=[["World"], ["Dunya"], ["Monde"]], cache_examples=True, ) prediction = io.examples_handler.load_from_cache(1) assert prediction[0] == { "visible": False, "__type__": "update", } def test_caching_with_mix_update(self, patched_cache_folder): io = gr.Interface( lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"], "text", ["text", "image"], examples=[["World"], ["Dunya"], ["Monde"]], cache_examples=True, ) prediction = io.examples_handler.load_from_cache(1) assert prediction[0] == { "lines": 4, "value": "hello", "__type__": "update", } def test_caching_with_dict(self, patched_cache_folder): text = gr.Textbox() out = gr.Label() io = gr.Interface( lambda _: {text: gr.update(lines=4, interactive=False), out: "lion"}, "textbox", [text, out], examples=["abc"], cache_examples=True, ) prediction = io.examples_handler.load_from_cache(0) assert prediction == [ {"lines": 4, "__type__": "update", "interactive": False}, gr.Label.data_model(**{"label": "lion", "confidences": None}), ] def test_caching_with_generators(self, patched_cache_folder): def test_generator(x): for y in range(len(x)): yield "Your output: " + x[: y + 1] io = gr.Interface( test_generator, "textbox", "textbox", examples=["abcdef"], cache_examples=True, ) prediction = io.examples_handler.load_from_cache(0) assert prediction[0] == "Your output: abcdef" def test_caching_with_generators_and_streamed_output(self, patched_cache_folder): file_dir = Path(Path(__file__).parent, "test_files") audio = str(file_dir / "audio_sample.wav") def test_generator(x): for y in range(int(x)): yield audio, y * 5 io = gr.Interface( test_generator, "number", [gr.Audio(streaming=True), "number"], examples=[3], cache_examples=True, ) prediction = io.examples_handler.load_from_cache(0) len_input_audio = len(AudioSegment.from_wav(audio)) len_output_audio = len(AudioSegment.from_wav(prediction[0].path)) length_ratio = len_output_audio / len_input_audio assert round(length_ratio, 1) == 3.0 # might not be exactly 3x assert float(prediction[1]) == 10.0 def test_caching_with_async_generators(self, patched_cache_folder): async def test_generator(x): for y in range(len(x)): yield "Your output: " + x[: y + 1] io = gr.Interface( test_generator, "textbox", "textbox", examples=["abcdef"], cache_examples=True, ) prediction = io.examples_handler.load_from_cache(0) assert prediction[0] == "Your output: abcdef" def test_raise_helpful_error_message_if_providing_partial_examples( self, patched_cache_folder, tmp_path ): def foo(a, b): return a + b with pytest.warns( UserWarning, match="^Examples are being cached but not all input components have", ): with pytest.raises(Exception): gr.Interface( foo, inputs=["text", "text"], outputs=["text"], examples=[["foo"], ["bar"]], cache_examples=True, ) with pytest.warns( UserWarning, match="^Examples are being cached but not all input components have", ): with pytest.raises(Exception): gr.Interface( foo, inputs=["text", "text"], outputs=["text"], examples=[["foo", "bar"], ["bar", None]], cache_examples=True, ) def foo_no_exception(a, b=2): return a * b gr.Interface( foo_no_exception, inputs=["text", "number"], outputs=["text"], examples=[["foo"], ["bar"]], cache_examples=True, ) def many_missing(a, b, c): return a * b with pytest.warns( UserWarning, match="^Examples are being cached but not all input components have", ): with pytest.raises(Exception): gr.Interface( many_missing, inputs=["text", "number", "number"], outputs=["text"], examples=[["foo", None, None], ["bar", 2, 3]], cache_examples=True, ) def test_caching_with_batch(self, patched_cache_folder): def trim_words(words, lens): trimmed_words = [word[:length] for word, length in zip(words, lens)] return [trimmed_words] io = gr.Interface( trim_words, ["textbox", gr.Number(precision=0)], ["textbox"], batch=True, max_batch_size=16, examples=[["hello", 3], ["hi", 4]], cache_examples=True, ) prediction = io.examples_handler.load_from_cache(0) assert prediction == ["hel"] def test_caching_with_batch_multiple_outputs(self, patched_cache_folder): def trim_words(words, lens): trimmed_words = [word[:length] for word, length in zip(words, lens)] return trimmed_words, lens io = gr.Interface( trim_words, ["textbox", gr.Number(precision=0)], ["textbox", gr.Number(precision=0)], batch=True, max_batch_size=16, examples=[["hello", 3], ["hi", 4]], cache_examples=True, ) prediction = io.examples_handler.load_from_cache(0) assert prediction == ["hel", "3"] def test_caching_with_non_io_component(self, patched_cache_folder): def predict(name): return name, gr.update(visible=True) with gr.Blocks(): t1 = gr.Textbox() with gr.Column(visible=False) as c: t2 = gr.Textbox() examples = gr.Examples( [["John"], ["Mary"]], fn=predict, inputs=[t1], outputs=[t2, c], cache_examples=True, ) prediction = examples.load_from_cache(0) assert prediction == ["John", {"visible": True, "__type__": "update"}] def test_end_to_end(self, patched_cache_folder): def concatenate(str1, str2): return str1 + str2 with gr.Blocks() as demo: t1 = gr.Textbox() t2 = gr.Textbox() t1.submit(concatenate, [t1, t2], t2) gr.Examples( [["Hello,", None], ["Michael", None]], inputs=[t1, t2], api_name="load_example", ) app, _, _ = demo.launch(prevent_thread_lock=True) client = TestClient(app) response = client.post("/api/load_example/", json={"data": [0]}) assert response.json()["data"] == [ { "lines": 1, "max_lines": 20, "show_label": True, "container": True, "min_width": 160, "autofocus": False, "autoscroll": True, "elem_classes": [], "rtl": False, "show_copy_button": False, "__type__": "update", "visible": True, "value": "Hello,", "type": "text", } ] response = client.post("/api/load_example/", json={"data": [1]}) assert response.json()["data"] == [ { "lines": 1, "max_lines": 20, "show_label": True, "container": True, "min_width": 160, "autofocus": False, "autoscroll": True, "elem_classes": [], "rtl": False, "show_copy_button": False, "__type__": "update", "visible": True, "value": "Michael", "type": "text", } ] def test_end_to_end_cache_examples(self, patched_cache_folder): def concatenate(str1, str2): return f"{str1} {str2}" with gr.Blocks() as demo: t1 = gr.Textbox() t2 = gr.Textbox() t1.submit(concatenate, [t1, t2], t2) gr.Examples( examples=[["Hello,", "World"], ["Michael", "Jordan"]], inputs=[t1, t2], outputs=[t2], fn=concatenate, cache_examples=True, api_name="load_example", ) app, _, _ = demo.launch(prevent_thread_lock=True) client = TestClient(app) response = client.post("/api/load_example/", json={"data": [0]}) assert response.json()["data"] == ["Hello,", "World", "Hello, World"] response = client.post("/api/load_example/", json={"data": [1]}) assert response.json()["data"] == ["Michael", "Jordan", "Michael Jordan"] def test_multiple_file_flagging(tmp_path): with patch("gradio.utils.get_cache_folder", return_value=tmp_path): io = gr.Interface( fn=lambda *x: list(x), inputs=[ gr.Image(type="filepath", label="frame 1"), gr.Image(type="filepath", label="frame 2"), ], outputs=[gr.Files()], examples=[["test/test_files/cheetah1.jpg", "test/test_files/bus.png"]], cache_examples=True, ) prediction = io.examples_handler.load_from_cache(0) assert len(prediction[0].root) == 2 assert all(isinstance(d, gr.FileData) for d in prediction[0].root) def test_examples_keep_all_suffixes(tmp_path): with patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())): file_1 = tmp_path / "foo.bar.txt" file_1.write_text("file 1") file_2 = tmp_path / "file_2" file_2.mkdir(parents=True) file_2 = file_2 / "foo.bar.txt" file_2.write_text("file 2") io = gr.Interface( fn=lambda x: x.name, inputs=gr.File(), outputs=[gr.File()], examples=[[str(file_1)], [str(file_2)]], cache_examples=True, ) prediction = io.examples_handler.load_from_cache(0) assert Path(prediction[0].path).read_text() == "file 1" assert prediction[0].orig_name == "foo.bar.txt" assert prediction[0].path.endswith("foo.bar.txt") prediction = io.examples_handler.load_from_cache(1) assert Path(prediction[0].path).read_text() == "file 2" assert prediction[0].orig_name == "foo.bar.txt" assert prediction[0].path.endswith("foo.bar.txt") def test_make_waveform_with_spaces_in_filename(): with tempfile.TemporaryDirectory() as tmpdirname: audio = os.path.join(tmpdirname, "test audio.wav") shutil.copy("test/test_files/audio_sample.wav", audio) waveform = gr.make_waveform(audio) assert waveform.endswith(".mp4") try: command = [ "ffprobe", "-v", "error", "-select_streams", "v:0", "-show_entries", "stream=width,height", "-of", "json", waveform, ] result = subprocess.run(command, capture_output=True, text=True, check=True) output = result.stdout data = json.loads(output) width = data["streams"][0]["width"] height = data["streams"][0]["height"] assert width == 1000 assert height == 400 except subprocess.CalledProcessError as e: print("Error retrieving resolution of output waveform video:", e) def test_make_waveform_raises_if_ffmpeg_fails(tmp_path, monkeypatch): """ Test that make_waveform raises an exception if ffmpeg fails, instead of returning a path to a non-existent or empty file. """ audio = tmp_path / "test audio.wav" shutil.copy("test/test_files/audio_sample.wav", audio) def _failing_ffmpeg(*args, **kwargs): raise subprocess.CalledProcessError(1, "ffmpeg") monkeypatch.setattr(subprocess, "call", _failing_ffmpeg) with pytest.raises(Exception): gr.make_waveform(str(audio)) class TestProgressBar: @pytest.mark.asyncio async def test_progress_bar(self): with gr.Blocks() as demo: name = gr.Textbox() greeting = gr.Textbox() button = gr.Button(value="Greet") def greet(s, prog=gr.Progress()): prog(0, desc="start") time.sleep(0.15) for _ in prog.tqdm(range(4), unit="iter"): time.sleep(0.15) time.sleep(0.15) for _ in tqdm(["a", "b", "c"], desc="alphabet"): time.sleep(0.15) return f"Hello, {s}!" button.click(greet, name, greeting) demo.queue(max_size=1).launch(prevent_thread_lock=True) client = grc.Client(demo.local_url) job = client.submit("Gradio") status_updates = [] while not job.done(): status = job.status() update = ( status.progress_data[0].index if status.progress_data else None, status.progress_data[0].desc if status.progress_data else None, ) if update != (None, None) and ( len(status_updates) == 0 or status_updates[-1] != update ): status_updates.append(update) time.sleep(0.05) assert status_updates == [ (None, "start"), (0, None), (1, None), (2, None), (3, None), (4, None), ] @pytest.mark.asyncio async def test_progress_bar_track_tqdm(self): with gr.Blocks() as demo: name = gr.Textbox() greeting = gr.Textbox() button = gr.Button(value="Greet") def greet(s, prog=gr.Progress(track_tqdm=True)): prog(0, desc="start") time.sleep(0.15) for _ in prog.tqdm(range(4), unit="iter"): time.sleep(0.15) time.sleep(0.15) for _ in tqdm(["a", "b", "c"], desc="alphabet"): time.sleep(0.15) return f"Hello, {s}!" button.click(greet, name, greeting) demo.queue(max_size=1).launch(prevent_thread_lock=True) client = grc.Client(demo.local_url) job = client.submit("Gradio") status_updates = [] while not job.done(): status = job.status() update = ( status.progress_data[0].index if status.progress_data else None, status.progress_data[0].desc if status.progress_data else None, ) if update != (None, None) and ( len(status_updates) == 0 or status_updates[-1] != update ): status_updates.append(update) time.sleep(0.05) assert status_updates == [ (None, "start"), (0, None), (1, None), (2, None), (3, None), (4, None), (0, "alphabet"), (1, "alphabet"), (2, "alphabet"), ] @pytest.mark.asyncio async def test_progress_bar_track_tqdm_without_iterable(self): def greet(s, _=gr.Progress(track_tqdm=True)): with tqdm(total=len(s)) as progress_bar: for _c in s: progress_bar.update() time.sleep(0.15) return f"Hello, {s}!" demo = gr.Interface(greet, "text", "text") demo.queue().launch(prevent_thread_lock=True) client = grc.Client(demo.local_url) job = client.submit("Gradio") status_updates = [] while not job.done(): status = job.status() update = ( status.progress_data[0].index if status.progress_data else None, status.progress_data[0].unit if status.progress_data else None, ) if update != (None, None) and ( len(status_updates) == 0 or status_updates[-1] != update ): status_updates.append(update) time.sleep(0.05) assert status_updates == [ (1, "steps"), (2, "steps"), (3, "steps"), (4, "steps"), (5, "steps"), (6, "steps"), ] @pytest.mark.asyncio async def test_info_and_warning_alerts(self): def greet(s): for _c in s: gr.Info(f"Letter {_c}") time.sleep(0.15) if len(s) < 5: gr.Warning("Too short!") time.sleep(0.15) return f"Hello, {s}!" demo = gr.Interface(greet, "text", "text") demo.queue().launch(prevent_thread_lock=True) client = grc.Client(demo.local_url) job = client.submit("Jon") status_updates = [] while not job.done(): status = job.status() update = status.log if update is not None and ( len(status_updates) == 0 or status_updates[-1] != update ): status_updates.append(update) time.sleep(0.05) assert status_updates == [ ("Letter J", "info"), ("Letter o", "info"), ("Letter n", "info"), ("Too short!", "warning"), ] @pytest.mark.asyncio @pytest.mark.parametrize("async_handler", [True, False]) async def test_info_isolation(async_handler: bool): async def greet_async(name): await asyncio.sleep(2) gr.Info(f"Hello {name}") await asyncio.sleep(1) return name def greet_sync(name): time.sleep(2) gr.Info(f"Hello {name}") time.sleep(1) return name demo = gr.Interface( greet_async if async_handler else greet_sync, "text", "text", concurrency_limit=2, ) demo.launch(prevent_thread_lock=True) async def session_interaction(name, delay=0): client = grc.Client(demo.local_url) job = client.submit(name) status_updates = [] while not job.done(): status = job.status() update = status.log if update is not None and ( len(status_updates) == 0 or status_updates[-1] != update ): status_updates.append(update) time.sleep(0.05) return status_updates[-1][0] if status_updates else None alice_logs, bob_logs = await asyncio.gather( session_interaction("Alice"), session_interaction("Bob", delay=1), ) assert alice_logs == "Hello Alice" assert bob_logs == "Hello Bob"