gradio/test/test_helpers.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

627 lines
21 KiB
Python
Raw Normal View History

import json
2022-01-26 14:05:50 +08:00
import os
import shutil
import subprocess
import tempfile
import time
from pathlib import Path
from unittest.mock import patch
2022-01-26 14:05:50 +08:00
import pytest
import websockets
from gradio_client import media_data
from starlette.testclient import TestClient
from tqdm import tqdm
import gradio as gr
2022-01-26 14:05:50 +08:00
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
Add Progress Bar component (#2750) * changes * version * changes * fixes * changes * changes * changes * changes * chagnes * chagnes * fix * changes * changes * changes * change * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * version update * Commit from GitHub Actions (Upload Python Package) * changes * changes * changes * fix * changes * changes * changes * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * changes * changes * changes * changes * change * changes * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/helpers.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/routes.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/helpers.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update website/homepage/src/docs/template.html Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update website/homepage/src/docs/template.html Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * changes * changes * changes * changes * changes * changes * changes * change * changes * changes * changes * change Co-authored-by: Abubakar Abid <abubakar@huggingface.co> Co-authored-by: GH ACTIONS <aliabid94@users.noreply.github.com>
2022-12-31 03:45:54 +08:00
@patch("gradio.helpers.CACHED_FOLDER", tempfile.mkdtemp())
class TestExamples:
def test_handle_single_input(self):
examples = gr.Examples(["hello", "hi"], gr.Textbox())
assert examples.processed_examples == [["hello"], ["hi"]]
examples = gr.Examples([["hello"]], gr.Textbox())
assert examples.processed_examples == [["hello"]]
examples = gr.Examples(["test/test_files/bus.png"], gr.Image())
assert examples.processed_examples == [[media_data.BASE64_IMAGE]]
def test_handle_multiple_inputs(self):
examples = gr.Examples(
[["hello", "test/test_files/bus.png"]], [gr.Textbox(), gr.Image()]
)
assert examples.processed_examples == [["hello", media_data.BASE64_IMAGE]]
def test_handle_directory(self):
examples = gr.Examples("test/test_files/images", gr.Image())
assert examples.processed_examples == [
[media_data.BASE64_IMAGE],
[media_data.BASE64_IMAGE],
]
def test_handle_directory_with_log_file(self):
examples = gr.Examples(
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
)
assert examples.processed_examples == [
[media_data.BASE64_IMAGE, "hello"],
[media_data.BASE64_IMAGE, "hi"],
]
for sample in examples.dataset.samples:
assert os.path.isabs(sample[0])
def test_examples_per_page(self):
examples = gr.Examples(["hello", "hi"], gr.Textbox(), examples_per_page=2)
assert examples.dataset.get_config()["samples_per_page"] == 2
@pytest.mark.asyncio
async def test_no_preprocessing(self):
with gr.Blocks():
image = gr.Image()
textbox = gr.Textbox()
examples = gr.Examples(
examples=["test/test_files/bus.png"],
inputs=image,
outputs=textbox,
fn=lambda x: x,
cache_examples=True,
preprocess=False,
)
prediction = await examples.load_from_cache(0)
assert prediction == [media_data.BASE64_IMAGE]
@pytest.mark.asyncio
async def test_no_postprocessing(self):
def im(x):
return [media_data.BASE64_IMAGE]
with gr.Blocks():
text = gr.Textbox()
gall = gr.Gallery()
examples = gr.Examples(
examples=["hi"],
inputs=text,
outputs=gall,
fn=im,
cache_examples=True,
postprocess=False,
)
prediction = await examples.load_from_cache(0)
assert prediction[0][0][0]["data"] == media_data.BASE64_IMAGE
Add Progress Bar component (#2750) * changes * version * changes * fixes * changes * changes * changes * changes * chagnes * chagnes * fix * changes * changes * changes * change * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * version update * Commit from GitHub Actions (Upload Python Package) * changes * changes * changes * fix * changes * changes * changes * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * changes * changes * changes * changes * change * changes * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/helpers.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/routes.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/helpers.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update website/homepage/src/docs/template.html Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update website/homepage/src/docs/template.html Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * changes * changes * changes * changes * changes * changes * changes * change * changes * changes * changes * change Co-authored-by: Abubakar Abid <abubakar@huggingface.co> Co-authored-by: GH ACTIONS <aliabid94@users.noreply.github.com>
2022-12-31 03:45:54 +08:00
@patch("gradio.helpers.CACHED_FOLDER", tempfile.mkdtemp())
class TestExamplesDataset:
def test_no_headers(self):
examples = gr.Examples("test/test_files/images_log", [gr.Image(), gr.Text()])
assert examples.dataset.headers == []
def test_all_headers(self):
examples = gr.Examples(
"test/test_files/images_log",
[gr.Image(label="im"), gr.Text(label="your text")],
)
assert examples.dataset.headers == ["im", "your text"]
def test_some_headers(self):
examples = gr.Examples(
"test/test_files/images_log", [gr.Image(label="im"), gr.Text()]
)
assert examples.dataset.headers == ["im", ""]
Add Progress Bar component (#2750) * changes * version * changes * fixes * changes * changes * changes * changes * chagnes * chagnes * fix * changes * changes * changes * change * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * version update * Commit from GitHub Actions (Upload Python Package) * changes * changes * changes * fix * changes * changes * changes * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * changes * changes * changes * changes * change * changes * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/helpers.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/routes.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/helpers.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update website/homepage/src/docs/template.html Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update website/homepage/src/docs/template.html Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * changes * changes * changes * changes * changes * changes * changes * change * changes * changes * changes * change Co-authored-by: Abubakar Abid <abubakar@huggingface.co> Co-authored-by: GH ACTIONS <aliabid94@users.noreply.github.com>
2022-12-31 03:45:54 +08:00
@patch("gradio.helpers.CACHED_FOLDER", tempfile.mkdtemp())
class TestProcessExamples:
@pytest.mark.asyncio
async def test_caching(self):
io = gr.Interface(
lambda x: f"Hello {x}",
2022-03-26 02:14:28 +08:00
"text",
"text",
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
2022-03-26 02:14:28 +08:00
)
prediction = await io.examples_handler.load_from_cache(1)
assert prediction[0] == "Hello Dunya"
@pytest.mark.asyncio
async def test_caching_image(self):
io = gr.Interface(
lambda x: x,
"image",
"image",
examples=[["test/test_files/bus.png"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert prediction[0].startswith("data:image/png;base64,iVBORw0KGgoAAA")
@pytest.mark.asyncio
async def test_caching_audio(self):
io = gr.Interface(
lambda x: x,
"audio",
"audio",
examples=[["test/test_files/audio_sample.wav"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert prediction[0]["data"].startswith("data:audio/wav;base64,UklGRgA/")
@pytest.mark.asyncio
async def test_caching_with_update(self):
io = gr.Interface(
lambda x: gr.update(visible=False),
"text",
"image",
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(1)
assert prediction[0] == {
"visible": False,
"__type__": "update",
}
@pytest.mark.asyncio
async def test_caching_with_mix_update(self):
io = gr.Interface(
lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"],
"text",
["text", "image"],
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(1)
assert prediction[0] == {
"lines": 4,
"value": "hello",
"__type__": "update",
}
@pytest.mark.asyncio
async def test_caching_with_dict(self):
text = gr.Textbox()
out = gr.Label()
io = gr.Interface(
lambda _: {text: gr.update(lines=4, interactive=False), out: "lion"},
"textbox",
[text, out],
examples=["abc"],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert not any(d["trigger"] == "fake_event" for d in io.config["dependencies"])
assert prediction == [
{"lines": 4, "__type__": "update", "mode": "static"},
{"label": "lion"},
]
def test_raise_helpful_error_message_if_providing_partial_examples(self, tmp_path):
def foo(a, b):
return a + b
with pytest.warns(
UserWarning,
match="^Examples are being cached but not all input components have",
):
with pytest.raises(Exception):
gr.Interface(
foo,
inputs=["text", "text"],
outputs=["text"],
examples=[["foo"], ["bar"]],
cache_examples=True,
)
with pytest.warns(
UserWarning,
match="^Examples are being cached but not all input components have",
):
with pytest.raises(Exception):
gr.Interface(
foo,
inputs=["text", "text"],
outputs=["text"],
examples=[["foo", "bar"], ["bar", None]],
cache_examples=True,
)
def foo_no_exception(a, b=2):
return a * b
gr.Interface(
foo_no_exception,
inputs=["text", "number"],
outputs=["text"],
examples=[["foo"], ["bar"]],
cache_examples=True,
)
def many_missing(a, b, c):
return a * b
with pytest.warns(
UserWarning,
match="^Examples are being cached but not all input components have",
):
with pytest.raises(Exception):
gr.Interface(
many_missing,
inputs=["text", "number", "number"],
outputs=["text"],
examples=[["foo", None, None], ["bar", 2, 3]],
cache_examples=True,
)
@pytest.mark.asyncio
async def test_caching_with_batch(self):
def trim_words(words, lens):
trimmed_words = [word[:length] for word, length in zip(words, lens)]
return [trimmed_words]
io = gr.Interface(
trim_words,
["textbox", gr.Number(precision=0)],
["textbox"],
batch=True,
max_batch_size=16,
examples=[["hello", 3], ["hi", 4]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert prediction == ["hel"]
@pytest.mark.asyncio
async def test_caching_with_batch_multiple_outputs(self):
def trim_words(words, lens):
trimmed_words = [word[:length] for word, length in zip(words, lens)]
return trimmed_words, lens
io = gr.Interface(
trim_words,
["textbox", gr.Number(precision=0)],
["textbox", gr.Number(precision=0)],
batch=True,
max_batch_size=16,
examples=[["hello", 3], ["hi", 4]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert prediction == ["hel", "3"]
@pytest.mark.asyncio
async def test_caching_with_non_io_component(self):
def predict(name):
return name, gr.update(visible=True)
with gr.Blocks():
t1 = gr.Textbox()
with gr.Column(visible=False) as c:
t2 = gr.Textbox()
examples = gr.Examples(
[["John"], ["Mary"]],
fn=predict,
inputs=[t1],
outputs=[t2, c],
cache_examples=True,
)
prediction = await examples.load_from_cache(0)
assert prediction == ["John", {"visible": True, "__type__": "update"}]
def test_end_to_end(self):
def concatenate(str1, str2):
return str1 + str2
io = gr.Interface(
concatenate,
inputs=[gr.Textbox(), gr.Textbox()],
outputs=gr.Textbox(),
examples=[["Hello,", None], ["Michael", None]],
)
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/load_example/", json={"data": [0]})
assert response.json()["data"] == ["Hello,"]
response = client.post("/api/load_example/", json={"data": [1]})
assert response.json()["data"] == ["Michael"]
def test_end_to_end_cache_examples(self):
def concatenate(str1, str2):
return f"{str1} {str2}"
io = gr.Interface(
concatenate,
inputs=[gr.Textbox(), gr.Textbox()],
outputs=gr.Textbox(),
examples=[["Hello,", "World"], ["Michael", "Jordan"]],
cache_examples=True,
)
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/load_example/", json={"data": [0]})
assert response.json()["data"] == ["Hello,", "World", "Hello, World"]
response = client.post("/api/load_example/", json={"data": [1]})
assert response.json()["data"] == ["Michael", "Jordan", "Michael Jordan"]
@pytest.mark.asyncio
async def test_multiple_file_flagging(tmp_path):
with patch("gradio.helpers.CACHED_FOLDER", str(tmp_path)):
io = gr.Interface(
fn=lambda *x: list(x),
inputs=[
gr.Image(source="upload", type="filepath", label="frame 1"),
gr.Image(source="upload", type="filepath", label="frame 2"),
],
outputs=[gr.Files()],
examples=[["test/test_files/cheetah1.jpg", "test/test_files/bus.png"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert len(prediction[0]) == 2
assert all(isinstance(d, dict) for d in prediction[0])
@pytest.mark.asyncio
async def test_examples_keep_all_suffixes(tmp_path):
with patch("gradio.helpers.CACHED_FOLDER", str(tmp_path)):
file_1 = tmp_path / "foo.bar.txt"
file_1.write_text("file 1")
file_2 = tmp_path / "file_2"
file_2.mkdir(parents=True)
file_2 = file_2 / "foo.bar.txt"
file_2.write_text("file 2")
io = gr.Interface(
fn=lambda x: x.name,
inputs=gr.File(),
outputs=[gr.File()],
examples=[[str(file_1)], [str(file_2)]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert Path(prediction[0]["name"]).read_text() == "file 1"
assert prediction[0]["orig_name"] == "foo.bar.txt"
prediction = await io.examples_handler.load_from_cache(1)
assert Path(prediction[0]["name"]).read_text() == "file 2"
assert prediction[0]["orig_name"] == "foo.bar.txt"
def test_make_waveform_with_spaces_in_filename():
with tempfile.TemporaryDirectory() as tmpdirname:
audio = os.path.join(tmpdirname, "test audio.wav")
shutil.copy("test/test_files/audio_sample.wav", audio)
waveform = gr.make_waveform(audio)
assert waveform.endswith(".mp4")
def test_make_waveform_raises_if_ffmpeg_fails(tmp_path, monkeypatch):
"""
Test that make_waveform raises an exception if ffmpeg fails,
instead of returning a path to a non-existent or empty file.
"""
audio = tmp_path / "test audio.wav"
shutil.copy("test/test_files/audio_sample.wav", audio)
def _failing_ffmpeg(*args, **kwargs):
raise subprocess.CalledProcessError(1, "ffmpeg")
monkeypatch.setattr(subprocess, "call", _failing_ffmpeg)
with pytest.raises(Exception):
gr.make_waveform(str(audio))
class TestProgressBar:
@pytest.mark.asyncio
async def test_progress_bar(self):
with gr.Blocks() as demo:
name = gr.Textbox()
greeting = gr.Textbox()
button = gr.Button(value="Greet")
def greet(s, prog=gr.Progress()):
prog(0, desc="start")
time.sleep(0.15)
for _ in prog.tqdm(range(4), unit="iter"):
time.sleep(0.15)
time.sleep(0.15)
for _ in tqdm(["a", "b", "c"], desc="alphabet"):
time.sleep(0.15)
return f"Hello, {s}!"
button.click(greet, name, greeting)
demo.queue(max_size=1).launch(prevent_thread_lock=True)
async with websockets.connect(
f"{demo.local_url.replace('http', 'ws')}queue/join"
) as ws:
completed = False
progress_updates = []
while not completed:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(json.dumps({"data": [0], "fn_index": 0}))
if msg["msg"] == "send_hash":
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
if msg["msg"] == "progress":
progress_updates.append(msg["progress_data"])
if msg["msg"] == "process_completed":
completed = True
break
assert progress_updates == [
[
{
"index": None,
"length": None,
"unit": "steps",
"progress": 0.0,
"desc": "start",
}
],
[{"index": 0, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[{"index": 1, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[{"index": 2, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[{"index": 3, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[{"index": 4, "length": 4, "unit": "iter", "progress": None, "desc": None}],
]
@pytest.mark.asyncio
async def test_progress_bar_track_tqdm(self):
with gr.Blocks() as demo:
name = gr.Textbox()
greeting = gr.Textbox()
button = gr.Button(value="Greet")
def greet(s, prog=gr.Progress(track_tqdm=True)):
prog(0, desc="start")
time.sleep(0.15)
for _ in prog.tqdm(range(4), unit="iter"):
time.sleep(0.15)
time.sleep(0.15)
for _ in tqdm(["a", "b", "c"], desc="alphabet"):
time.sleep(0.15)
return f"Hello, {s}!"
button.click(greet, name, greeting)
demo.queue(max_size=1).launch(prevent_thread_lock=True)
async with websockets.connect(
f"{demo.local_url.replace('http', 'ws')}queue/join"
) as ws:
completed = False
progress_updates = []
while not completed:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(json.dumps({"data": [0], "fn_index": 0}))
if msg["msg"] == "send_hash":
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
if (
msg["msg"] == "progress" and msg["progress_data"]
): # Ignore empty lists which sometimes appear on Windows
progress_updates.append(msg["progress_data"])
if msg["msg"] == "process_completed":
completed = True
break
assert progress_updates == [
[
{
"index": None,
"length": None,
"unit": "steps",
"progress": 0.0,
"desc": "start",
}
],
[{"index": 0, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[{"index": 1, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[{"index": 2, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[{"index": 3, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[{"index": 4, "length": 4, "unit": "iter", "progress": None, "desc": None}],
[
{
"index": 0,
"length": 3,
"unit": "steps",
"progress": None,
"desc": "alphabet",
}
],
[
{
"index": 1,
"length": 3,
"unit": "steps",
"progress": None,
"desc": "alphabet",
}
],
[
{
"index": 2,
"length": 3,
"unit": "steps",
"progress": None,
"desc": "alphabet",
}
],
]
@pytest.mark.asyncio
async def test_progress_bar_track_tqdm_without_iterable(self):
def greet(s, _=gr.Progress(track_tqdm=True)):
with tqdm(total=len(s)) as progress_bar:
for _c in s:
progress_bar.update()
time.sleep(0.15)
return f"Hello, {s}!"
demo = gr.Interface(greet, "text", "text")
demo.queue().launch(prevent_thread_lock=True)
async with websockets.connect(
f"{demo.local_url.replace('http', 'ws')}queue/join"
) as ws:
completed = False
progress_updates = []
while not completed:
msg = json.loads(await ws.recv())
if msg["msg"] == "send_data":
await ws.send(json.dumps({"data": ["abc"], "fn_index": 0}))
if msg["msg"] == "send_hash":
await ws.send(json.dumps({"fn_index": 0, "session_hash": "shdce"}))
if (
msg["msg"] == "progress" and msg["progress_data"]
): # Ignore empty lists which sometimes appear on Windows
progress_updates.append(msg["progress_data"])
if msg["msg"] == "process_completed":
completed = True
break
assert progress_updates == [
[
{
"index": 1,
"length": 3,
"unit": "steps",
"progress": None,
"desc": None,
}
],
[
{
"index": 2,
"length": 3,
"unit": "steps",
"progress": None,
"desc": None,
}
],
[
{
"index": 3,
"length": 3,
"unit": "steps",
"progress": None,
"desc": None,
}
],
]