gradio/test/test_examples.py
Abubakar Abid 0e168c4dff
Batching (#2218)
* batch

* formatting

* added parameter

* batch

* added docstrings

* correct arguments

* docstring

* adapt process_api for batch

* backend

* __call__

* more regular tests

* formatting

* cleaning up blocks.py

* __call__ works

* api route works

* first attempt at queue

* fixing tests

* fix some tests

* formatting

* removed print

* merge

* queue works!

* removed batch timeout

* removed batch timeout

* updated documentation

* fixing tests

* fixing tests

* fixing queue tests

* fixing queue tests

* formatting

* fix blocks config

* fix tests

* update documentation

* updated tests

* blocks

* blocks

* blocks

* tests

* test fixes

* more tests

* faster

* foramtting

* test fixes

* dataclasses

* fix

* revert to fix test

* fix

* fix test

* formatting

* fix tests

* refactoring examples

* formatting

* changelog

* fix examples

* formatting

* fix tests

* formatting

* catch error

* formatting

* fix tests

* fix cancel with batch

* final tests and docs

* test routes

* formatting
2022-10-24 16:32:37 -07:00

279 lines
8.9 KiB
Python

import os
import tempfile
from unittest.mock import patch
import pytest
import gradio as gr
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@patch("gradio.examples.CACHED_FOLDER", tempfile.mkdtemp())
class TestExamples:
def test_handle_single_input(self):
examples = gr.Examples(["hello", "hi"], gr.Textbox())
assert examples.processed_examples == [["hello"], ["hi"]]
examples = gr.Examples([["hello"]], gr.Textbox())
assert examples.processed_examples == [["hello"]]
examples = gr.Examples(["test/test_files/bus.png"], gr.Image())
assert examples.processed_examples == [[gr.media_data.BASE64_IMAGE]]
def test_handle_multiple_inputs(self):
examples = gr.Examples(
[["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()]
)
assert examples.processed_examples == [["hello", gr.media_data.BASE64_IMAGE]]
def test_handle_directory(self):
examples = gr.Examples("test/test_files/images", gr.Image())
assert examples.processed_examples == [
[gr.media_data.BASE64_IMAGE],
[gr.media_data.BASE64_IMAGE],
]
def test_handle_directory_with_log_file(self):
examples = gr.Examples(
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
)
assert examples.processed_examples == [
[gr.media_data.BASE64_IMAGE, "hello"],
[gr.media_data.BASE64_IMAGE, "hi"],
]
@pytest.mark.asyncio
async def test_no_preprocessing(self):
with gr.Blocks():
image = gr.Image()
textbox = gr.Textbox()
examples = gr.Examples(
examples=["test/test_files/bus.png"],
inputs=image,
outputs=textbox,
fn=lambda x: x,
cache_examples=True,
preprocess=False,
)
prediction = await examples.load_from_cache(0)
assert prediction == [gr.media_data.BASE64_IMAGE]
@pytest.mark.asyncio
async def test_no_postprocessing(self):
def im(x):
return [gr.media_data.BASE64_IMAGE]
with gr.Blocks():
text = gr.Textbox()
gall = gr.Gallery()
examples = gr.Examples(
examples=["hi"],
inputs=text,
outputs=gall,
fn=im,
cache_examples=True,
postprocess=False,
)
prediction = await examples.load_from_cache(0)
assert prediction[0][0]["data"] == gr.media_data.BASE64_IMAGE
@patch("gradio.examples.CACHED_FOLDER", tempfile.mkdtemp())
class TestExamplesDataset:
def test_no_headers(self):
examples = gr.Examples("test/test_files/images_log", [gr.Image(), gr.Text()])
assert examples.dataset.headers == []
def test_all_headers(self):
examples = gr.Examples(
"test/test_files/images_log",
[gr.Image(label="im"), gr.Text(label="your text")],
)
assert examples.dataset.headers == ["im", "your text"]
def test_some_headers(self):
examples = gr.Examples(
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
)
assert examples.dataset.headers == ["im", ""]
@patch("gradio.examples.CACHED_FOLDER", tempfile.mkdtemp())
class TestProcessExamples:
@pytest.mark.asyncio
async def test_caching(self):
io = gr.Interface(
lambda x: "Hello " + x,
"text",
"text",
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(1)
assert prediction[0] == "Hello Dunya"
@pytest.mark.asyncio
async def test_caching_image(self):
io = gr.Interface(
lambda x: x,
"image",
"image",
examples=[["test/test_files/bus.png"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert prediction[0].startswith("")
@pytest.mark.asyncio
async def test_caching_audio(self):
io = gr.Interface(
lambda x: x,
"audio",
"audio",
examples=[["test/test_files/audio_sample.wav"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert prediction[0]["data"].startswith("data:audio/wav;base64,UklGRgA/")
@pytest.mark.asyncio
async def test_caching_with_update(self):
io = gr.Interface(
lambda x: gr.update(visible=False),
"text",
"image",
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(1)
assert prediction[0] == {"visible": False, "__type__": "update"}
@pytest.mark.asyncio
async def test_caching_with_mix_update(self):
io = gr.Interface(
lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"],
"text",
["text", "image"],
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(1)
assert prediction[0] == {"lines": 4, "value": "hello", "__type__": "update"}
@pytest.mark.asyncio
async def test_caching_with_dict(self):
text = gr.Textbox()
out = gr.Label()
io = gr.Interface(
lambda _: {text: gr.update(lines=4), out: "lion"},
"textbox",
[text, out],
examples=["abc"],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert prediction == [{"lines": 4, "__type__": "update"}, {"label": "lion"}]
def test_raise_helpful_error_message_if_providing_partial_examples(self, tmp_path):
def foo(a, b):
return a + b
with pytest.warns(
UserWarning,
match="^Examples are being cached but not all input components have",
):
with pytest.raises(Exception):
gr.Interface(
foo,
inputs=["text", "text"],
outputs=["text"],
examples=[["foo"], ["bar"]],
cache_examples=True,
)
with pytest.warns(
UserWarning,
match="^Examples are being cached but not all input components have",
):
with pytest.raises(Exception):
gr.Interface(
foo,
inputs=["text", "text"],
outputs=["text"],
examples=[["foo", "bar"], ["bar", None]],
cache_examples=True,
)
def foo_no_exception(a, b=2):
return a * b
gr.Interface(
foo_no_exception,
inputs=["text", "number"],
outputs=["text"],
examples=[["foo"], ["bar"]],
cache_examples=True,
)
def many_missing(a, b, c):
return a * b
with pytest.warns(
UserWarning,
match="^Examples are being cached but not all input components have",
):
with pytest.raises(Exception):
gr.Interface(
many_missing,
inputs=["text", "number", "number"],
outputs=["text"],
examples=[["foo", None, None], ["bar", 2, 3]],
cache_examples=True,
)
@pytest.mark.asyncio
async def test_caching_with_batch(self):
def trim_words(words, lens):
trimmed_words = []
for w, l in zip(words, lens):
trimmed_words.append(w[:l])
return [trimmed_words]
io = gr.Interface(
trim_words,
["textbox", gr.Number(precision=0)],
["textbox"],
batch=True,
max_batch_size=16,
examples=[["hello", 3], ["hi", 4]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert prediction == ["hel"]
@pytest.mark.asyncio
async def test_caching_with_batch_multiple_outputs(self):
def trim_words(words, lens):
trimmed_words = []
for w, l in zip(words, lens):
trimmed_words.append(w[:l])
return trimmed_words, lens
io = gr.Interface(
trim_words,
["textbox", gr.Number(precision=0)],
["textbox", gr.Number(precision=0)],
batch=True,
max_batch_size=16,
examples=[["hello", 3], ["hi", 4]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert prediction == ["hel", "3"]