mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
7d134e0b30
* fix * Add code * add changeset * Fix bug * lint * add changeset * fix both cache examples=False,True * format --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
967 lines
31 KiB
Python
967 lines
31 KiB
Python
import asyncio
|
|
import os
|
|
import tempfile
|
|
import time
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
import gradio_client as grc
|
|
import pytest
|
|
from gradio_client import media_data
|
|
from gradio_client import utils as client_utils
|
|
from pydub import AudioSegment
|
|
from starlette.testclient import TestClient
|
|
from tqdm import tqdm
|
|
|
|
import gradio as gr
|
|
from gradio import helpers, utils
|
|
from gradio.route_utils import API_PREFIX
|
|
|
|
|
|
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
|
|
class TestExamples:
|
|
def test_handle_single_input(self, patched_cache_folder):
|
|
examples = gr.Examples(["hello", "hi"], gr.Textbox())
|
|
assert examples.non_none_processed_examples.as_list() == [["hello"], ["hi"]]
|
|
|
|
examples = gr.Examples([["hello"]], gr.Textbox())
|
|
assert examples.non_none_processed_examples.as_list() == [["hello"]]
|
|
|
|
examples = gr.Examples(["test/test_files/bus.png"], gr.Image())
|
|
assert (
|
|
client_utils.encode_file_to_base64(
|
|
examples.non_none_processed_examples.as_list()[0][0]["path"]
|
|
)
|
|
== media_data.BASE64_IMAGE
|
|
)
|
|
|
|
def test_handle_multiple_inputs(self, patched_cache_folder):
|
|
examples = gr.Examples(
|
|
[["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()]
|
|
)
|
|
assert examples.non_none_processed_examples.as_list()[0][0] == "hello"
|
|
assert (
|
|
client_utils.encode_file_to_base64(
|
|
examples.non_none_processed_examples.as_list()[0][1]["path"]
|
|
)
|
|
== media_data.BASE64_IMAGE
|
|
)
|
|
|
|
def test_handle_directory(self, patched_cache_folder):
|
|
examples = gr.Examples("test/test_files/images", gr.Image())
|
|
assert len(examples.non_none_processed_examples.as_list()) == 2
|
|
for row in examples.non_none_processed_examples.as_list():
|
|
for output in row:
|
|
assert (
|
|
client_utils.encode_file_to_base64(output["path"])
|
|
== media_data.BASE64_IMAGE
|
|
)
|
|
|
|
def test_handle_directory_with_log_file(self, patched_cache_folder):
|
|
examples = gr.Examples(
|
|
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
|
|
)
|
|
ex = client_utils.traverse(
|
|
examples.non_none_processed_examples.as_list(),
|
|
lambda s: client_utils.encode_file_to_base64(s["path"]),
|
|
lambda x: isinstance(x, dict) and Path(x["path"]).exists(),
|
|
)
|
|
assert ex == [
|
|
[media_data.BASE64_IMAGE, "hello"],
|
|
[media_data.BASE64_IMAGE, "hi"],
|
|
]
|
|
for sample in examples.dataset.samples:
|
|
assert os.path.isabs(sample[0]["path"])
|
|
|
|
def test_examples_per_page(self, patched_cache_folder):
|
|
examples = gr.Examples(["hello", "hi"], gr.Textbox(), examples_per_page=2)
|
|
assert examples.dataset.get_config()["samples_per_page"] == 2
|
|
|
|
def test_no_preprocessing(self, patched_cache_folder, connect):
|
|
with gr.Blocks() as demo:
|
|
image = gr.Image()
|
|
textbox = gr.Textbox()
|
|
|
|
examples = gr.Examples(
|
|
examples=["test/test_files/bus.png"],
|
|
inputs=image,
|
|
outputs=textbox,
|
|
fn=lambda x: x["path"],
|
|
cache_examples=True,
|
|
preprocess=False,
|
|
)
|
|
|
|
with connect(demo):
|
|
prediction = examples.load_from_cache(0)
|
|
assert (
|
|
client_utils.encode_file_to_base64(prediction[0]) == media_data.BASE64_IMAGE
|
|
)
|
|
|
|
def test_no_postprocessing(self, patched_cache_folder, connect):
|
|
def im(x):
|
|
return [
|
|
{
|
|
"image": {
|
|
"path": "test/test_files/bus.png",
|
|
},
|
|
"caption": "hi",
|
|
}
|
|
]
|
|
|
|
with gr.Blocks() as demo:
|
|
text = gr.Textbox()
|
|
gall = gr.Gallery()
|
|
|
|
examples = gr.Examples(
|
|
examples=["hi"],
|
|
inputs=text,
|
|
outputs=gall,
|
|
fn=im,
|
|
cache_examples=True,
|
|
postprocess=False,
|
|
)
|
|
|
|
with connect(demo):
|
|
prediction = examples.load_from_cache(0)
|
|
file = prediction[0].root[0].image.path
|
|
assert client_utils.encode_url_or_file_to_base64(
|
|
file
|
|
) == client_utils.encode_url_or_file_to_base64("test/test_files/bus.png")
|
|
|
|
|
|
def test_setting_cache_dir_env_variable(monkeypatch, connect):
|
|
temp_dir = tempfile.mkdtemp()
|
|
monkeypatch.setenv("GRADIO_EXAMPLES_CACHE", temp_dir)
|
|
with gr.Blocks() as demo:
|
|
image = gr.Image()
|
|
image2 = gr.Image()
|
|
|
|
examples = gr.Examples(
|
|
examples=["test/test_files/bus.png"],
|
|
inputs=image,
|
|
outputs=image2,
|
|
fn=lambda x: x,
|
|
cache_examples=True,
|
|
)
|
|
|
|
with connect(demo):
|
|
prediction = examples.load_from_cache(0)
|
|
path_to_cached_file = Path(prediction[0].path)
|
|
assert utils.is_in_or_equal(path_to_cached_file, temp_dir)
|
|
monkeypatch.delenv("GRADIO_EXAMPLES_CACHE", raising=False)
|
|
|
|
|
|
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
|
|
class TestExamplesDataset:
|
|
def test_no_headers(self, patched_cache_folder):
|
|
examples = gr.Examples("test/test_files/images_log", [gr.Image(), gr.Text()])
|
|
assert examples.dataset.headers == []
|
|
|
|
def test_all_headers(self, patched_cache_folder):
|
|
examples = gr.Examples(
|
|
"test/test_files/images_log",
|
|
[gr.Image(label="im"), gr.Text(label="your text")],
|
|
)
|
|
assert examples.dataset.headers == ["im", "your text"]
|
|
|
|
def test_some_headers(self, patched_cache_folder):
|
|
examples = gr.Examples(
|
|
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
|
|
)
|
|
assert examples.dataset.headers == ["im", ""]
|
|
|
|
def test_example_labels(self, patched_cache_folder):
|
|
examples = gr.Examples(
|
|
examples=[
|
|
[5, "add", 3],
|
|
[4, "divide", 2],
|
|
[-4, "multiply", 2.5],
|
|
[0, "subtract", 1.2],
|
|
],
|
|
inputs=[
|
|
gr.Number(),
|
|
gr.Radio(["add", "divide", "multiply", "subtract"]),
|
|
gr.Number(),
|
|
],
|
|
example_labels=["add", "divide", "multiply", "subtract"],
|
|
)
|
|
assert examples.dataset.sample_labels == [
|
|
"add",
|
|
"divide",
|
|
"multiply",
|
|
"subtract",
|
|
]
|
|
|
|
|
|
def test_example_caching_relaunch(connect):
|
|
def combine(a, b):
|
|
return a + " " + b
|
|
|
|
with gr.Blocks() as demo:
|
|
txt = gr.Textbox(label="Input")
|
|
txt_2 = gr.Textbox(label="Input 2")
|
|
txt_3 = gr.Textbox(value="", label="Output")
|
|
btn = gr.Button(value="Submit")
|
|
btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3])
|
|
gr.Examples(
|
|
[["hi", "Adam"], ["hello", "Eve"]],
|
|
[txt, txt_2],
|
|
txt_3,
|
|
combine,
|
|
cache_examples=True,
|
|
api_name="examples",
|
|
)
|
|
|
|
with connect(demo) as client:
|
|
assert client.predict(1, api_name="/examples") == (
|
|
"hello",
|
|
"Eve",
|
|
"hello Eve",
|
|
)
|
|
|
|
# Let the server shut down
|
|
time.sleep(1)
|
|
|
|
with connect(demo) as client:
|
|
assert client.predict(1, api_name="/examples") == (
|
|
"hello",
|
|
"Eve",
|
|
"hello Eve",
|
|
)
|
|
|
|
|
|
@patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp()))
|
|
class TestProcessExamples:
|
|
def test_caching(self, patched_cache_folder, connect):
|
|
io = gr.Interface(
|
|
lambda x: f"Hello {x}",
|
|
"text",
|
|
"text",
|
|
examples=[["World"], ["Dunya"], ["Monde"]],
|
|
cache_examples=True,
|
|
)
|
|
with connect(io):
|
|
prediction = io.examples_handler.load_from_cache(1)
|
|
assert prediction[0] == "Hello Dunya"
|
|
|
|
def test_example_caching_relaunch(self, patched_cache_folder, connect):
|
|
def combine(a, b):
|
|
return a + " " + b
|
|
|
|
with gr.Blocks() as demo:
|
|
txt = gr.Textbox(label="Input")
|
|
txt_2 = gr.Textbox(label="Input 2")
|
|
txt_3 = gr.Textbox(value="", label="Output")
|
|
btn = gr.Button(value="Submit")
|
|
btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3])
|
|
gr.Examples(
|
|
[["hi", "Adam"], ["hello", "Eve"]],
|
|
[txt, txt_2],
|
|
txt_3,
|
|
combine,
|
|
cache_examples=True,
|
|
api_name="examples",
|
|
)
|
|
|
|
with connect(demo) as client:
|
|
assert client.predict(1, api_name="/examples") == (
|
|
"hello",
|
|
"Eve",
|
|
"hello Eve",
|
|
)
|
|
|
|
with connect(demo) as client:
|
|
assert client.predict(1, api_name="/examples") == (
|
|
"hello",
|
|
"Eve",
|
|
"hello Eve",
|
|
)
|
|
|
|
def test_caching_image(self, patched_cache_folder, connect):
|
|
io = gr.Interface(
|
|
lambda x: x,
|
|
"image",
|
|
"image",
|
|
examples=[["test/test_files/bus.png"]],
|
|
cache_examples=True,
|
|
)
|
|
with connect(io):
|
|
prediction = io.examples_handler.load_from_cache(0)
|
|
assert prediction[0].path.endswith(".webp")
|
|
|
|
def test_caching_audio(self, patched_cache_folder, connect):
|
|
io = gr.Interface(
|
|
lambda x: x,
|
|
"audio",
|
|
"audio",
|
|
examples=[["test/test_files/audio_sample.wav"]],
|
|
cache_examples=True,
|
|
)
|
|
with connect(io):
|
|
prediction = io.examples_handler.load_from_cache(0)
|
|
file = prediction[0].path
|
|
assert client_utils.encode_url_or_file_to_base64(file).startswith(
|
|
"data:audio/wav;base64,UklGRgA/"
|
|
)
|
|
|
|
def test_caching_with_update(self, patched_cache_folder, connect):
|
|
io = gr.Interface(
|
|
lambda x: gr.update(visible=False),
|
|
"text",
|
|
"image",
|
|
examples=[["World"], ["Dunya"], ["Monde"]],
|
|
cache_examples=True,
|
|
)
|
|
with connect(io):
|
|
prediction = io.examples_handler.load_from_cache(1)
|
|
assert prediction[0] == {
|
|
"visible": False,
|
|
"__type__": "update",
|
|
}
|
|
|
|
def test_caching_with_mix_update(self, patched_cache_folder, connect):
|
|
io = gr.Interface(
|
|
lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"],
|
|
"text",
|
|
["text", "image"],
|
|
examples=[["World"], ["Dunya"], ["Monde"]],
|
|
cache_examples=True,
|
|
)
|
|
with connect(io):
|
|
prediction = io.examples_handler.load_from_cache(1)
|
|
assert prediction[0] == {
|
|
"lines": 4,
|
|
"value": "hello",
|
|
"__type__": "update",
|
|
}
|
|
|
|
def test_caching_with_dict(self, patched_cache_folder, connect):
|
|
text = gr.Textbox()
|
|
out = gr.Label()
|
|
|
|
io = gr.Interface(
|
|
lambda _: {text: gr.update(lines=4, interactive=False), out: "lion"},
|
|
"textbox",
|
|
[text, out],
|
|
examples=["abc"],
|
|
cache_examples=True,
|
|
)
|
|
with connect(io):
|
|
prediction = io.examples_handler.load_from_cache(0)
|
|
assert prediction == [
|
|
{"lines": 4, "__type__": "update", "interactive": False},
|
|
gr.Label.data_model(**{"label": "lion", "confidences": None}),
|
|
]
|
|
|
|
def test_caching_with_generators(self, patched_cache_folder, connect):
|
|
def test_generator(x):
|
|
for y in range(len(x)):
|
|
yield "Your output: " + x[: y + 1]
|
|
|
|
io = gr.Interface(
|
|
test_generator,
|
|
"textbox",
|
|
"textbox",
|
|
examples=["abcdef"],
|
|
cache_examples=True,
|
|
)
|
|
with connect(io):
|
|
prediction = io.examples_handler.load_from_cache(0)
|
|
assert prediction[0] == "Your output: abcdef"
|
|
|
|
def test_caching_with_generators_and_streamed_output(
|
|
self, patched_cache_folder, connect
|
|
):
|
|
file_dir = Path(Path(__file__).parent, "test_files")
|
|
audio = str(file_dir / "audio_sample.wav")
|
|
|
|
def test_generator(x):
|
|
for y in range(int(x)):
|
|
yield audio, y * 5
|
|
|
|
io = gr.Interface(
|
|
test_generator,
|
|
"number",
|
|
[gr.Audio(streaming=True), "number"],
|
|
examples=[3],
|
|
cache_examples=True,
|
|
)
|
|
with connect(io):
|
|
prediction = io.examples_handler.load_from_cache(0)
|
|
len_input_audio = len(AudioSegment.from_file(audio))
|
|
len_output_audio = len(AudioSegment.from_file(prediction[0].path))
|
|
length_ratio = len_output_audio / len_input_audio
|
|
assert 3 <= round(length_ratio, 1) < 4 # might not be exactly 3x
|
|
assert float(prediction[1]) == 10.0
|
|
|
|
def test_caching_with_async_generators(self, patched_cache_folder, connect):
|
|
async def test_generator(x):
|
|
for y in range(len(x)):
|
|
yield "Your output: " + x[: y + 1]
|
|
|
|
io = gr.Interface(
|
|
test_generator,
|
|
"textbox",
|
|
"textbox",
|
|
examples=["abcdef"],
|
|
cache_examples=True,
|
|
)
|
|
with connect(io):
|
|
prediction = io.examples_handler.load_from_cache(0)
|
|
assert prediction[0] == "Your output: abcdef"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_raise_helpful_error_message_if_providing_partial_examples(
|
|
self, patched_cache_folder, tmp_path
|
|
):
|
|
def foo(a, b):
|
|
return a + b
|
|
|
|
with pytest.warns(
|
|
UserWarning,
|
|
match="^Examples will be cached but not all input components have",
|
|
):
|
|
with pytest.raises(Exception):
|
|
io = gr.Interface(
|
|
foo,
|
|
inputs=["text", "text"],
|
|
outputs=["text"],
|
|
examples=[["foo"], ["bar"]],
|
|
cache_examples=True,
|
|
)
|
|
await io.examples_handler._start_caching()
|
|
|
|
with pytest.warns(
|
|
UserWarning,
|
|
match="^Examples will be cached but not all input components have",
|
|
):
|
|
with pytest.raises(Exception):
|
|
io = gr.Interface(
|
|
foo,
|
|
inputs=["text", "text"],
|
|
outputs=["text"],
|
|
examples=[["foo", "bar"], ["bar", None]],
|
|
cache_examples=True,
|
|
)
|
|
await io.examples_handler._start_caching()
|
|
|
|
def foo_no_exception(a, b=2):
|
|
return a * b
|
|
|
|
gr.Interface(
|
|
foo_no_exception,
|
|
inputs=["text", "number"],
|
|
outputs=["text"],
|
|
examples=[["foo"], ["bar"]],
|
|
cache_examples=True,
|
|
)
|
|
|
|
def many_missing(a, b, c):
|
|
return a * b
|
|
|
|
with pytest.warns(
|
|
UserWarning,
|
|
match="^Examples will be cached but not all input components have",
|
|
):
|
|
with pytest.raises(Exception):
|
|
io = gr.Interface(
|
|
many_missing,
|
|
inputs=["text", "number", "number"],
|
|
outputs=["text"],
|
|
examples=[["foo", None, None], ["bar", 2, 3]],
|
|
cache_examples=True,
|
|
)
|
|
await io.examples_handler._start_caching()
|
|
|
|
def test_caching_with_batch(self, patched_cache_folder, connect):
|
|
def trim_words(words, lens):
|
|
trimmed_words = [
|
|
word[:length] for word, length in zip(words, lens, strict=False)
|
|
]
|
|
return [trimmed_words]
|
|
|
|
io = gr.Interface(
|
|
trim_words,
|
|
["textbox", gr.Number(precision=0)],
|
|
["textbox"],
|
|
batch=True,
|
|
max_batch_size=16,
|
|
examples=[["hello", 3], ["hi", 4]],
|
|
cache_examples=True,
|
|
)
|
|
with connect(io):
|
|
prediction = io.examples_handler.load_from_cache(0)
|
|
assert prediction == ["hel"]
|
|
|
|
def test_caching_with_batch_multiple_outputs(self, patched_cache_folder, connect):
|
|
def trim_words(words, lens):
|
|
trimmed_words = [
|
|
word[:length] for word, length in zip(words, lens, strict=False)
|
|
]
|
|
return trimmed_words, lens
|
|
|
|
io = gr.Interface(
|
|
trim_words,
|
|
["textbox", gr.Number(precision=0)],
|
|
["textbox", gr.Number(precision=0)],
|
|
batch=True,
|
|
max_batch_size=16,
|
|
examples=[["hello", 3], ["hi", 4]],
|
|
cache_examples=True,
|
|
)
|
|
with connect(io):
|
|
prediction = io.examples_handler.load_from_cache(0)
|
|
assert prediction == ["hel", "3"]
|
|
|
|
def test_caching_with_non_io_component(self, patched_cache_folder, connect):
|
|
def predict(name):
|
|
return name, gr.update(visible=True)
|
|
|
|
with gr.Blocks() as demo:
|
|
t1 = gr.Textbox()
|
|
with gr.Column(visible=False) as c:
|
|
t2 = gr.Textbox()
|
|
|
|
examples = gr.Examples(
|
|
[["John"], ["Mary"]],
|
|
fn=predict,
|
|
inputs=[t1],
|
|
outputs=[t2, c], # type: ignore
|
|
cache_examples=True,
|
|
)
|
|
|
|
with connect(demo):
|
|
prediction = examples.load_from_cache(0)
|
|
assert prediction == ["John", {"visible": True, "__type__": "update"}]
|
|
|
|
def test_end_to_end(self, patched_cache_folder):
|
|
def concatenate(str1, str2):
|
|
return str1 + str2
|
|
|
|
with gr.Blocks() as demo:
|
|
t1 = gr.Textbox()
|
|
t2 = gr.Textbox()
|
|
t1.submit(concatenate, [t1, t2], t2)
|
|
|
|
gr.Examples(
|
|
[["Hello,", None], ["Michael", None]],
|
|
inputs=[t1, t2],
|
|
api_name="load_example",
|
|
)
|
|
|
|
app, _, _ = demo.launch(prevent_thread_lock=True)
|
|
client = TestClient(app)
|
|
|
|
response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [0]})
|
|
assert response.json()["data"] == [
|
|
{
|
|
"lines": 1,
|
|
"max_lines": 20,
|
|
"show_label": True,
|
|
"container": True,
|
|
"min_width": 160,
|
|
"autofocus": False,
|
|
"autoscroll": True,
|
|
"elem_classes": [],
|
|
"rtl": False,
|
|
"show_copy_button": False,
|
|
"__type__": "update",
|
|
"visible": True,
|
|
"value": "Hello,",
|
|
"type": "text",
|
|
"stop_btn": False,
|
|
"submit_btn": False,
|
|
}
|
|
]
|
|
|
|
response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [1]})
|
|
assert response.json()["data"] == [
|
|
{
|
|
"lines": 1,
|
|
"max_lines": 20,
|
|
"show_label": True,
|
|
"container": True,
|
|
"min_width": 160,
|
|
"autofocus": False,
|
|
"autoscroll": True,
|
|
"elem_classes": [],
|
|
"rtl": False,
|
|
"show_copy_button": False,
|
|
"__type__": "update",
|
|
"visible": True,
|
|
"value": "Michael",
|
|
"type": "text",
|
|
"stop_btn": False,
|
|
"submit_btn": False,
|
|
}
|
|
]
|
|
|
|
def test_end_to_end_cache_examples(self, patched_cache_folder):
|
|
def concatenate(str1, str2):
|
|
return f"{str1} {str2}"
|
|
|
|
with gr.Blocks() as demo:
|
|
t1 = gr.Textbox()
|
|
t2 = gr.Textbox()
|
|
t1.submit(concatenate, [t1, t2], t2)
|
|
|
|
gr.Examples(
|
|
examples=[["Hello,", "World"], ["Michael", "Jordan"]],
|
|
inputs=[t1, t2],
|
|
outputs=[t2],
|
|
fn=concatenate,
|
|
cache_examples=True,
|
|
api_name="load_example",
|
|
)
|
|
|
|
app, _, _ = demo.launch(prevent_thread_lock=True)
|
|
client = TestClient(app)
|
|
|
|
response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [0]})
|
|
assert response.json()["data"] == ["Hello,", "World", "Hello, World"]
|
|
|
|
response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [1]})
|
|
assert response.json()["data"] == ["Michael", "Jordan", "Michael Jordan"]
|
|
|
|
def test_end_to_end_lazy_cache_examples(self, patched_cache_folder):
|
|
def image_identity(image, string):
|
|
return image
|
|
|
|
with gr.Blocks() as demo:
|
|
i1 = gr.Image()
|
|
t = gr.Textbox()
|
|
i2 = gr.Image()
|
|
|
|
gr.Examples(
|
|
examples=[
|
|
["test/test_files/cheetah1.jpg", "cheetah"],
|
|
["test/test_files/bus.png", "bus"],
|
|
],
|
|
inputs=[i1, t],
|
|
outputs=[i2],
|
|
fn=image_identity,
|
|
cache_examples=True,
|
|
cache_mode="lazy",
|
|
api_name="load_example",
|
|
)
|
|
|
|
app, _, _ = demo.launch(prevent_thread_lock=True)
|
|
client = TestClient(app)
|
|
|
|
response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [0]})
|
|
data = response.json()["data"]
|
|
assert data[0]["value"]["path"].endswith("cheetah1.jpg")
|
|
assert data[1]["value"] == "cheetah"
|
|
|
|
response = client.post(f"{API_PREFIX}/api/load_example/", json={"data": [1]})
|
|
data = response.json()["data"]
|
|
assert data[0]["value"]["path"].endswith("bus.png")
|
|
assert data[1]["value"] == "bus"
|
|
|
|
|
|
def test_multiple_file_flagging(tmp_path, connect):
|
|
with patch("gradio.utils.get_cache_folder", return_value=tmp_path):
|
|
io = gr.Interface(
|
|
fn=lambda *x: list(x),
|
|
inputs=[
|
|
gr.Image(type="filepath", label="frame 1"),
|
|
gr.Image(type="filepath", label="frame 2"),
|
|
],
|
|
outputs=[gr.Files()],
|
|
examples=[["test/test_files/cheetah1.jpg", "test/test_files/bus.png"]],
|
|
cache_examples=True,
|
|
)
|
|
with connect(io):
|
|
prediction = io.examples_handler.load_from_cache(0)
|
|
|
|
assert len(prediction[0].root) == 2
|
|
assert all(isinstance(d, gr.FileData) for d in prediction[0].root)
|
|
|
|
|
|
def test_examples_keep_all_suffixes(tmp_path, connect):
|
|
with patch("gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())):
|
|
file_1 = tmp_path / "foo.bar.txt"
|
|
file_1.write_text("file 1")
|
|
file_2 = tmp_path / "file_2"
|
|
file_2.mkdir(parents=True)
|
|
file_2 = file_2 / "foo.bar.txt"
|
|
file_2.write_text("file 2")
|
|
io = gr.Interface(
|
|
fn=lambda x: x.name,
|
|
inputs=gr.File(),
|
|
outputs=[gr.File()],
|
|
examples=[[str(file_1)], [str(file_2)]],
|
|
cache_examples=True,
|
|
)
|
|
with connect(io):
|
|
prediction = io.examples_handler.load_from_cache(0)
|
|
assert Path(prediction[0].path).read_text() == "file 1"
|
|
assert prediction[0].orig_name == "foo.bar.txt"
|
|
assert prediction[0].path.endswith("foo.bar.txt")
|
|
prediction = io.examples_handler.load_from_cache(1)
|
|
assert Path(prediction[0].path).read_text() == "file 2"
|
|
assert prediction[0].orig_name == "foo.bar.txt"
|
|
assert prediction[0].path.endswith("foo.bar.txt")
|
|
|
|
|
|
class TestProgressBar:
|
|
@pytest.mark.asyncio
|
|
async def test_progress_bar(self):
|
|
with gr.Blocks() as demo:
|
|
name = gr.Textbox()
|
|
greeting = gr.Textbox()
|
|
button = gr.Button(value="Greet")
|
|
|
|
def greet(s, prog=gr.Progress()):
|
|
prog(0, desc="start")
|
|
time.sleep(0.15)
|
|
for _ in prog.tqdm(range(4), unit="iter"):
|
|
time.sleep(0.15)
|
|
time.sleep(0.15)
|
|
for _ in tqdm(["a", "b", "c"], desc="alphabet"):
|
|
time.sleep(0.15)
|
|
return f"Hello, {s}!"
|
|
|
|
button.click(greet, name, greeting)
|
|
demo.queue(max_size=1).launch(prevent_thread_lock=True)
|
|
assert demo.local_url
|
|
|
|
client = grc.Client(demo.local_url)
|
|
job = client.submit("Gradio")
|
|
|
|
status_updates = []
|
|
while not job.done():
|
|
status = job.status()
|
|
update = (
|
|
status.progress_data[0].index if status.progress_data else None,
|
|
status.progress_data[0].desc if status.progress_data else None,
|
|
)
|
|
if update != (None, None) and (
|
|
len(status_updates) == 0 or status_updates[-1] != update
|
|
):
|
|
status_updates.append(update)
|
|
time.sleep(0.05)
|
|
|
|
assert status_updates == [
|
|
(None, "start"),
|
|
(0, None),
|
|
(1, None),
|
|
(2, None),
|
|
(3, None),
|
|
(4, None),
|
|
]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_progress_bar_track_tqdm(self):
|
|
with gr.Blocks() as demo:
|
|
name = gr.Textbox()
|
|
greeting = gr.Textbox()
|
|
button = gr.Button(value="Greet")
|
|
|
|
def greet(s, prog=gr.Progress(track_tqdm=True)):
|
|
prog(0, desc="start")
|
|
time.sleep(0.15)
|
|
for _ in prog.tqdm(range(4), unit="iter"):
|
|
time.sleep(0.15)
|
|
time.sleep(0.15)
|
|
for _ in tqdm(["a", "b", "c"], desc="alphabet"):
|
|
time.sleep(0.15)
|
|
return f"Hello, {s}!"
|
|
|
|
button.click(greet, name, greeting)
|
|
demo.queue(max_size=1).launch(prevent_thread_lock=True)
|
|
assert demo.local_url
|
|
|
|
client = grc.Client(demo.local_url)
|
|
job = client.submit("Gradio")
|
|
|
|
status_updates = []
|
|
while not job.done():
|
|
status = job.status()
|
|
update = (
|
|
status.progress_data[0].index if status.progress_data else None,
|
|
status.progress_data[0].desc if status.progress_data else None,
|
|
)
|
|
if update != (None, None) and (
|
|
len(status_updates) == 0 or status_updates[-1] != update
|
|
):
|
|
status_updates.append(update)
|
|
time.sleep(0.05)
|
|
|
|
assert status_updates == [
|
|
(None, "start"),
|
|
(0, None),
|
|
(1, None),
|
|
(2, None),
|
|
(3, None),
|
|
(4, None),
|
|
(0, "alphabet"),
|
|
(1, "alphabet"),
|
|
(2, "alphabet"),
|
|
]
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.flaky(reruns=5)
|
|
async def test_progress_bar_track_tqdm_without_iterable(self):
|
|
def greet(s, _=gr.Progress(track_tqdm=True)):
|
|
with tqdm(total=len(s)) as progress_bar:
|
|
for _c in s:
|
|
progress_bar.update()
|
|
time.sleep(0.1)
|
|
return f"Hello, {s}!"
|
|
|
|
demo = gr.Interface(greet, "text", "text")
|
|
demo.queue().launch(prevent_thread_lock=True)
|
|
assert demo.local_url
|
|
|
|
client = grc.Client(demo.local_url)
|
|
job = client.submit("Gradio")
|
|
|
|
status_updates = []
|
|
while not job.done():
|
|
status = job.status()
|
|
update = (
|
|
status.progress_data[0].index if status.progress_data else None,
|
|
status.progress_data[0].unit if status.progress_data else None,
|
|
)
|
|
if update != (None, None) and (
|
|
len(status_updates) == 0 or status_updates[-1] != update
|
|
):
|
|
status_updates.append(update)
|
|
time.sleep(0.05)
|
|
|
|
assert status_updates[-1] == (6, "steps")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_info_and_warning_alerts(self):
|
|
def greet(s):
|
|
for _c in s:
|
|
gr.Info(f"Letter {_c}")
|
|
time.sleep(0.15)
|
|
if len(s) < 5:
|
|
gr.Warning("Too short!")
|
|
time.sleep(0.15)
|
|
return f"Hello, {s}!"
|
|
|
|
demo = gr.Interface(greet, "text", "text")
|
|
demo.queue().launch(prevent_thread_lock=True)
|
|
assert demo.local_url
|
|
|
|
client = grc.Client(demo.local_url)
|
|
job = client.submit("Jon")
|
|
|
|
status_updates = []
|
|
while not job.done():
|
|
status = job.status()
|
|
update = status.log
|
|
if update is not None and (
|
|
len(status_updates) == 0 or status_updates[-1] != update
|
|
):
|
|
status_updates.append(update)
|
|
time.sleep(0.05)
|
|
|
|
assert status_updates == [
|
|
("Letter J", "info"),
|
|
("Letter o", "info"),
|
|
("Letter n", "info"),
|
|
("Too short!", "warning"),
|
|
]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("async_handler", [True, False])
|
|
async def test_info_isolation(async_handler: bool):
|
|
async def greet_async(name):
|
|
await asyncio.sleep(2)
|
|
gr.Info(f"Hello {name}")
|
|
await asyncio.sleep(1)
|
|
return name
|
|
|
|
def greet_sync(name):
|
|
time.sleep(2)
|
|
gr.Info(f"Hello {name}")
|
|
time.sleep(1)
|
|
return name
|
|
|
|
demo = gr.Interface(
|
|
greet_async if async_handler else greet_sync,
|
|
"text",
|
|
"text",
|
|
concurrency_limit=2,
|
|
)
|
|
demo.launch(prevent_thread_lock=True)
|
|
|
|
async def session_interaction(name, delay=0):
|
|
assert demo.local_url
|
|
client = grc.Client(demo.local_url)
|
|
job = client.submit(name)
|
|
|
|
status_updates = []
|
|
while not job.done():
|
|
status = job.status()
|
|
update = status.log
|
|
if update is not None and (
|
|
len(status_updates) == 0 or status_updates[-1] != update
|
|
):
|
|
status_updates.append(update)
|
|
time.sleep(0.05)
|
|
return status_updates[-1][0] if status_updates else None
|
|
|
|
alice_logs, bob_logs = await asyncio.gather(
|
|
session_interaction("Alice"),
|
|
session_interaction("Bob", delay=1),
|
|
)
|
|
|
|
assert alice_logs == "Hello Alice"
|
|
assert bob_logs == "Hello Bob"
|
|
|
|
|
|
def test_check_event_data_in_cache():
|
|
def get_select_index(evt: gr.SelectData):
|
|
return evt.index
|
|
|
|
with pytest.raises(gr.Error):
|
|
helpers.special_args(
|
|
get_select_index,
|
|
inputs=[],
|
|
event_data=helpers.EventData(
|
|
None,
|
|
{
|
|
"index": {"path": "foo", "meta": {"_type": "gradio.FileData"}},
|
|
"value": "whatever",
|
|
},
|
|
),
|
|
)
|
|
|
|
|
|
def test_examples_no_cache_optional_inputs():
|
|
def foo(a, b, c, d):
|
|
return {"a": a, "b": b, "c": c, "d": d}
|
|
|
|
io = gr.Interface(
|
|
foo,
|
|
["text", "text", "text", "text"],
|
|
"json",
|
|
cache_examples=False,
|
|
examples=[["a", "b", None, "d"], ["a", "b", None, "de"]],
|
|
)
|
|
|
|
try:
|
|
app, _, _ = io.launch(prevent_thread_lock=True)
|
|
|
|
client = TestClient(app)
|
|
with client as c:
|
|
for i in range(2):
|
|
response = c.post(
|
|
f"{API_PREFIX}/run/predict/",
|
|
json={
|
|
"data": [i],
|
|
"fn_index": 6,
|
|
"trigger_id": 19,
|
|
"session_hash": "test",
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
finally:
|
|
io.close()
|