mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
c60ad4d34a
* Gallery as input * add changeset * Add keep_captions parameter * revert notebooks * fix docstring typo * use tuples * remove keep_captions * Update js/app/test/gallery_component_events.spec.ts Co-authored-by: pngwn <hello@pngwn.io> * more precise type hint --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: pngwn <hello@pngwn.io>
2996 lines
102 KiB
Python
2996 lines
102 KiB
Python
"""
|
|
Tests for all of the components defined in components.py. Tests are divided into two types:
|
|
1. test_component_functions() are unit tests that check essential functions of a component, the functions that are checked are documented in the docstring.
|
|
2. test_in_interface() are functional tests that check a component's functionalities inside an Interface. Please do not use Interface.launch() in this file, as it slow downs the tests.
|
|
"""
|
|
|
|
import filecmp
|
|
import inspect
|
|
import json
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
from copy import deepcopy
|
|
from difflib import SequenceMatcher
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import PIL
|
|
import pytest
|
|
import vega_datasets
|
|
from gradio_client import media_data
|
|
from gradio_client import utils as client_utils
|
|
from scipy.io import wavfile
|
|
|
|
try:
|
|
from typing import cast
|
|
except ImportError:
|
|
from typing import cast
|
|
|
|
import gradio as gr
|
|
from gradio import processing_utils, utils
|
|
from gradio.components.dataframe import DataframeData
|
|
from gradio.components.file_explorer import FileExplorerData
|
|
from gradio.components.image_editor import EditorData
|
|
from gradio.components.video import VideoData
|
|
from gradio.data_classes import FileData, ListFiles
|
|
|
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
|
|
|
|
|
class TestGettingComponents:
|
|
def test_component_function(self):
|
|
assert isinstance(
|
|
gr.components.component("textarea", render=False), gr.templates.TextArea
|
|
)
|
|
|
|
@pytest.mark.parametrize(
|
|
"component, render, unrender, should_be_rendered",
|
|
[
|
|
(gr.Textbox(render=True), False, True, False),
|
|
(gr.Textbox(render=False), False, False, False),
|
|
(gr.Textbox(render=False), True, False, True),
|
|
("textbox", False, False, False),
|
|
("textbox", True, False, True),
|
|
],
|
|
)
|
|
def test_get_component_instance_rendering(
|
|
self, component, render, unrender, should_be_rendered
|
|
):
|
|
with gr.Blocks():
|
|
textbox = gr.components.get_component_instance(
|
|
component, render=render, unrender=unrender
|
|
)
|
|
assert textbox.is_rendered == should_be_rendered
|
|
|
|
|
|
class TestTextbox:
|
|
def test_component_functions(self):
|
|
"""
|
|
Preprocess, postprocess, serialize, tokenize, get_config
|
|
"""
|
|
text_input = gr.Textbox()
|
|
assert text_input.preprocess("Hello World!") == "Hello World!"
|
|
assert text_input.postprocess("Hello World!") == "Hello World!"
|
|
assert text_input.postprocess(None) is None
|
|
assert text_input.postprocess("Ali") == "Ali"
|
|
assert text_input.postprocess(2) == "2"
|
|
assert text_input.postprocess(2.14) == "2.14"
|
|
assert text_input.get_config() == {
|
|
"lines": 1,
|
|
"max_lines": 20,
|
|
"placeholder": None,
|
|
"value": "",
|
|
"name": "textbox",
|
|
"show_copy_button": False,
|
|
"show_label": True,
|
|
"type": "text",
|
|
"label": None,
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"interactive": None,
|
|
"proxy_url": None,
|
|
"rtl": False,
|
|
"text_align": None,
|
|
"autofocus": False,
|
|
"_selectable": False,
|
|
"info": None,
|
|
"autoscroll": True,
|
|
}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_in_interface_as_input(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
iface = gr.Interface(lambda x: x[::-1], "textbox", "textbox")
|
|
assert iface("Hello") == "olleH"
|
|
|
|
def test_in_interface_as_output(self):
|
|
"""
|
|
Interface, process
|
|
|
|
"""
|
|
iface = gr.Interface(lambda x: x[-1], "textbox", gr.Textbox())
|
|
assert iface("Hello") == "o"
|
|
iface = gr.Interface(lambda x: x / 2, "number", gr.Textbox())
|
|
assert iface(10) == "5.0"
|
|
|
|
def test_static(self):
|
|
"""
|
|
postprocess
|
|
"""
|
|
component = gr.Textbox("abc")
|
|
assert component.get_config().get("value") == "abc"
|
|
|
|
def test_override_template(self):
|
|
"""
|
|
override template
|
|
"""
|
|
component = gr.TextArea(value="abc")
|
|
assert component.get_config().get("value") == "abc"
|
|
assert component.get_config().get("lines") == 7
|
|
component = gr.TextArea(value="abc", lines=4)
|
|
assert component.get_config().get("value") == "abc"
|
|
assert component.get_config().get("lines") == 4
|
|
|
|
def test_faulty_type(self):
|
|
with pytest.raises(
|
|
ValueError, match='`type` must be one of "text", "password", or "email".'
|
|
):
|
|
gr.Textbox(type="boo")
|
|
|
|
def test_max_lines(self):
|
|
assert gr.Textbox(type="password").get_config().get("max_lines") == 1
|
|
assert gr.Textbox(type="email").get_config().get("max_lines") == 1
|
|
assert gr.Textbox(type="text").get_config().get("max_lines") == 20
|
|
assert gr.Textbox().get_config().get("max_lines") == 20
|
|
|
|
|
|
class TestNumber:
|
|
def test_component_functions(self):
|
|
"""
|
|
Preprocess, postprocess, serialize, get_config
|
|
|
|
"""
|
|
numeric_input = gr.Number(elem_id="num", elem_classes="first")
|
|
assert numeric_input.preprocess(3) == 3.0
|
|
assert numeric_input.preprocess(None) is None
|
|
assert numeric_input.postprocess(3) == 3
|
|
assert numeric_input.postprocess(3) == 3.0
|
|
assert numeric_input.postprocess(2.14) == 2.14
|
|
assert numeric_input.postprocess(None) is None
|
|
assert numeric_input.get_config() == {
|
|
"value": None,
|
|
"name": "number",
|
|
"show_label": True,
|
|
"step": 1,
|
|
"label": None,
|
|
"minimum": None,
|
|
"maximum": None,
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"elem_id": "num",
|
|
"elem_classes": ["first"],
|
|
"visible": True,
|
|
"interactive": None,
|
|
"proxy_url": None,
|
|
"info": None,
|
|
"precision": None,
|
|
"_selectable": False,
|
|
}
|
|
|
|
def test_component_functions_integer(self):
|
|
"""
|
|
Preprocess, postprocess, serialize, get_template_context
|
|
|
|
"""
|
|
numeric_input = gr.Number(precision=0, value=42)
|
|
assert numeric_input.preprocess(3) == 3
|
|
assert numeric_input.preprocess(None) is None
|
|
assert numeric_input.postprocess(3) == 3
|
|
assert numeric_input.postprocess(3) == 3
|
|
assert numeric_input.postprocess(2.85) == 3
|
|
assert numeric_input.postprocess(None) is None
|
|
assert numeric_input.get_config() == {
|
|
"value": 42,
|
|
"name": "number",
|
|
"show_label": True,
|
|
"step": 1,
|
|
"label": None,
|
|
"minimum": None,
|
|
"maximum": None,
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"interactive": None,
|
|
"proxy_url": None,
|
|
"info": None,
|
|
"precision": 0,
|
|
"_selectable": False,
|
|
}
|
|
|
|
def test_component_functions_precision(self):
|
|
"""
|
|
Preprocess, postprocess, serialize, get_template_context
|
|
|
|
"""
|
|
numeric_input = gr.Number(precision=2, value=42.3428)
|
|
assert numeric_input.preprocess(3.231241) == 3.23
|
|
assert numeric_input.preprocess(None) is None
|
|
assert numeric_input.postprocess(-42.1241) == -42.12
|
|
assert numeric_input.postprocess(5.6784) == 5.68
|
|
assert numeric_input.postprocess(2.1421) == 2.14
|
|
assert numeric_input.postprocess(None) is None
|
|
|
|
def test_precision_none_with_integer(self):
|
|
"""
|
|
Preprocess, postprocess
|
|
"""
|
|
numeric_input = gr.Number(precision=None)
|
|
assert numeric_input.preprocess(5) == 5
|
|
assert isinstance(numeric_input.preprocess(5), int)
|
|
assert numeric_input.postprocess(5) == 5
|
|
assert isinstance(numeric_input.postprocess(5), int)
|
|
|
|
def test_precision_none_with_float(self):
|
|
"""
|
|
Preprocess, postprocess
|
|
"""
|
|
numeric_input = gr.Number(value=5.5, precision=None)
|
|
assert numeric_input.preprocess(5.5) == 5.5
|
|
assert isinstance(numeric_input.preprocess(5.5), float)
|
|
assert numeric_input.postprocess(5.5) == 5.5
|
|
assert isinstance(numeric_input.postprocess(5.5), float)
|
|
|
|
def test_in_interface_as_input(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
iface = gr.Interface(lambda x: x**2, "number", "textbox")
|
|
assert iface(2) == "4"
|
|
|
|
def test_precision_0_in_interface(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
iface = gr.Interface(lambda x: x**2, gr.Number(precision=0), "textbox")
|
|
assert iface(2) == "4"
|
|
|
|
def test_in_interface_as_output(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
iface = gr.Interface(lambda x: int(x) ** 2, "textbox", "number")
|
|
assert iface(2) == 4.0
|
|
|
|
def test_static(self):
|
|
"""
|
|
postprocess
|
|
"""
|
|
component = gr.Number()
|
|
assert component.get_config().get("value") is None
|
|
component = gr.Number(3)
|
|
assert component.get_config().get("value") == 3.0
|
|
|
|
|
|
class TestSlider:
|
|
def test_component_functions(self):
|
|
"""
|
|
Preprocess, postprocess, serialize, get_config
|
|
"""
|
|
slider_input = gr.Slider()
|
|
assert slider_input.preprocess(3.0) == 3.0
|
|
assert slider_input.postprocess(3) == 3
|
|
assert slider_input.postprocess(3) == 3
|
|
assert slider_input.postprocess(None) == 0
|
|
|
|
slider_input = gr.Slider(10, 20, value=15, step=1, label="Slide Your Input")
|
|
assert slider_input.get_config() == {
|
|
"minimum": 10,
|
|
"maximum": 20,
|
|
"step": 1,
|
|
"value": 15,
|
|
"name": "slider",
|
|
"show_label": True,
|
|
"label": "Slide Your Input",
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"interactive": None,
|
|
"proxy_url": None,
|
|
"info": None,
|
|
"_selectable": False,
|
|
}
|
|
|
|
def test_in_interface(self):
|
|
""" "
|
|
Interface, process
|
|
"""
|
|
iface = gr.Interface(lambda x: x**2, "slider", "textbox")
|
|
assert iface(2) == "4"
|
|
|
|
def test_static(self):
|
|
"""
|
|
postprocess
|
|
"""
|
|
component = gr.Slider(0, 100, 5)
|
|
assert component.get_config().get("value") == 5
|
|
component = gr.Slider(0, 100, None)
|
|
assert component.get_config().get("value") == 0
|
|
|
|
@patch("gradio.Slider.get_random_value", return_value=7)
|
|
def test_slider_get_random_value_on_load(self, mock_get_random_value):
|
|
slider = gr.Slider(minimum=-5, maximum=10, randomize=True)
|
|
assert slider.value == 7
|
|
assert slider.load_event_to_attach[0]() == 7
|
|
assert slider.load_event_to_attach[1] is None
|
|
|
|
@patch("random.randint", return_value=3)
|
|
def test_slider_rounds_when_using_default_randomizer(self, mock_randint):
|
|
slider = gr.Slider(minimum=0, maximum=1, randomize=True, step=0.1)
|
|
# If get_random_value didn't round, this test would fail
|
|
# because 0.30000000000000004 != 0.3
|
|
assert slider.get_random_value() == 0.3
|
|
mock_randint.assert_called()
|
|
|
|
|
|
class TestCheckbox:
|
|
def test_component_functions(self):
|
|
"""
|
|
Preprocess, postprocess, serialize, get_config
|
|
"""
|
|
bool_input = gr.Checkbox()
|
|
assert bool_input.preprocess(True)
|
|
assert bool_input.postprocess(True)
|
|
assert bool_input.postprocess(True)
|
|
bool_input = gr.Checkbox(value=True, label="Check Your Input")
|
|
assert bool_input.get_config() == {
|
|
"value": True,
|
|
"name": "checkbox",
|
|
"show_label": True,
|
|
"label": "Check Your Input",
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"interactive": None,
|
|
"proxy_url": None,
|
|
"_selectable": False,
|
|
"info": None,
|
|
}
|
|
|
|
def test_in_interface(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
iface = gr.Interface(lambda x: 1 if x else 0, "checkbox", "number")
|
|
assert iface(True) == 1
|
|
|
|
|
|
class TestCheckboxGroup:
|
|
def test_component_functions(self):
|
|
"""
|
|
Preprocess, postprocess, serialize, get_config
|
|
"""
|
|
checkboxes_input = gr.CheckboxGroup(["a", "b", "c"])
|
|
assert checkboxes_input.preprocess(["a", "c"]) == ["a", "c"]
|
|
assert checkboxes_input.postprocess(["a", "c"]) == ["a", "c"]
|
|
|
|
checkboxes_input = gr.CheckboxGroup(["a", "b"], type="index")
|
|
assert checkboxes_input.preprocess(["a"]) == [0]
|
|
assert checkboxes_input.preprocess(["a", "b"]) == [0, 1]
|
|
assert checkboxes_input.preprocess(["a", "b", "c"]) == [0, 1, None]
|
|
|
|
# When a Gradio app is loaded with gr.load, the tuples are converted to lists,
|
|
# so we need to test that case as well
|
|
checkboxgroup = gr.CheckboxGroup(["a", "b", ["c", "c full"]]) # type: ignore
|
|
assert checkboxgroup.choices == [("a", "a"), ("b", "b"), ("c", "c full")]
|
|
|
|
checkboxes_input = gr.CheckboxGroup(
|
|
value=["a", "c"],
|
|
choices=["a", "b", "c"],
|
|
label="Check Your Inputs",
|
|
)
|
|
assert checkboxes_input.get_config() == {
|
|
"choices": [("a", "a"), ("b", "b"), ("c", "c")],
|
|
"value": ["a", "c"],
|
|
"name": "checkboxgroup",
|
|
"show_label": True,
|
|
"label": "Check Your Inputs",
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"interactive": None,
|
|
"proxy_url": None,
|
|
"_selectable": False,
|
|
"type": "value",
|
|
"info": None,
|
|
}
|
|
with pytest.raises(ValueError):
|
|
gr.CheckboxGroup(["a"], type="unknown")
|
|
|
|
cbox = gr.CheckboxGroup(choices=["a", "b"], value="c")
|
|
assert cbox.get_config()["value"] == ["c"]
|
|
assert cbox.postprocess("a") == ["a"]
|
|
assert cbox.process_example("a") == ["a"]
|
|
|
|
def test_in_interface(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
checkboxes_input = gr.CheckboxGroup(["a", "b", "c"])
|
|
iface = gr.Interface(lambda x: "|".join(x), checkboxes_input, "textbox")
|
|
assert iface(["a", "c"]) == "a|c"
|
|
assert iface([]) == ""
|
|
_ = gr.CheckboxGroup(["a", "b", "c"], type="index")
|
|
|
|
|
|
class TestRadio:
|
|
def test_component_functions(self):
|
|
"""
|
|
Preprocess, postprocess, serialize, get_config
|
|
|
|
"""
|
|
radio_input = gr.Radio(["a", "b", "c"])
|
|
assert radio_input.preprocess("c") == "c"
|
|
assert radio_input.postprocess("a") == "a"
|
|
radio_input = gr.Radio(
|
|
choices=["a", "b", "c"], value="a", label="Pick Your One Input"
|
|
)
|
|
assert radio_input.get_config() == {
|
|
"choices": [("a", "a"), ("b", "b"), ("c", "c")],
|
|
"value": "a",
|
|
"name": "radio",
|
|
"show_label": True,
|
|
"label": "Pick Your One Input",
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"interactive": None,
|
|
"proxy_url": None,
|
|
"_selectable": False,
|
|
"type": "value",
|
|
"info": None,
|
|
}
|
|
|
|
radio = gr.Radio(choices=["a", "b"], type="index")
|
|
assert radio.preprocess("a") == 0
|
|
assert radio.preprocess("b") == 1
|
|
assert radio.preprocess("c") is None
|
|
|
|
# When a Gradio app is loaded with gr.load, the tuples are converted to lists,
|
|
# so we need to test that case as well
|
|
radio = gr.Radio(["a", "b", ["c", "c full"]]) # type: ignore
|
|
assert radio.choices == [("a", "a"), ("b", "b"), ("c", "c full")]
|
|
|
|
with pytest.raises(ValueError):
|
|
gr.Radio(["a", "b"], type="unknown")
|
|
|
|
def test_in_interface(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
radio_input = gr.Radio(["a", "b", "c"])
|
|
iface = gr.Interface(lambda x: 2 * x, radio_input, "textbox")
|
|
assert iface("c") == "cc"
|
|
|
|
|
|
class TestDropdown:
|
|
def test_component_functions(self):
|
|
"""
|
|
Preprocess, postprocess, serialize, get_config
|
|
"""
|
|
dropdown_input = gr.Dropdown(["a", "b", ("c", "c full")], multiselect=True)
|
|
assert dropdown_input.preprocess("a") == "a"
|
|
assert dropdown_input.postprocess("a") == ["a"]
|
|
assert dropdown_input.preprocess("c full") == "c full"
|
|
assert dropdown_input.postprocess("c full") == ["c full"]
|
|
|
|
# When a Gradio app is loaded with gr.load, the tuples are converted to lists,
|
|
# so we need to test that case as well
|
|
dropdown_input = gr.Dropdown(["a", "b", ["c", "c full"]]) # type: ignore
|
|
assert dropdown_input.choices == [("a", "a"), ("b", "b"), ("c", "c full")]
|
|
|
|
dropdown = gr.Dropdown(choices=["a", "b"], type="index")
|
|
assert dropdown.preprocess("a") == 0
|
|
assert dropdown.preprocess("b") == 1
|
|
assert dropdown.preprocess("c") is None
|
|
|
|
dropdown = gr.Dropdown(choices=["a", "b"], type="index", multiselect=True)
|
|
assert dropdown.preprocess(["a"]) == [0]
|
|
assert dropdown.preprocess(["a", "b"]) == [0, 1]
|
|
assert dropdown.preprocess(["a", "b", "c"]) == [0, 1, None]
|
|
|
|
dropdown_input_multiselect = gr.Dropdown(["a", "b", ("c", "c full")])
|
|
assert dropdown_input_multiselect.preprocess(["a", "c full"]) == ["a", "c full"]
|
|
assert dropdown_input_multiselect.postprocess(["a", "c full"]) == [
|
|
"a",
|
|
"c full",
|
|
]
|
|
dropdown_input_multiselect = gr.Dropdown(
|
|
value=["a", "c"],
|
|
choices=["a", "b", ("c", "c full")],
|
|
label="Select Your Inputs",
|
|
multiselect=True,
|
|
max_choices=2,
|
|
)
|
|
assert dropdown_input_multiselect.get_config() == {
|
|
"allow_custom_value": False,
|
|
"choices": [("a", "a"), ("b", "b"), ("c", "c full")],
|
|
"value": ["a", "c"],
|
|
"name": "dropdown",
|
|
"show_label": True,
|
|
"label": "Select Your Inputs",
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"interactive": None,
|
|
"proxy_url": None,
|
|
"multiselect": True,
|
|
"filterable": True,
|
|
"max_choices": 2,
|
|
"_selectable": False,
|
|
"type": "value",
|
|
"info": None,
|
|
}
|
|
with pytest.raises(ValueError):
|
|
gr.Dropdown(["a"], type="unknown")
|
|
|
|
dropdown = gr.Dropdown(choices=["a", "b"], value="c")
|
|
assert dropdown.get_config()["value"] == "c"
|
|
assert dropdown.postprocess("a") == "a"
|
|
|
|
def test_in_interface(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
dropdown_input = gr.Dropdown(["a", "b", "c"])
|
|
iface = gr.Interface(lambda x: "|".join(x), dropdown_input, "textbox")
|
|
assert iface(["a", "c"]) == "a|c"
|
|
assert iface([]) == ""
|
|
|
|
|
|
class TestImageEditor:
|
|
def test_component_functions(self):
|
|
test_image_path = "test/test_files/bus.png"
|
|
image_data = FileData(path=test_image_path)
|
|
image_editor_data = EditorData(
|
|
background=image_data, layers=[image_data, image_data], composite=image_data
|
|
)
|
|
payload = {
|
|
"background": test_image_path,
|
|
"layers": [test_image_path, test_image_path],
|
|
"composite": test_image_path,
|
|
}
|
|
|
|
image_editor_component = gr.ImageEditor()
|
|
|
|
assert isinstance(image_editor_component.preprocess(image_editor_data), dict)
|
|
assert image_editor_component.postprocess(payload) == image_editor_data
|
|
|
|
# Test that ImageEditor can accept just a filepath as well
|
|
simpler_data = EditorData(
|
|
background=image_data, layers=[], composite=image_data
|
|
)
|
|
assert image_editor_component.postprocess(test_image_path) == simpler_data
|
|
|
|
assert image_editor_component.get_config() == {
|
|
"value": None,
|
|
"height": None,
|
|
"width": None,
|
|
"image_mode": "RGBA",
|
|
"sources": ("upload", "webcam", "clipboard"),
|
|
"type": "numpy",
|
|
"label": None,
|
|
"show_label": True,
|
|
"show_download_button": True,
|
|
"container": True,
|
|
"scale": None,
|
|
"min_width": 160,
|
|
"interactive": None,
|
|
"visible": True,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"mirror_webcam": True,
|
|
"show_share_button": False,
|
|
"_selectable": False,
|
|
"crop_size": None,
|
|
"transforms": ("crop",),
|
|
"eraser": {"default_size": "auto"},
|
|
"brush": {
|
|
"default_size": "auto",
|
|
"colors": [
|
|
"rgb(204, 50, 50)",
|
|
"rgb(173, 204, 50)",
|
|
"rgb(50, 204, 112)",
|
|
"rgb(50, 112, 204)",
|
|
"rgb(173, 50, 204)",
|
|
],
|
|
"default_color": "auto",
|
|
"color_mode": "defaults",
|
|
},
|
|
"proxy_url": None,
|
|
"name": "imageeditor",
|
|
}
|
|
|
|
def test_process_example(self):
|
|
test_image_path = "test/test_files/bus.png"
|
|
image_editor = gr.ImageEditor()
|
|
example_value = image_editor.process_example(test_image_path)
|
|
assert isinstance(example_value, EditorData)
|
|
assert example_value.background and example_value.background.path
|
|
|
|
|
|
class TestImage:
|
|
def test_component_functions(self, gradio_temp_dir):
|
|
"""
|
|
Preprocess, postprocess, serialize, get_config, _segment_by_slic
|
|
type: pil, file, filepath, numpy
|
|
"""
|
|
|
|
img = FileData(path="test/test_files/bus.png")
|
|
image_input = gr.Image()
|
|
|
|
image_input = gr.Image(type="filepath")
|
|
image_temp_filepath = image_input.preprocess(img)
|
|
assert image_temp_filepath in [
|
|
str(f) for f in gradio_temp_dir.glob("**/*") if f.is_file()
|
|
]
|
|
|
|
image_input = gr.Image(type="pil", label="Upload Your Image")
|
|
assert image_input.get_config() == {
|
|
"image_mode": "RGB",
|
|
"sources": ["upload", "webcam", "clipboard"],
|
|
"name": "image",
|
|
"show_share_button": False,
|
|
"show_download_button": True,
|
|
"streaming": False,
|
|
"show_label": True,
|
|
"label": "Upload Your Image",
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"height": None,
|
|
"width": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"value": None,
|
|
"interactive": None,
|
|
"proxy_url": None,
|
|
"mirror_webcam": True,
|
|
"_selectable": False,
|
|
"streamable": False,
|
|
"type": "pil",
|
|
}
|
|
assert image_input.preprocess(None) is None
|
|
image_input = gr.Image()
|
|
assert image_input.preprocess(img) is not None
|
|
image_input.preprocess(img)
|
|
file_image = gr.Image(type="filepath")
|
|
assert isinstance(file_image.preprocess(img), str)
|
|
with pytest.raises(ValueError):
|
|
gr.Image(type="unknown")
|
|
|
|
string_source = gr.Image(sources="upload")
|
|
assert string_source.sources == ["upload"]
|
|
# Output functionalities
|
|
image_output = gr.Image(type="pil")
|
|
processed_image = image_output.postprocess(
|
|
PIL.Image.open(img.path)
|
|
).model_dump()
|
|
assert processed_image is not None
|
|
if processed_image is not None:
|
|
processed = PIL.Image.open(cast(dict, processed_image).get("path", ""))
|
|
source = PIL.Image.open(img.path)
|
|
assert processed.size == source.size
|
|
|
|
def test_in_interface_as_output(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
|
|
def generate_noise(height, width):
|
|
return np.random.randint(0, 256, (height, width, 3))
|
|
|
|
iface = gr.Interface(generate_noise, ["slider", "slider"], "image")
|
|
assert iface(10, 20).endswith(".png")
|
|
|
|
def test_static(self):
|
|
"""
|
|
postprocess
|
|
"""
|
|
component = gr.Image("test/test_files/bus.png")
|
|
value = component.get_config().get("value")
|
|
base64 = client_utils.encode_file_to_base64(value["path"])
|
|
assert base64 == media_data.BASE64_IMAGE
|
|
component = gr.Image(None)
|
|
assert component.get_config().get("value") is None
|
|
|
|
def test_images_upright_after_preprocess(self):
|
|
component = gr.Image(type="pil")
|
|
file_path = "test/test_files/rotated_image.jpeg"
|
|
im = PIL.Image.open(file_path)
|
|
assert im.getexif().get(274) != 1
|
|
image = component.preprocess(FileData(path=file_path))
|
|
assert image == PIL.ImageOps.exif_transpose(im)
|
|
|
|
|
|
class TestPlot:
|
|
@pytest.mark.asyncio
|
|
async def test_in_interface_as_output(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
|
|
def plot(num):
|
|
import matplotlib.pyplot as plt
|
|
|
|
fig = plt.figure()
|
|
plt.plot(range(num), range(num))
|
|
return fig
|
|
|
|
iface = gr.Interface(plot, "slider", "plot")
|
|
with utils.MatplotlibBackendMananger():
|
|
output = await iface.process_api(fn_index=0, inputs=[10], state={})
|
|
assert output["data"][0]["type"] == "matplotlib"
|
|
assert output["data"][0]["plot"].startswith("data:image/png;base64")
|
|
|
|
def test_static(self):
|
|
"""
|
|
postprocess
|
|
"""
|
|
with utils.MatplotlibBackendMananger():
|
|
import matplotlib.pyplot as plt
|
|
|
|
fig = plt.figure()
|
|
plt.plot([1, 2, 3], [1, 2, 3])
|
|
|
|
component = gr.Plot(fig)
|
|
assert component.get_config().get("value") is not None
|
|
component = gr.Plot(None)
|
|
assert component.get_config().get("value") is None
|
|
|
|
def test_postprocess_altair(self):
|
|
import altair as alt
|
|
from vega_datasets import data
|
|
|
|
cars = data.cars()
|
|
chart = (
|
|
alt.Chart(cars)
|
|
.mark_point()
|
|
.encode(
|
|
x="Horsepower",
|
|
y="Miles_per_Gallon",
|
|
color="Origin",
|
|
)
|
|
)
|
|
out = gr.Plot().postprocess(chart).model_dump()
|
|
assert isinstance(out["plot"], str)
|
|
assert out["plot"] == chart.to_json()
|
|
|
|
|
|
class TestAudio:
|
|
def test_component_functions(self, gradio_temp_dir):
|
|
"""
|
|
Preprocess, postprocess serialize, get_config, deserialize
|
|
type: filepath, numpy, file
|
|
"""
|
|
x_wav = FileData(path=media_data.BASE64_AUDIO["path"])
|
|
audio_input = gr.Audio()
|
|
output1 = audio_input.preprocess(x_wav)
|
|
assert output1[0] == 8000
|
|
assert output1[1].shape == (8046,)
|
|
|
|
x_wav = processing_utils.move_files_to_cache([x_wav], audio_input)[0]
|
|
audio_input = gr.Audio(type="filepath")
|
|
output1 = audio_input.preprocess(x_wav)
|
|
assert Path(output1).name.endswith("audio_sample.wav")
|
|
|
|
audio_input = gr.Audio(label="Upload Your Audio")
|
|
assert audio_input.get_config() == {
|
|
"autoplay": False,
|
|
"sources": ["upload", "microphone"],
|
|
"name": "audio",
|
|
"show_download_button": None,
|
|
"show_share_button": False,
|
|
"streaming": False,
|
|
"show_label": True,
|
|
"label": "Upload Your Audio",
|
|
"container": True,
|
|
"editable": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"value": None,
|
|
"interactive": None,
|
|
"proxy_url": None,
|
|
"type": "numpy",
|
|
"format": "wav",
|
|
"streamable": False,
|
|
"max_length": None,
|
|
"min_length": None,
|
|
"waveform_options": {
|
|
"sample_rate": 44100,
|
|
"show_controls": False,
|
|
"show_recording_waveform": True,
|
|
"skip_length": 5,
|
|
"waveform_color": "#9ca3af",
|
|
"waveform_progress_color": "#f97316",
|
|
},
|
|
"_selectable": False,
|
|
}
|
|
assert audio_input.preprocess(None) is None
|
|
|
|
audio_input = gr.Audio(type="filepath")
|
|
assert isinstance(audio_input.preprocess(x_wav), str)
|
|
with pytest.raises(ValueError):
|
|
gr.Audio(type="unknown")
|
|
|
|
# Confirm Audio can be instantiated with a numpy array
|
|
gr.Audio((100, np.random.random(size=(1000, 2))), label="Play your audio")
|
|
|
|
# Output functionalities
|
|
y_audio = client_utils.decode_base64_to_file(
|
|
deepcopy(media_data.BASE64_AUDIO)["data"]
|
|
)
|
|
audio_output = gr.Audio(type="filepath")
|
|
assert filecmp.cmp(
|
|
y_audio.name, audio_output.postprocess(y_audio.name).model_dump()["path"]
|
|
)
|
|
assert audio_output.get_config() == {
|
|
"autoplay": False,
|
|
"name": "audio",
|
|
"show_download_button": None,
|
|
"show_share_button": False,
|
|
"streaming": False,
|
|
"show_label": True,
|
|
"label": None,
|
|
"max_length": None,
|
|
"min_length": None,
|
|
"container": True,
|
|
"editable": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"value": None,
|
|
"interactive": None,
|
|
"proxy_url": None,
|
|
"type": "filepath",
|
|
"format": "wav",
|
|
"streamable": False,
|
|
"sources": ["upload", "microphone"],
|
|
"waveform_options": {
|
|
"sample_rate": 44100,
|
|
"show_controls": False,
|
|
"show_recording_waveform": True,
|
|
"skip_length": 5,
|
|
"waveform_color": "#9ca3af",
|
|
"waveform_progress_color": "#f97316",
|
|
},
|
|
"_selectable": False,
|
|
}
|
|
|
|
output1 = audio_output.postprocess(y_audio.name).model_dump()
|
|
output2 = audio_output.postprocess(Path(y_audio.name)).model_dump()
|
|
assert output1 == output2
|
|
|
|
def test_default_value_postprocess(self):
|
|
x_wav = deepcopy(media_data.BASE64_AUDIO)
|
|
audio = gr.Audio(value=x_wav["path"])
|
|
assert utils.is_in_or_equal(audio.value["path"], audio.GRADIO_CACHE)
|
|
|
|
def test_in_interface(self):
|
|
def reverse_audio(audio):
|
|
sr, data = audio
|
|
return (sr, np.flipud(data))
|
|
|
|
iface = gr.Interface(reverse_audio, "audio", "audio")
|
|
reversed_file = iface("test/test_files/audio_sample.wav")
|
|
reversed_reversed_file = iface(reversed_file)
|
|
reversed_reversed_data = client_utils.encode_url_or_file_to_base64(
|
|
reversed_reversed_file
|
|
)
|
|
similarity = SequenceMatcher(
|
|
a=reversed_reversed_data, b=media_data.BASE64_AUDIO["data"]
|
|
).ratio()
|
|
assert similarity > 0.99
|
|
|
|
def test_in_interface_as_output(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
|
|
def generate_noise(duration):
|
|
return 48000, np.random.randint(-256, 256, (duration, 3)).astype(np.int16)
|
|
|
|
iface = gr.Interface(generate_noise, "slider", "audio")
|
|
assert iface(100).endswith(".wav")
|
|
|
|
def test_audio_preprocess_can_be_read_by_scipy(self, gradio_temp_dir):
|
|
x_wav = FileData(
|
|
path=processing_utils.save_base64_to_cache(
|
|
media_data.BASE64_MICROPHONE["data"], cache_dir=gradio_temp_dir
|
|
)
|
|
)
|
|
audio_input = gr.Audio(type="filepath")
|
|
output = audio_input.preprocess(x_wav)
|
|
wavfile.read(output)
|
|
|
|
def test_prepost_process_to_mp3(self, gradio_temp_dir):
|
|
x_wav = FileData(
|
|
path=processing_utils.save_base64_to_cache(
|
|
media_data.BASE64_MICROPHONE["data"], cache_dir=gradio_temp_dir
|
|
)
|
|
)
|
|
audio_input = gr.Audio(type="filepath", format="mp3")
|
|
output = audio_input.preprocess(x_wav)
|
|
assert output.endswith("mp3")
|
|
output = audio_input.postprocess(
|
|
(48000, np.random.randint(-256, 256, (5, 3)).astype(np.int16))
|
|
).model_dump()
|
|
assert output["path"].endswith("mp3")
|
|
|
|
|
|
class TestFile:
|
|
def test_component_functions(self):
|
|
"""
|
|
Preprocess, serialize, get_config, value
|
|
"""
|
|
x_file = FileData(path=media_data.BASE64_FILE["path"])
|
|
file_input = gr.File()
|
|
output = file_input.preprocess(x_file)
|
|
assert isinstance(output, str)
|
|
|
|
input1 = file_input.preprocess(x_file)
|
|
input2 = file_input.preprocess(x_file)
|
|
assert input1 == input1.name # Testing backwards compatibility
|
|
assert input1 == input2
|
|
assert Path(input1).name == "sample_file.pdf"
|
|
|
|
file_input = gr.File(label="Upload Your File")
|
|
assert file_input.get_config() == {
|
|
"file_count": "single",
|
|
"file_types": None,
|
|
"name": "file",
|
|
"show_label": True,
|
|
"label": "Upload Your File",
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"value": None,
|
|
"interactive": None,
|
|
"proxy_url": None,
|
|
"_selectable": False,
|
|
"height": None,
|
|
"type": "filepath",
|
|
}
|
|
assert file_input.preprocess(None) is None
|
|
assert file_input.preprocess(x_file) is not None
|
|
|
|
zero_size_file = FileData(path="document.txt", size=0)
|
|
temp_file = file_input.preprocess(zero_size_file)
|
|
assert not Path(temp_file.name).exists()
|
|
|
|
file_input = gr.File(type="binary")
|
|
output = file_input.preprocess(x_file)
|
|
assert isinstance(output, bytes)
|
|
|
|
output1 = file_input.postprocess("test/test_files/sample_file.pdf")
|
|
output2 = file_input.postprocess("test/test_files/sample_file.pdf")
|
|
assert output1 == output2
|
|
|
|
def test_preprocess_with_multiple_files(self):
|
|
file_data = FileData(path=media_data.BASE64_FILE["path"])
|
|
list_file_data = ListFiles(root=[file_data, file_data])
|
|
file_input = gr.File(file_count="directory")
|
|
output = file_input.preprocess(list_file_data)
|
|
assert isinstance(output, list)
|
|
assert isinstance(output[0], str)
|
|
|
|
def test_file_type_must_be_list(self):
|
|
with pytest.raises(
|
|
ValueError, match="Parameter file_types must be a list. Received str"
|
|
):
|
|
gr.File(file_types=".json")
|
|
|
|
def test_in_interface_as_input(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
x_file = media_data.BASE64_FILE["path"]
|
|
|
|
def get_size_of_file(file_obj):
|
|
return os.path.getsize(file_obj.name)
|
|
|
|
iface = gr.Interface(get_size_of_file, "file", "number")
|
|
assert iface(x_file) == 10558
|
|
|
|
def test_as_component_as_output(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
|
|
def write_file(content):
|
|
with open("test.txt", "w") as f:
|
|
f.write(content)
|
|
return "test.txt"
|
|
|
|
iface = gr.Interface(write_file, "text", "file")
|
|
assert iface("hello world").endswith(".txt")
|
|
|
|
|
|
class TestUploadButton:
|
|
def test_component_functions(self):
|
|
"""
|
|
preprocess
|
|
"""
|
|
x_file = FileData(path=media_data.BASE64_FILE["path"])
|
|
upload_input = gr.UploadButton()
|
|
input = upload_input.preprocess(x_file)
|
|
assert isinstance(input, str)
|
|
|
|
input1 = upload_input.preprocess(x_file)
|
|
input2 = upload_input.preprocess(x_file)
|
|
assert input1 == input1.name # Testing backwards compatibility
|
|
assert input1 == input2
|
|
|
|
def test_raises_if_file_types_is_not_list(self):
|
|
with pytest.raises(
|
|
ValueError, match="Parameter file_types must be a list. Received int"
|
|
):
|
|
gr.UploadButton(file_types=2)
|
|
|
|
def test_preprocess_with_multiple_files(self):
|
|
file_data = FileData(path=media_data.BASE64_FILE["path"])
|
|
list_file_data = ListFiles(root=[file_data, file_data])
|
|
upload_input = gr.UploadButton(file_count="directory")
|
|
output = upload_input.preprocess(list_file_data)
|
|
assert isinstance(output, list)
|
|
assert isinstance(output[0], str)
|
|
|
|
|
|
class TestDataframe:
|
|
def test_component_functions(self):
|
|
"""
|
|
Preprocess, serialize, get_config
|
|
"""
|
|
x_data = {
|
|
"data": [["Tim", 12, False], ["Jan", 24, True]],
|
|
"headers": ["Name", "Age", "Member"],
|
|
"metadata": None,
|
|
}
|
|
x_payload = DataframeData(**x_data)
|
|
dataframe_input = gr.Dataframe(headers=["Name", "Age", "Member"])
|
|
output = dataframe_input.preprocess(x_payload)
|
|
assert output["Age"][1] == 24
|
|
assert not output["Member"][0]
|
|
assert dataframe_input.postprocess(output) == x_payload
|
|
|
|
dataframe_input = gr.Dataframe(
|
|
headers=["Name", "Age", "Member"], label="Dataframe Input"
|
|
)
|
|
assert dataframe_input.get_config() == {
|
|
"value": {
|
|
"headers": ["Name", "Age", "Member"],
|
|
"data": [["", "", ""]],
|
|
"metadata": None,
|
|
},
|
|
"_selectable": False,
|
|
"headers": ["Name", "Age", "Member"],
|
|
"row_count": (1, "dynamic"),
|
|
"col_count": (3, "dynamic"),
|
|
"datatype": ["str", "str", "str"],
|
|
"type": "pandas",
|
|
"label": "Dataframe Input",
|
|
"show_label": True,
|
|
"scale": None,
|
|
"min_width": 160,
|
|
"interactive": None,
|
|
"visible": True,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"wrap": False,
|
|
"proxy_url": None,
|
|
"name": "dataframe",
|
|
"height": 500,
|
|
"latex_delimiters": [{"display": True, "left": "$$", "right": "$$"}],
|
|
"line_breaks": True,
|
|
"column_widths": [],
|
|
}
|
|
dataframe_input = gr.Dataframe()
|
|
output = dataframe_input.preprocess(DataframeData(**x_data))
|
|
assert output["Age"][1] == 24
|
|
|
|
x_data = {
|
|
"data": [["Tim", 12, False], ["Jan", 24, True]],
|
|
"headers": ["Name", "Age", "Member"],
|
|
"metadata": {"display_value": None, "styling": None},
|
|
}
|
|
dataframe_input.preprocess(DataframeData(**x_data))
|
|
|
|
with pytest.raises(ValueError):
|
|
gr.Dataframe(type="unknown")
|
|
|
|
dataframe_output = gr.Dataframe()
|
|
assert dataframe_output.get_config() == {
|
|
"value": {
|
|
"headers": ["1", "2", "3"],
|
|
"data": [["", "", ""]],
|
|
"metadata": None,
|
|
},
|
|
"_selectable": False,
|
|
"headers": ["1", "2", "3"],
|
|
"row_count": (1, "dynamic"),
|
|
"col_count": (3, "dynamic"),
|
|
"datatype": ["str", "str", "str"],
|
|
"type": "pandas",
|
|
"label": None,
|
|
"show_label": True,
|
|
"scale": None,
|
|
"min_width": 160,
|
|
"interactive": None,
|
|
"visible": True,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"wrap": False,
|
|
"proxy_url": None,
|
|
"name": "dataframe",
|
|
"height": 500,
|
|
"latex_delimiters": [{"display": True, "left": "$$", "right": "$$"}],
|
|
"line_breaks": True,
|
|
"column_widths": [],
|
|
}
|
|
|
|
dataframe_input = gr.Dataframe(column_widths=["100px", 200, "50%"])
|
|
assert dataframe_input.get_config()["column_widths"] == [
|
|
"100px",
|
|
"200px",
|
|
"50%",
|
|
]
|
|
|
|
def test_postprocess(self):
|
|
"""
|
|
postprocess
|
|
"""
|
|
dataframe_output = gr.Dataframe()
|
|
output = dataframe_output.postprocess([]).model_dump()
|
|
assert output == {"data": [[]], "headers": [], "metadata": None}
|
|
output = dataframe_output.postprocess(np.zeros((2, 2))).model_dump()
|
|
assert output == {
|
|
"data": [[0, 0], [0, 0]],
|
|
"headers": ["1", "2"],
|
|
"metadata": None,
|
|
}
|
|
output = dataframe_output.postprocess([[1, 3, 5]]).model_dump()
|
|
assert output == {
|
|
"data": [[1, 3, 5]],
|
|
"headers": ["1", "2", "3"],
|
|
"metadata": None,
|
|
}
|
|
output = dataframe_output.postprocess(
|
|
pd.DataFrame([[2, True], [3, True], [4, False]], columns=["num", "prime"])
|
|
).model_dump()
|
|
assert output == {
|
|
"headers": ["num", "prime"],
|
|
"data": [[2, True], [3, True], [4, False]],
|
|
"metadata": None,
|
|
}
|
|
with pytest.raises(ValueError):
|
|
gr.Dataframe(type="unknown")
|
|
|
|
# When the headers don't match the data
|
|
dataframe_output = gr.Dataframe(headers=["one", "two", "three"])
|
|
output = dataframe_output.postprocess([[2, True], [3, True]]).model_dump()
|
|
assert output == {
|
|
"headers": ["one", "two"],
|
|
"data": [[2, True], [3, True]],
|
|
"metadata": None,
|
|
}
|
|
dataframe_output = gr.Dataframe(headers=["one", "two", "three"])
|
|
output = dataframe_output.postprocess(
|
|
[[2, True, "ab", 4], [3, True, "cd", 5]]
|
|
).model_dump()
|
|
assert output == {
|
|
"headers": ["one", "two", "three", "4"],
|
|
"data": [[2, True, "ab", 4], [3, True, "cd", 5]],
|
|
"metadata": None,
|
|
}
|
|
|
|
def test_dataframe_postprocess_all_types(self):
|
|
df = pd.DataFrame(
|
|
{
|
|
"date_1": pd.date_range("2021-01-01", periods=2),
|
|
"date_2": pd.date_range("2022-02-15", periods=2).strftime(
|
|
"%B %d, %Y, %r"
|
|
),
|
|
"number": np.array([0.2233, 0.57281]),
|
|
"number_2": np.array([84, 23]).astype(np.int64),
|
|
"bool": [True, False],
|
|
"markdown": ["# Hello", "# Goodbye"],
|
|
}
|
|
)
|
|
component = gr.Dataframe(
|
|
datatype=["date", "date", "number", "number", "bool", "markdown"]
|
|
)
|
|
output = component.postprocess(df).model_dump()
|
|
assert output == {
|
|
"headers": list(df.columns),
|
|
"data": [
|
|
[
|
|
pd.Timestamp("2021-01-01 00:00:00"),
|
|
"February 15, 2022, 12:00:00 AM",
|
|
0.2233,
|
|
84,
|
|
True,
|
|
"# Hello",
|
|
],
|
|
[
|
|
pd.Timestamp("2021-01-02 00:00:00"),
|
|
"February 16, 2022, 12:00:00 AM",
|
|
0.57281,
|
|
23,
|
|
False,
|
|
"# Goodbye",
|
|
],
|
|
],
|
|
"metadata": None,
|
|
}
|
|
|
|
def test_dataframe_postprocess_only_dates(self):
|
|
df = pd.DataFrame(
|
|
{
|
|
"date_1": pd.date_range("2021-01-01", periods=2),
|
|
"date_2": pd.date_range("2022-02-15", periods=2),
|
|
}
|
|
)
|
|
component = gr.Dataframe(datatype=["date", "date"])
|
|
output = component.postprocess(df).model_dump()
|
|
assert output == {
|
|
"headers": list(df.columns),
|
|
"data": [
|
|
[
|
|
pd.Timestamp("2021-01-01 00:00:00"),
|
|
pd.Timestamp("2022-02-15 00:00:00"),
|
|
],
|
|
[
|
|
pd.Timestamp("2021-01-02 00:00:00"),
|
|
pd.Timestamp("2022-02-16 00:00:00"),
|
|
],
|
|
],
|
|
"metadata": None,
|
|
}
|
|
|
|
def test_dataframe_postprocess_styler(self):
|
|
component = gr.Dataframe()
|
|
df = pd.DataFrame(
|
|
{
|
|
"name": ["Adam", "Mike"] * 4,
|
|
"gpa": [1.1, 1.12] * 4,
|
|
"sat": [800, 800] * 4,
|
|
}
|
|
)
|
|
s = df.style.format(precision=1, decimal=",")
|
|
output = component.postprocess(s).model_dump()
|
|
assert output == {
|
|
"data": [
|
|
["Adam", 1.1, 800],
|
|
["Mike", 1.12, 800],
|
|
["Adam", 1.1, 800],
|
|
["Mike", 1.12, 800],
|
|
["Adam", 1.1, 800],
|
|
["Mike", 1.12, 800],
|
|
["Adam", 1.1, 800],
|
|
["Mike", 1.12, 800],
|
|
],
|
|
"headers": ["name", "gpa", "sat"],
|
|
"metadata": {
|
|
"display_value": [
|
|
["Adam", "1,1", "800"],
|
|
["Mike", "1,1", "800"],
|
|
["Adam", "1,1", "800"],
|
|
["Mike", "1,1", "800"],
|
|
["Adam", "1,1", "800"],
|
|
["Mike", "1,1", "800"],
|
|
["Adam", "1,1", "800"],
|
|
["Mike", "1,1", "800"],
|
|
],
|
|
"styling": [
|
|
["", "", ""],
|
|
["", "", ""],
|
|
["", "", ""],
|
|
["", "", ""],
|
|
["", "", ""],
|
|
["", "", ""],
|
|
["", "", ""],
|
|
["", "", ""],
|
|
],
|
|
},
|
|
}
|
|
|
|
df = pd.DataFrame(
|
|
{
|
|
"A": [14, 4, 5, 4, 1],
|
|
"B": [5, 2, 54, 3, 2],
|
|
"C": [20, 20, 7, 3, 8],
|
|
"D": [14, 3, 6, 2, 6],
|
|
"E": [23, 45, 64, 32, 23],
|
|
}
|
|
)
|
|
|
|
t = df.style.highlight_max(color="lightgreen", axis=0)
|
|
output = component.postprocess(t).model_dump()
|
|
assert output == {
|
|
"data": [
|
|
[14, 5, 20, 14, 23],
|
|
[4, 2, 20, 3, 45],
|
|
[5, 54, 7, 6, 64],
|
|
[4, 3, 3, 2, 32],
|
|
[1, 2, 8, 6, 23],
|
|
],
|
|
"headers": ["A", "B", "C", "D", "E"],
|
|
"metadata": {
|
|
"display_value": [
|
|
["14", "5", "20", "14", "23"],
|
|
["4", "2", "20", "3", "45"],
|
|
["5", "54", "7", "6", "64"],
|
|
["4", "3", "3", "2", "32"],
|
|
["1", "2", "8", "6", "23"],
|
|
],
|
|
"styling": [
|
|
[
|
|
"background-color: lightgreen",
|
|
"",
|
|
"background-color: lightgreen",
|
|
"background-color: lightgreen",
|
|
"",
|
|
],
|
|
["", "", "background-color: lightgreen", "", ""],
|
|
[
|
|
"",
|
|
"background-color: lightgreen",
|
|
"",
|
|
"",
|
|
"background-color: lightgreen",
|
|
],
|
|
["", "", "", "", ""],
|
|
["", "", "", "", ""],
|
|
],
|
|
},
|
|
}
|
|
|
|
|
|
class TestDataset:
|
|
def test_preprocessing(self):
|
|
test_file_dir = Path(__file__).parent / "test_files"
|
|
bus = str(Path(test_file_dir, "bus.png").resolve())
|
|
|
|
dataset = gr.Dataset(
|
|
components=["number", "textbox", "image", "html", "markdown"],
|
|
samples=[
|
|
[5, "hello", bus, "<b>Bold</b>", "**Bold**"],
|
|
[15, "hi", bus, "<i>Italics</i>", "*Italics*"],
|
|
],
|
|
)
|
|
|
|
row = dataset.preprocess(1)
|
|
assert row[0] == 15
|
|
assert row[1] == "hi"
|
|
assert row[2]["path"].endswith("bus.png")
|
|
assert row[3] == "<i>Italics</i>"
|
|
assert row[4] == "*Italics*"
|
|
|
|
dataset = gr.Dataset(
|
|
components=["number", "textbox", "image", "html", "markdown"],
|
|
samples=[
|
|
[5, "hello", bus, "<b>Bold</b>", "**Bold**"],
|
|
[15, "hi", bus, "<i>Italics</i>", "*Italics*"],
|
|
],
|
|
type="index",
|
|
)
|
|
|
|
assert dataset.preprocess(1) == 1
|
|
|
|
radio = gr.Radio(choices=[("name 1", "value 1"), ("name 2", "value 2")])
|
|
dataset = gr.Dataset(samples=[["value 1"], ["value 2"]], components=[radio])
|
|
assert dataset.samples == [["value 1"], ["value 2"]]
|
|
|
|
def test_postprocessing(self):
|
|
test_file_dir = Path(Path(__file__).parent, "test_files")
|
|
bus = Path(test_file_dir, "bus.png")
|
|
|
|
dataset = gr.Dataset(
|
|
components=["number", "textbox", "image", "html", "markdown"], type="index"
|
|
)
|
|
|
|
output = dataset.postprocess(
|
|
samples=[
|
|
[5, "hello", bus, "<b>Bold</b>", "**Bold**"],
|
|
[15, "hi", bus, "<i>Italics</i>", "*Italics*"],
|
|
],
|
|
)
|
|
|
|
assert output == {
|
|
"samples": [
|
|
[5, "hello", bus, "<b>Bold</b>", "**Bold**"],
|
|
[15, "hi", bus, "<i>Italics</i>", "*Italics*"],
|
|
],
|
|
"__type__": "update",
|
|
}
|
|
|
|
|
|
class TestVideo:
|
|
def test_component_functions(self):
|
|
"""
|
|
Preprocess, serialize, deserialize, get_config
|
|
"""
|
|
x_video = VideoData(
|
|
video=FileData(path=deepcopy(media_data.BASE64_VIDEO)["path"])
|
|
)
|
|
video_input = gr.Video()
|
|
|
|
x_video = processing_utils.move_files_to_cache([x_video], video_input)[0]
|
|
|
|
output1 = video_input.preprocess(x_video)
|
|
assert isinstance(output1, str)
|
|
output2 = video_input.preprocess(x_video)
|
|
assert output1 == output2
|
|
|
|
video_input = gr.Video(include_audio=False)
|
|
output1 = video_input.preprocess(x_video)
|
|
output2 = video_input.preprocess(x_video)
|
|
assert output1 == output2
|
|
|
|
video_input = gr.Video(label="Upload Your Video")
|
|
assert video_input.get_config() == {
|
|
"autoplay": False,
|
|
"sources": ["upload", "webcam"],
|
|
"name": "video",
|
|
"show_share_button": False,
|
|
"show_label": True,
|
|
"label": "Upload Your Video",
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"show_download_button": None,
|
|
"height": None,
|
|
"width": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"value": None,
|
|
"interactive": None,
|
|
"proxy_url": None,
|
|
"mirror_webcam": True,
|
|
"include_audio": True,
|
|
"format": None,
|
|
"min_length": None,
|
|
"max_length": None,
|
|
"_selectable": False,
|
|
}
|
|
assert video_input.preprocess(None) is None
|
|
video_input = gr.Video(format="avi")
|
|
output_video = video_input.preprocess(x_video)
|
|
assert output_video[-3:] == "avi"
|
|
assert "flip" not in output_video
|
|
|
|
# Output functionalities
|
|
y_vid_path = "test/test_files/video_sample.mp4"
|
|
subtitles_path = "test/test_files/s1.srt"
|
|
video_output = gr.Video()
|
|
output1 = video_output.postprocess(y_vid_path).model_dump()["video"]["path"]
|
|
assert output1.endswith("mp4")
|
|
output2 = video_output.postprocess(y_vid_path).model_dump()["video"]["path"]
|
|
assert output1 == output2
|
|
assert (
|
|
video_output.postprocess(y_vid_path).model_dump()["video"]["orig_name"]
|
|
== "video_sample.mp4"
|
|
)
|
|
output_with_subtitles = video_output.postprocess(
|
|
(y_vid_path, subtitles_path)
|
|
).model_dump()
|
|
assert output_with_subtitles["subtitles"]["path"].endswith(".vtt")
|
|
|
|
p_video = gr.Video()
|
|
video_with_subtitle = gr.Video()
|
|
postprocessed_video = p_video.postprocess(Path(y_vid_path)).model_dump()
|
|
postprocessed_video_with_subtitle = video_with_subtitle.postprocess(
|
|
(Path(y_vid_path), Path(subtitles_path))
|
|
).model_dump()
|
|
|
|
processed_video = {
|
|
"video": {
|
|
"path": "video_sample.mp4",
|
|
"orig_name": "video_sample.mp4",
|
|
"mime_type": None,
|
|
"size": None,
|
|
"url": None,
|
|
},
|
|
"subtitles": None,
|
|
}
|
|
|
|
processed_video_with_subtitle = {
|
|
"video": {
|
|
"path": "video_sample.mp4",
|
|
"orig_name": "video_sample.mp4",
|
|
"mime_type": None,
|
|
"size": None,
|
|
"url": None,
|
|
},
|
|
"subtitles": {
|
|
"path": "s1.srt",
|
|
"mime_type": None,
|
|
"orig_name": None,
|
|
"size": None,
|
|
"url": None,
|
|
},
|
|
}
|
|
postprocessed_video["video"]["path"] = os.path.basename(
|
|
postprocessed_video["video"]["path"]
|
|
)
|
|
assert processed_video == postprocessed_video
|
|
postprocessed_video_with_subtitle["video"]["path"] = os.path.basename(
|
|
postprocessed_video_with_subtitle["video"]["path"]
|
|
)
|
|
if postprocessed_video_with_subtitle["subtitles"]["path"]:
|
|
postprocessed_video_with_subtitle["subtitles"]["path"] = "s1.srt"
|
|
assert processed_video_with_subtitle == postprocessed_video_with_subtitle
|
|
|
|
def test_in_interface(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
x_video = media_data.BASE64_VIDEO["path"]
|
|
iface = gr.Interface(lambda x: x, "video", "playable_video")
|
|
assert iface({"video": x_video})["video"].endswith(".mp4")
|
|
|
|
def test_with_waveform(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
x_audio = media_data.BASE64_AUDIO["path"]
|
|
iface = gr.Interface(lambda x: gr.make_waveform(x), "audio", "video")
|
|
assert iface(x_audio)["video"].endswith(".mp4")
|
|
|
|
def test_video_postprocess_converts_to_playable_format(self):
|
|
test_file_dir = Path(Path(__file__).parent, "test_files")
|
|
# This file has a playable container but not playable codec
|
|
with tempfile.NamedTemporaryFile(
|
|
suffix="bad_video.mp4", delete=False
|
|
) as tmp_not_playable_vid:
|
|
bad_vid = str(test_file_dir / "bad_video_sample.mp4")
|
|
assert not processing_utils.video_is_playable(bad_vid)
|
|
shutil.copy(bad_vid, tmp_not_playable_vid.name)
|
|
output = gr.Video().postprocess(tmp_not_playable_vid.name).model_dump()
|
|
assert processing_utils.video_is_playable(output["video"]["path"])
|
|
|
|
# This file has a playable codec but not a playable container
|
|
with tempfile.NamedTemporaryFile(
|
|
suffix="playable_but_bad_container.mkv", delete=False
|
|
) as tmp_not_playable_vid:
|
|
bad_vid = str(test_file_dir / "playable_but_bad_container.mkv")
|
|
assert not processing_utils.video_is_playable(bad_vid)
|
|
shutil.copy(bad_vid, tmp_not_playable_vid.name)
|
|
output = gr.Video().postprocess(tmp_not_playable_vid.name).model_dump()
|
|
assert processing_utils.video_is_playable(output["video"]["path"])
|
|
|
|
@patch("pathlib.Path.exists", MagicMock(return_value=False))
|
|
@patch("gradio.components.video.FFmpeg")
|
|
def test_video_preprocessing_flips_video_for_webcam(self, mock_ffmpeg):
|
|
# Ensures that the cached temp video file is not used so that ffmpeg is called for each test
|
|
x_video = VideoData(video=FileData(path=media_data.BASE64_VIDEO["path"]))
|
|
video_input = gr.Video(sources=["webcam"])
|
|
_ = video_input.preprocess(x_video)
|
|
|
|
# Dict mapping filename to FFmpeg options
|
|
output_params = mock_ffmpeg.call_args_list[0][1]["outputs"]
|
|
assert "hflip" in list(output_params.values())[0]
|
|
assert "flip" in list(output_params.keys())[0]
|
|
|
|
mock_ffmpeg.reset_mock()
|
|
_ = gr.Video(
|
|
sources=["webcam"], mirror_webcam=False, include_audio=True
|
|
).preprocess(x_video)
|
|
mock_ffmpeg.assert_not_called()
|
|
|
|
mock_ffmpeg.reset_mock()
|
|
_ = gr.Video(sources=["upload"], format="mp4", include_audio=True).preprocess(
|
|
x_video
|
|
)
|
|
mock_ffmpeg.assert_not_called()
|
|
|
|
mock_ffmpeg.reset_mock()
|
|
output_file = gr.Video(
|
|
sources=["webcam"], mirror_webcam=True, format="avi"
|
|
).preprocess(x_video)
|
|
output_params = mock_ffmpeg.call_args_list[0][1]["outputs"]
|
|
assert "hflip" in list(output_params.values())[0]
|
|
assert "flip" in list(output_params.keys())[0]
|
|
assert ".avi" in list(output_params.keys())[0]
|
|
assert ".avi" in output_file
|
|
|
|
mock_ffmpeg.reset_mock()
|
|
output_file = gr.Video(
|
|
sources=["webcam"], mirror_webcam=False, format="avi", include_audio=False
|
|
).preprocess(x_video)
|
|
output_params = mock_ffmpeg.call_args_list[0][1]["outputs"]
|
|
assert list(output_params.values())[0] == ["-an"]
|
|
assert "flip" not in Path(list(output_params.keys())[0]).name
|
|
assert ".avi" in list(output_params.keys())[0]
|
|
assert ".avi" in output_file
|
|
|
|
|
|
class TestNames:
|
|
# This test ensures that `components.get_component_instance()` works correctly when instantiating from components
|
|
def test_no_duplicate_uncased_names(self, io_components):
|
|
unique_subclasses_uncased = {s.__name__.lower() for s in io_components}
|
|
assert len(io_components) == len(unique_subclasses_uncased)
|
|
|
|
|
|
class TestLabel:
|
|
def test_component_functions(self):
|
|
"""
|
|
Process, postprocess, deserialize
|
|
"""
|
|
y = "happy"
|
|
label_output = gr.Label()
|
|
label = label_output.postprocess(y).model_dump()
|
|
assert label == {"label": "happy", "confidences": None}
|
|
|
|
y = {3: 0.7, 1: 0.2, 0: 0.1}
|
|
label = label_output.postprocess(y).model_dump()
|
|
assert label == {
|
|
"label": 3,
|
|
"confidences": [
|
|
{"label": 3, "confidence": 0.7},
|
|
{"label": 1, "confidence": 0.2},
|
|
{"label": 0, "confidence": 0.1},
|
|
],
|
|
}
|
|
label_output = gr.Label(num_top_classes=2)
|
|
label = label_output.postprocess(y).model_dump()
|
|
|
|
assert label == {
|
|
"label": 3,
|
|
"confidences": [
|
|
{"label": 3, "confidence": 0.7},
|
|
{"label": 1, "confidence": 0.2},
|
|
],
|
|
}
|
|
with pytest.raises(ValueError):
|
|
label_output.postprocess([1, 2, 3]).model_dump()
|
|
|
|
test_file_dir = Path(Path(__file__).parent, "test_files")
|
|
path = str(Path(test_file_dir, "test_label_json.json"))
|
|
label_dict = label_output.postprocess(path).model_dump()
|
|
assert label_dict["label"] == "web site"
|
|
|
|
assert label_output.get_config() == {
|
|
"name": "label",
|
|
"show_label": True,
|
|
"num_top_classes": 2,
|
|
"value": {},
|
|
"label": None,
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"proxy_url": None,
|
|
"color": None,
|
|
"_selectable": False,
|
|
}
|
|
|
|
def test_color_argument(self):
|
|
label = gr.Label(value=-10, color="red")
|
|
assert label.get_config()["color"] == "red"
|
|
|
|
def test_in_interface(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
x_img = "test/test_files/bus.png"
|
|
|
|
def rgb_distribution(img):
|
|
rgb_dist = np.mean(img, axis=(0, 1))
|
|
rgb_dist /= np.sum(rgb_dist)
|
|
rgb_dist = np.round(rgb_dist, decimals=2)
|
|
return {
|
|
"red": rgb_dist[0],
|
|
"green": rgb_dist[1],
|
|
"blue": rgb_dist[2],
|
|
}
|
|
|
|
iface = gr.Interface(rgb_distribution, "image", "label")
|
|
output = iface(x_img)
|
|
assert output == {
|
|
"label": "red",
|
|
"confidences": [
|
|
{"label": "red", "confidence": 0.44},
|
|
{"label": "green", "confidence": 0.28},
|
|
{"label": "blue", "confidence": 0.28},
|
|
],
|
|
}
|
|
|
|
|
|
class TestHighlightedText:
|
|
def test_postprocess(self):
|
|
"""
|
|
postprocess
|
|
"""
|
|
component = gr.HighlightedText()
|
|
value = [
|
|
("", None),
|
|
("Wolfgang", "PER"),
|
|
(" lives in ", None),
|
|
("Berlin", "LOC"),
|
|
("", None),
|
|
]
|
|
result = [
|
|
{"token": "", "class_or_confidence": None},
|
|
{"token": "Wolfgang", "class_or_confidence": "PER"},
|
|
{"token": " lives in ", "class_or_confidence": None},
|
|
{"token": "Berlin", "class_or_confidence": "LOC"},
|
|
{"token": "", "class_or_confidence": None},
|
|
]
|
|
result_ = component.postprocess(value).model_dump()
|
|
assert result == result_
|
|
|
|
text = "Wolfgang lives in Berlin"
|
|
entities = [
|
|
{"entity": "PER", "start": 0, "end": 8},
|
|
{"entity": "LOC", "start": 18, "end": 24},
|
|
]
|
|
result_ = component.postprocess(
|
|
{"text": text, "entities": entities}
|
|
).model_dump()
|
|
assert result == result_
|
|
|
|
text = "Wolfgang lives in Berlin"
|
|
entities = [
|
|
{"entity_group": "PER", "start": 0, "end": 8},
|
|
{"entity": "LOC", "start": 18, "end": 24},
|
|
]
|
|
result_ = component.postprocess(
|
|
{"text": text, "entities": entities}
|
|
).model_dump()
|
|
assert result == result_
|
|
|
|
# Test split entity is merged when combine adjacent is set
|
|
text = "Wolfgang lives in Berlin"
|
|
entities = [
|
|
{"entity": "PER", "start": 0, "end": 4},
|
|
{"entity": "PER", "start": 4, "end": 8},
|
|
{"entity": "LOC", "start": 18, "end": 24},
|
|
]
|
|
# After a merge empty entries are stripped except the leading one
|
|
result_after_merge = [
|
|
{"token": "", "class_or_confidence": None},
|
|
{"token": "Wolfgang", "class_or_confidence": "PER"},
|
|
{"token": " lives in ", "class_or_confidence": None},
|
|
{"token": "Berlin", "class_or_confidence": "LOC"},
|
|
]
|
|
result_ = component.postprocess(
|
|
{"text": text, "entities": entities}
|
|
).model_dump()
|
|
assert result != result_
|
|
assert result_after_merge != result_
|
|
|
|
component = gr.HighlightedText(combine_adjacent=True)
|
|
result_ = component.postprocess(
|
|
{"text": text, "entities": entities}
|
|
).model_dump()
|
|
assert result_after_merge == result_
|
|
|
|
component = gr.HighlightedText()
|
|
|
|
text = "Wolfgang lives in Berlin"
|
|
entities = [
|
|
{"entity": "LOC", "start": 18, "end": 24},
|
|
{"entity": "PER", "start": 0, "end": 8},
|
|
]
|
|
result_ = component.postprocess(
|
|
{"text": text, "entities": entities}
|
|
).model_dump()
|
|
assert result == result_
|
|
|
|
text = "I live there"
|
|
entities = []
|
|
result_ = component.postprocess(
|
|
{"text": text, "entities": entities}
|
|
).model_dump()
|
|
assert [{"token": text, "class_or_confidence": None}] == result_
|
|
|
|
text = "Wolfgang"
|
|
entities = [
|
|
{"entity": "PER", "start": 0, "end": 8},
|
|
]
|
|
result_ = component.postprocess(
|
|
{"text": text, "entities": entities}
|
|
).model_dump()
|
|
assert [
|
|
{"token": "", "class_or_confidence": None},
|
|
{"token": text, "class_or_confidence": "PER"},
|
|
{"token": "", "class_or_confidence": None},
|
|
] == result_
|
|
|
|
def test_component_functions(self):
|
|
"""
|
|
get_config
|
|
"""
|
|
ht_output = gr.HighlightedText(color_map={"pos": "green", "neg": "red"})
|
|
assert ht_output.get_config() == {
|
|
"color_map": {"pos": "green", "neg": "red"},
|
|
"name": "highlightedtext",
|
|
"show_label": True,
|
|
"label": None,
|
|
"show_legend": False,
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"value": None,
|
|
"proxy_url": None,
|
|
"_selectable": False,
|
|
"combine_adjacent": False,
|
|
"adjacent_separator": "",
|
|
"interactive": None,
|
|
}
|
|
|
|
def test_in_interface(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
|
|
def highlight_vowels(sentence):
|
|
phrases, cur_phrase = [], ""
|
|
vowels, mode = "aeiou", None
|
|
for letter in sentence:
|
|
letter_mode = "vowel" if letter in vowels else "non"
|
|
if mode is None:
|
|
mode = letter_mode
|
|
elif mode != letter_mode:
|
|
phrases.append((cur_phrase, mode))
|
|
cur_phrase = ""
|
|
mode = letter_mode
|
|
cur_phrase += letter
|
|
phrases.append((cur_phrase, mode))
|
|
return phrases
|
|
|
|
iface = gr.Interface(highlight_vowels, "text", "highlight")
|
|
output = iface("Helloooo")
|
|
assert output == [
|
|
{"token": "H", "class_or_confidence": "non"},
|
|
{"token": "e", "class_or_confidence": "vowel"},
|
|
{"token": "ll", "class_or_confidence": "non"},
|
|
{"token": "oooo", "class_or_confidence": "vowel"},
|
|
]
|
|
|
|
|
|
class TestAnnotatedImage:
|
|
def test_postprocess(self):
|
|
"""
|
|
postprocess
|
|
"""
|
|
component = gr.AnnotatedImage()
|
|
img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
|
mask1 = [40, 40, 50, 50]
|
|
mask2 = np.zeros((100, 100), dtype=np.uint8)
|
|
mask2[10:20, 10:20] = 1
|
|
|
|
input = (img, [(mask1, "mask1"), (mask2, "mask2")])
|
|
result = component.postprocess(input).model_dump()
|
|
|
|
base_img_out = PIL.Image.open(result["image"]["path"])
|
|
|
|
assert result["annotations"][0]["label"] == "mask1"
|
|
|
|
mask1_img_out = PIL.Image.open(result["annotations"][0]["image"]["path"])
|
|
assert mask1_img_out.size == base_img_out.size
|
|
mask1_array_out = np.array(mask1_img_out)
|
|
assert np.max(mask1_array_out[40:50, 40:50]) == 255
|
|
assert np.max(mask1_array_out[50:60, 50:60]) == 0
|
|
|
|
def test_component_functions(self):
|
|
ht_output = gr.AnnotatedImage(label="sections", show_legend=False)
|
|
assert ht_output.get_config() == {
|
|
"name": "annotatedimage",
|
|
"show_label": True,
|
|
"label": "sections",
|
|
"show_legend": False,
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"color_map": None,
|
|
"height": None,
|
|
"width": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"value": None,
|
|
"proxy_url": None,
|
|
"_selectable": False,
|
|
}
|
|
|
|
def test_in_interface(self):
|
|
def mask(img):
|
|
top_left_corner = [0, 0, 20, 20]
|
|
random_mask = np.random.randint(0, 2, img.shape[:2])
|
|
return (img, [(top_left_corner, "left corner"), (random_mask, "random")])
|
|
|
|
iface = gr.Interface(mask, "image", gr.AnnotatedImage())
|
|
output = iface("test/test_files/bus.png")
|
|
output_img, (mask1, _) = output["image"], output["annotations"]
|
|
input_img = PIL.Image.open("test/test_files/bus.png")
|
|
output_img = PIL.Image.open(output_img)
|
|
mask1_img = PIL.Image.open(mask1["image"])
|
|
|
|
assert output_img.size == input_img.size
|
|
assert mask1_img.size == input_img.size
|
|
|
|
|
|
class TestChatbot:
|
|
def test_component_functions(self):
|
|
"""
|
|
Postprocess, get_config
|
|
"""
|
|
chatbot = gr.Chatbot()
|
|
assert chatbot.postprocess(
|
|
[["You are **cool**\nand fun", "so are *you*"]]
|
|
).model_dump() == [("You are **cool**\nand fun", "so are *you*")]
|
|
|
|
multimodal_msg = [
|
|
[("test/test_files/video_sample.mp4",), "cool video"],
|
|
[("test/test_files/audio_sample.wav",), "cool audio"],
|
|
[("test/test_files/bus.png", "A bus"), "cool pic"],
|
|
[(Path("test/test_files/video_sample.mp4"),), "cool video"],
|
|
[(Path("test/test_files/audio_sample.wav"),), "cool audio"],
|
|
[(Path("test/test_files/bus.png"), "A bus"), "cool pic"],
|
|
]
|
|
postprocessed_multimodal_msg = chatbot.postprocess(multimodal_msg).model_dump()
|
|
for msg in postprocessed_multimodal_msg:
|
|
assert "file" in msg[0]
|
|
assert msg[1] in {"cool video", "cool audio", "cool pic"}
|
|
assert msg[0]["file"]["path"].split(".")[-1] in {"mp4", "wav", "png"}
|
|
if msg[0]["alt_text"]:
|
|
assert msg[0]["alt_text"] == "A bus"
|
|
|
|
assert chatbot.get_config() == {
|
|
"value": [],
|
|
"label": None,
|
|
"show_label": True,
|
|
"name": "chatbot",
|
|
"show_share_button": False,
|
|
"visible": True,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"height": None,
|
|
"proxy_url": None,
|
|
"_selectable": False,
|
|
"latex_delimiters": [{"display": True, "left": "$$", "right": "$$"}],
|
|
"likeable": False,
|
|
"rtl": False,
|
|
"show_copy_button": False,
|
|
"avatar_images": [None, None],
|
|
"sanitize_html": True,
|
|
"render_markdown": True,
|
|
"bubble_full_width": True,
|
|
"line_breaks": True,
|
|
"layout": None,
|
|
}
|
|
|
|
def test_avatar_images_are_moved_to_cache(self):
|
|
chatbot = gr.Chatbot(avatar_images=("test/test_files/bus.png", None))
|
|
assert chatbot.avatar_images[0]
|
|
assert utils.is_in_or_equal(chatbot.avatar_images[0], chatbot.GRADIO_CACHE)
|
|
assert chatbot.avatar_images[1] is None
|
|
|
|
|
|
class TestJSON:
|
|
def test_component_functions(self):
|
|
"""
|
|
Postprocess
|
|
"""
|
|
js_output = gr.JSON()
|
|
assert js_output.postprocess('{"a":1, "b": 2}'), '"{\\"a\\":1, \\"b\\": 2}"'
|
|
assert js_output.get_config() == {
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"value": None,
|
|
"show_label": True,
|
|
"label": None,
|
|
"name": "json",
|
|
"proxy_url": None,
|
|
"_selectable": False,
|
|
}
|
|
|
|
def test_chatbot_selectable_in_config(self):
|
|
with gr.Blocks() as demo:
|
|
cb = gr.Chatbot(label="Chatbot")
|
|
cb.like(lambda: print("foo"))
|
|
gr.Chatbot(label="Chatbot2")
|
|
|
|
assertion_count = 0
|
|
for component in demo.config["components"]:
|
|
if component["props"]["label"] == "Chatbot":
|
|
assertion_count += 1
|
|
assert component["props"]["likeable"]
|
|
elif component["props"]["label"] == "Chatbot2":
|
|
assertion_count += 1
|
|
assert not component["props"]["likeable"]
|
|
|
|
assert assertion_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_in_interface(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
|
|
def get_avg_age_per_gender(data):
|
|
return {
|
|
"M": int(data[data["gender"] == "M"]["age"].mean()),
|
|
"F": int(data[data["gender"] == "F"]["age"].mean()),
|
|
"O": int(data[data["gender"] == "O"]["age"].mean()),
|
|
}
|
|
|
|
iface = gr.Interface(
|
|
get_avg_age_per_gender,
|
|
gr.Dataframe(headers=["gender", "age"]),
|
|
"json",
|
|
)
|
|
y_data = [
|
|
["M", 30],
|
|
["F", 20],
|
|
["M", 40],
|
|
["O", 20],
|
|
["F", 30],
|
|
]
|
|
assert (
|
|
await iface.process_api(
|
|
0, [{"data": y_data, "headers": ["gender", "age"]}], state={}
|
|
)
|
|
)["data"][0] == {
|
|
"M": 35,
|
|
"F": 25,
|
|
"O": 20,
|
|
}
|
|
|
|
|
|
class TestHTML:
|
|
def test_component_functions(self):
|
|
"""
|
|
get_config
|
|
"""
|
|
html_component = gr.components.HTML("#Welcome onboard", label="HTML Input")
|
|
assert html_component.get_config() == {
|
|
"value": "#Welcome onboard",
|
|
"label": "HTML Input",
|
|
"show_label": True,
|
|
"visible": True,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"proxy_url": None,
|
|
"name": "html",
|
|
"_selectable": False,
|
|
}
|
|
|
|
def test_in_interface(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
|
|
def bold_text(text):
|
|
return f"<strong>{text}</strong>"
|
|
|
|
iface = gr.Interface(bold_text, "text", "html")
|
|
assert iface("test") == "<strong>test</strong>"
|
|
|
|
|
|
class TestMarkdown:
|
|
def test_component_functions(self):
|
|
markdown_component = gr.Markdown("# Let's learn about $x$", label="Markdown")
|
|
assert markdown_component.get_config()["value"] == "# Let's learn about $x$"
|
|
|
|
def test_in_interface(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
iface = gr.Interface(lambda x: x, "text", "markdown")
|
|
input_data = " Here's an [image](https://gradio.app/images/gradio_logo.png)"
|
|
output_data = iface(input_data)
|
|
assert output_data == input_data.strip()
|
|
|
|
|
|
class TestModel3D:
|
|
def test_component_functions(self):
|
|
"""
|
|
get_config
|
|
"""
|
|
model_component = gr.components.Model3D(None, label="Model")
|
|
assert model_component.get_config() == {
|
|
"value": None,
|
|
"clear_color": [0, 0, 0, 0],
|
|
"label": "Model",
|
|
"show_label": True,
|
|
"container": True,
|
|
"scale": None,
|
|
"min_width": 160,
|
|
"visible": True,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"proxy_url": None,
|
|
"interactive": None,
|
|
"name": "model3d",
|
|
"camera_position": (None, None, None),
|
|
"height": None,
|
|
"zoom_speed": 1,
|
|
"pan_speed": 1,
|
|
"_selectable": False,
|
|
}
|
|
|
|
file = "test/test_files/Box.gltf"
|
|
output1 = model_component.postprocess(file)
|
|
output2 = model_component.postprocess(Path(file))
|
|
assert output1
|
|
assert output2
|
|
assert Path(output1.path).name == Path(output2.path).name
|
|
|
|
def test_in_interface(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
iface = gr.Interface(lambda x: x, "model3d", "model3d")
|
|
input_data = "test/test_files/Box.gltf"
|
|
output_data = iface(input_data)
|
|
assert output_data.endswith(".gltf")
|
|
|
|
|
|
class TestColorPicker:
|
|
def test_component_functions(self):
|
|
"""
|
|
Preprocess, postprocess, serialize, tokenize, get_config
|
|
"""
|
|
color_picker_input = gr.ColorPicker()
|
|
assert color_picker_input.preprocess("#000000") == "#000000"
|
|
assert color_picker_input.postprocess("#000000") == "#000000"
|
|
assert color_picker_input.postprocess(None) is None
|
|
assert color_picker_input.postprocess("#FFFFFF") == "#FFFFFF"
|
|
|
|
assert color_picker_input.get_config() == {
|
|
"value": None,
|
|
"show_label": True,
|
|
"label": None,
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"interactive": None,
|
|
"proxy_url": None,
|
|
"name": "colorpicker",
|
|
"info": None,
|
|
"_selectable": False,
|
|
}
|
|
|
|
def test_in_interface_as_input(self):
|
|
"""
|
|
Interface, process
|
|
"""
|
|
iface = gr.Interface(lambda x: x, "colorpicker", "colorpicker")
|
|
assert iface("#000000") == "#000000"
|
|
|
|
def test_in_interface_as_output(self):
|
|
"""
|
|
Interface, process
|
|
|
|
"""
|
|
iface = gr.Interface(lambda x: x, "colorpicker", gr.ColorPicker())
|
|
assert iface("#000000") == "#000000"
|
|
|
|
def test_static(self):
|
|
"""
|
|
postprocess
|
|
"""
|
|
component = gr.ColorPicker("#000000")
|
|
assert component.get_config().get("value") == "#000000"
|
|
|
|
|
|
class TestGallery:
|
|
def test_postprocess(self):
|
|
url = "https://huggingface.co/Norod78/SDXL-VintageMagStyle-Lora/resolve/main/Examples/00015-20230906102032-7778-Wonderwoman VintageMagStyle _lora_SDXL-VintageMagStyle-Lora_1_, Very detailed, clean, high quality, sharp image.jpg"
|
|
gallery = gr.Gallery([url])
|
|
assert gallery.get_config()["value"] == [
|
|
{
|
|
"image": {
|
|
"path": url,
|
|
"orig_name": "00015-20230906102032-7778-Wonderwoman VintageMagStyle _lora_SDXL-VintageMagStyle-Lora_1_, Very detailed, clean, high quality, sharp image.jpg",
|
|
"mime_type": None,
|
|
"size": None,
|
|
"url": url,
|
|
},
|
|
"caption": None,
|
|
}
|
|
]
|
|
|
|
def test_gallery(self):
|
|
gallery = gr.Gallery()
|
|
Path(Path(__file__).parent, "test_files")
|
|
|
|
postprocessed_gallery = gallery.postprocess(
|
|
[
|
|
(str(Path("test/test_files/foo.png")), "foo_caption"),
|
|
(Path("test/test_files/bar.png"), "bar_caption"),
|
|
str(Path("test/test_files/baz.png")),
|
|
Path("test/test_files/qux.png"),
|
|
]
|
|
).model_dump()
|
|
|
|
# Using str(Path(...)) to ensure that the test passes on all platforms
|
|
assert postprocessed_gallery == [
|
|
{
|
|
"image": {
|
|
"path": str(Path("test") / "test_files" / "foo.png"),
|
|
"orig_name": "foo.png",
|
|
"mime_type": None,
|
|
"size": None,
|
|
"url": None,
|
|
},
|
|
"caption": "foo_caption",
|
|
},
|
|
{
|
|
"image": {
|
|
"path": str(Path("test") / "test_files" / "bar.png"),
|
|
"orig_name": "bar.png",
|
|
"mime_type": None,
|
|
"size": None,
|
|
"url": None,
|
|
},
|
|
"caption": "bar_caption",
|
|
},
|
|
{
|
|
"image": {
|
|
"path": str(Path("test") / "test_files" / "baz.png"),
|
|
"orig_name": "baz.png",
|
|
"mime_type": None,
|
|
"size": None,
|
|
"url": None,
|
|
},
|
|
"caption": None,
|
|
},
|
|
{
|
|
"image": {
|
|
"path": str(Path("test") / "test_files" / "qux.png"),
|
|
"orig_name": "qux.png",
|
|
"mime_type": None,
|
|
"size": None,
|
|
"url": None,
|
|
},
|
|
"caption": None,
|
|
},
|
|
]
|
|
|
|
def test_gallery_preprocess(self):
|
|
from gradio.components.gallery import GalleryData, GalleryImage
|
|
|
|
gallery = gr.Gallery()
|
|
img = GalleryImage(image=FileData(path="test/test_files/bus.png"))
|
|
data = GalleryData(root=[img])
|
|
|
|
preprocess = gallery.preprocess(data)
|
|
assert preprocess[0][0] == "test/test_files/bus.png"
|
|
|
|
gallery = gr.Gallery(type="numpy")
|
|
assert (
|
|
gallery.preprocess(data)[0][0]
|
|
== np.array(PIL.Image.open("test/test_files/bus.png"))
|
|
).all()
|
|
|
|
gallery = gr.Gallery(type="pil")
|
|
assert gallery.preprocess(data)[0][0] == PIL.Image.open(
|
|
"test/test_files/bus.png"
|
|
)
|
|
|
|
img_captions = GalleryImage(
|
|
image=FileData(path="test/test_files/bus.png"), caption="bus"
|
|
)
|
|
data = GalleryData(root=[img_captions])
|
|
preprocess = gr.Gallery().preprocess(data)
|
|
assert preprocess[0] == ("test/test_files/bus.png", "bus")
|
|
|
|
|
|
class TestState:
|
|
def test_as_component(self):
|
|
state = gr.State(value=5)
|
|
assert state.preprocess(10) == 10
|
|
assert state.preprocess("abc") == "abc"
|
|
assert state.stateful
|
|
|
|
def test_initial_value_deepcopy(self):
|
|
with pytest.raises(TypeError):
|
|
gr.State(value=gr) # modules are not deepcopyable
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_in_interface(self):
|
|
def test(x, y=" def"):
|
|
return (x + y, x + y)
|
|
|
|
io = gr.Interface(test, ["text", "state"], ["text", "state"])
|
|
result = await io.call_function(0, ["abc"])
|
|
assert result["prediction"][0] == "abc def"
|
|
result = await io.call_function(0, ["abc", result["prediction"][0]])
|
|
assert result["prediction"][0] == "abcabc def"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_in_blocks(self):
|
|
with gr.Blocks() as demo:
|
|
score = gr.State()
|
|
btn = gr.Button()
|
|
btn.click(lambda x: x + 1, score, score)
|
|
|
|
result = await demo.call_function(0, [0])
|
|
assert result["prediction"] == 1
|
|
result = await demo.call_function(0, [result["prediction"]])
|
|
assert result["prediction"] == 2
|
|
|
|
|
|
def test_dataframe_process_example_converts_dataframes():
|
|
df_comp = gr.Dataframe()
|
|
assert df_comp.process_example(
|
|
pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})
|
|
) == [
|
|
[1, 5],
|
|
[2, 6],
|
|
[3, 7],
|
|
[4, 8],
|
|
]
|
|
assert df_comp.process_example(np.array([[1, 2], [3, 4.0]])) == [
|
|
[1.0, 2.0],
|
|
[3.0, 4.0],
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("component", [gr.Model3D, gr.File, gr.Audio])
|
|
def test_process_example_returns_file_basename(component):
|
|
component = component()
|
|
assert (
|
|
component.process_example("/home/freddy/sources/example.ext") == "example.ext"
|
|
)
|
|
assert component.process_example(None) == ""
|
|
|
|
|
|
@patch(
|
|
"gradio.components.Component.process_example",
|
|
spec=gr.components.Component.process_example,
|
|
)
|
|
@patch("gradio.components.Image.process_example", spec=gr.Image.process_example)
|
|
@patch("gradio.components.File.process_example", spec=gr.File.process_example)
|
|
@patch("gradio.components.Dataframe.process_example", spec=gr.DataFrame.process_example)
|
|
@patch("gradio.components.Model3D.process_example", spec=gr.Model3D.process_example)
|
|
def test_dataset_calls_process_example(*mocks):
|
|
gr.Dataset(
|
|
components=[gr.Dataframe(), gr.File(), gr.Image(), gr.Model3D(), gr.Textbox()],
|
|
samples=[
|
|
[
|
|
pd.DataFrame({"a": np.array([1, 2, 3])}),
|
|
"foo.png",
|
|
"bar.jpeg",
|
|
"duck.obj",
|
|
"hello",
|
|
]
|
|
],
|
|
)
|
|
assert all(m.called for m in mocks)
|
|
|
|
|
|
cars = vega_datasets.data.cars()
|
|
stocks = vega_datasets.data.stocks()
|
|
barley = vega_datasets.data.barley()
|
|
simple = pd.DataFrame(
|
|
{
|
|
"a": ["A", "B", "C", "D", "E", "F", "G", "H", "I"],
|
|
"b": [28, 55, 43, 91, 81, 53, 19, 87, 52],
|
|
}
|
|
)
|
|
|
|
|
|
class TestScatterPlot:
|
|
@patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")})
|
|
def test_get_config(self):
|
|
print(gr.ScatterPlot().get_config())
|
|
assert gr.ScatterPlot().get_config() == {
|
|
"caption": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"interactive": None,
|
|
"label": None,
|
|
"name": "plot",
|
|
"bokeh_version": "3.0.3",
|
|
"show_actions_button": False,
|
|
"proxy_url": None,
|
|
"show_label": True,
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"value": None,
|
|
"visible": True,
|
|
"x": None,
|
|
"y": None,
|
|
"color": None,
|
|
"size": None,
|
|
"shape": None,
|
|
"title": None,
|
|
"tooltip": None,
|
|
"x_title": None,
|
|
"y_title": None,
|
|
"color_legend_title": None,
|
|
"size_legend_title": None,
|
|
"shape_legend_title": None,
|
|
"color_legend_position": None,
|
|
"size_legend_position": None,
|
|
"shape_legend_position": None,
|
|
"height": None,
|
|
"width": None,
|
|
"x_lim": None,
|
|
"y_lim": None,
|
|
"x_label_angle": None,
|
|
"y_label_angle": None,
|
|
"_selectable": False,
|
|
}
|
|
|
|
def test_no_color(self):
|
|
plot = gr.ScatterPlot(
|
|
x="Horsepower",
|
|
y="Miles_per_Gallon",
|
|
tooltip="Name",
|
|
title="Car Data",
|
|
x_title="Horse",
|
|
)
|
|
output = plot.postprocess(cars).model_dump()
|
|
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
|
config = json.loads(output["plot"])
|
|
assert config["encoding"]["x"]["field"] == "Horsepower"
|
|
assert config["encoding"]["x"]["title"] == "Horse"
|
|
assert config["encoding"]["y"]["field"] == "Miles_per_Gallon"
|
|
assert config["title"] == "Car Data"
|
|
assert "height" not in config
|
|
assert "width" not in config
|
|
|
|
def test_no_interactive(self):
|
|
plot = gr.ScatterPlot(
|
|
x="Horsepower", y="Miles_per_Gallon", tooltip="Name", interactive=False
|
|
)
|
|
output = plot.postprocess(cars).model_dump()
|
|
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
|
config = json.loads(output["plot"])
|
|
assert "selection" not in config
|
|
|
|
def test_height_width(self):
|
|
plot = gr.ScatterPlot(
|
|
x="Horsepower", y="Miles_per_Gallon", height=100, width=200
|
|
)
|
|
output = plot.postprocess(cars).model_dump()
|
|
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
|
config = json.loads(output["plot"])
|
|
assert config["height"] == 100
|
|
assert config["width"] == 200
|
|
|
|
def test_xlim_ylim(self):
|
|
plot = gr.ScatterPlot(
|
|
x="Horsepower", y="Miles_per_Gallon", x_lim=[200, 400], y_lim=[300, 500]
|
|
)
|
|
output = plot.postprocess(cars).model_dump()
|
|
config = json.loads(output["plot"])
|
|
assert config["encoding"]["x"]["scale"] == {"domain": [200, 400]}
|
|
assert config["encoding"]["y"]["scale"] == {"domain": [300, 500]}
|
|
|
|
def test_color_encoding(self):
|
|
plot = gr.ScatterPlot(
|
|
x="Horsepower",
|
|
y="Miles_per_Gallon",
|
|
tooltip="Name",
|
|
title="Car Data",
|
|
color="Origin",
|
|
)
|
|
output = plot.postprocess(cars).model_dump()
|
|
config = json.loads(output["plot"])
|
|
assert config["encoding"]["color"]["field"] == "Origin"
|
|
assert config["encoding"]["color"]["scale"] == {
|
|
"domain": ["USA", "Europe", "Japan"],
|
|
"range": [0, 1, 2],
|
|
}
|
|
assert config["encoding"]["color"]["type"] == "nominal"
|
|
|
|
def test_two_encodings(self):
|
|
plot = gr.ScatterPlot(
|
|
show_label=False,
|
|
title="Two encodings",
|
|
x="Horsepower",
|
|
y="Miles_per_Gallon",
|
|
color="Acceleration",
|
|
shape="Origin",
|
|
)
|
|
output = plot.postprocess(cars).model_dump()
|
|
config = json.loads(output["plot"])
|
|
assert config["encoding"]["color"]["field"] == "Acceleration"
|
|
assert config["encoding"]["color"]["scale"] == {
|
|
"domain": [cars.Acceleration.min(), cars.Acceleration.max()],
|
|
"range": [0, 1],
|
|
}
|
|
assert config["encoding"]["color"]["type"] == "quantitative"
|
|
|
|
assert config["encoding"]["shape"]["field"] == "Origin"
|
|
assert config["encoding"]["shape"]["type"] == "nominal"
|
|
|
|
def test_legend_position(self):
|
|
plot = gr.ScatterPlot(
|
|
show_label=False,
|
|
title="Two encodings",
|
|
x="Horsepower",
|
|
y="Miles_per_Gallon",
|
|
color="Acceleration",
|
|
color_legend_position="none",
|
|
color_legend_title="Foo",
|
|
shape="Origin",
|
|
shape_legend_position="none",
|
|
shape_legend_title="Bar",
|
|
size="Acceleration",
|
|
size_legend_title="Accel",
|
|
size_legend_position="none",
|
|
)
|
|
output = plot.postprocess(cars).model_dump()
|
|
config = json.loads(output["plot"])
|
|
assert config["encoding"]["color"]["legend"] is None
|
|
assert config["encoding"]["shape"]["legend"] is None
|
|
assert config["encoding"]["size"]["legend"] is None
|
|
|
|
def test_scatterplot_accepts_fn_as_value(self):
|
|
plot = gr.ScatterPlot(
|
|
value=lambda: cars.sample(frac=0.1, replace=False),
|
|
x="Horsepower",
|
|
y="Miles_per_Gallon",
|
|
color="Origin",
|
|
)
|
|
assert isinstance(plot.value, dict)
|
|
assert isinstance(plot.value["plot"], str)
|
|
|
|
|
|
class TestLinePlot:
|
|
@patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")})
|
|
def test_get_config(self):
|
|
assert gr.LinePlot().get_config() == {
|
|
"caption": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"interactive": None,
|
|
"label": None,
|
|
"name": "plot",
|
|
"bokeh_version": "3.0.3",
|
|
"show_actions_button": False,
|
|
"proxy_url": None,
|
|
"show_label": True,
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"value": None,
|
|
"visible": True,
|
|
"x": None,
|
|
"y": None,
|
|
"color": None,
|
|
"stroke_dash": None,
|
|
"overlay_point": None,
|
|
"title": None,
|
|
"tooltip": None,
|
|
"x_title": None,
|
|
"y_title": None,
|
|
"color_legend_title": None,
|
|
"stroke_dash_legend_title": None,
|
|
"color_legend_position": None,
|
|
"stroke_dash_legend_position": None,
|
|
"height": None,
|
|
"width": None,
|
|
"x_lim": None,
|
|
"y_lim": None,
|
|
"x_label_angle": None,
|
|
"y_label_angle": None,
|
|
"_selectable": False,
|
|
}
|
|
|
|
def test_no_color(self):
|
|
plot = gr.LinePlot(
|
|
x="date",
|
|
y="price",
|
|
tooltip=["symbol", "price"],
|
|
title="Stock Performance",
|
|
x_title="Trading Day",
|
|
)
|
|
output = plot.postprocess(stocks).model_dump()
|
|
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
|
config = json.loads(output["plot"])
|
|
for layer in config["layer"]:
|
|
assert layer["mark"]["type"] in ["line", "point"]
|
|
assert layer["encoding"]["x"]["field"] == "date"
|
|
assert layer["encoding"]["x"]["title"] == "Trading Day"
|
|
assert layer["encoding"]["y"]["field"] == "price"
|
|
|
|
assert config["title"] == "Stock Performance"
|
|
assert "height" not in config
|
|
assert "width" not in config
|
|
|
|
def test_height_width(self):
|
|
plot = gr.LinePlot(x="date", y="price", height=100, width=200)
|
|
output = plot.postprocess(stocks).model_dump()
|
|
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
|
config = json.loads(output["plot"])
|
|
assert config["height"] == 100
|
|
assert config["width"] == 200
|
|
|
|
def test_xlim_ylim(self):
|
|
plot = gr.LinePlot(x="date", y="price", x_lim=[200, 400], y_lim=[300, 500])
|
|
output = plot.postprocess(stocks).model_dump()
|
|
config = json.loads(output["plot"])
|
|
for layer in config["layer"]:
|
|
assert layer["encoding"]["x"]["scale"] == {"domain": [200, 400]}
|
|
assert layer["encoding"]["y"]["scale"] == {"domain": [300, 500]}
|
|
|
|
def test_color_encoding(self):
|
|
plot = gr.LinePlot(
|
|
x="date", y="price", tooltip="symbol", color="symbol", overlay_point=True
|
|
)
|
|
output = plot.postprocess(stocks).model_dump()
|
|
config = json.loads(output["plot"])
|
|
for layer in config["layer"]:
|
|
assert layer["encoding"]["color"]["field"] == "symbol"
|
|
assert layer["encoding"]["color"]["scale"] == {
|
|
"domain": ["MSFT", "AMZN", "IBM", "GOOG", "AAPL"],
|
|
"range": [0, 1, 2, 3, 4],
|
|
}
|
|
assert layer["encoding"]["color"]["type"] == "nominal"
|
|
if layer["mark"]["type"] == "point":
|
|
assert layer["encoding"]["opacity"] == {}
|
|
|
|
def test_lineplot_accepts_fn_as_value(self):
|
|
plot = gr.LinePlot(
|
|
value=lambda: stocks.sample(frac=0.1, replace=False),
|
|
x="date",
|
|
y="price",
|
|
color="symbol",
|
|
)
|
|
assert isinstance(plot.value, dict)
|
|
assert isinstance(plot.value["plot"], str)
|
|
|
|
|
|
class TestBarPlot:
|
|
@patch.dict("sys.modules", {"bokeh": MagicMock(__version__="3.0.3")})
|
|
def test_get_config(self):
|
|
assert gr.BarPlot().get_config() == {
|
|
"caption": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"interactive": None,
|
|
"label": None,
|
|
"name": "plot",
|
|
"bokeh_version": "3.0.3",
|
|
"show_actions_button": False,
|
|
"proxy_url": None,
|
|
"show_label": True,
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"value": None,
|
|
"visible": True,
|
|
"x": None,
|
|
"y": None,
|
|
"color": None,
|
|
"vertical": True,
|
|
"group": None,
|
|
"title": None,
|
|
"tooltip": None,
|
|
"x_title": None,
|
|
"y_title": None,
|
|
"color_legend_title": None,
|
|
"group_title": None,
|
|
"color_legend_position": None,
|
|
"height": None,
|
|
"width": None,
|
|
"y_lim": None,
|
|
"x_label_angle": None,
|
|
"y_label_angle": None,
|
|
"sort": None,
|
|
"_selectable": False,
|
|
}
|
|
|
|
def test_no_color(self):
|
|
plot = gr.BarPlot(
|
|
x="a",
|
|
y="b",
|
|
tooltip=["a", "b"],
|
|
title="Made Up Bar Plot",
|
|
x_title="Variable A",
|
|
sort="x",
|
|
)
|
|
output = plot.postprocess(simple).model_dump()
|
|
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
|
assert output["chart"] == "bar"
|
|
config = json.loads(output["plot"])
|
|
assert config["encoding"]["x"]["sort"] == "x"
|
|
assert config["encoding"]["x"]["field"] == "a"
|
|
assert config["encoding"]["x"]["title"] == "Variable A"
|
|
assert config["encoding"]["y"]["field"] == "b"
|
|
assert config["encoding"]["y"]["title"] == "b"
|
|
|
|
assert config["title"] == "Made Up Bar Plot"
|
|
assert "height" not in config
|
|
assert "width" not in config
|
|
|
|
def test_height_width(self):
|
|
plot = gr.BarPlot(x="a", y="b", height=100, width=200)
|
|
output = plot.postprocess(simple).model_dump()
|
|
assert sorted(output.keys()) == ["chart", "plot", "type"]
|
|
config = json.loads(output["plot"])
|
|
assert config["height"] == 100
|
|
assert config["width"] == 200
|
|
|
|
def test_ylim(self):
|
|
plot = gr.BarPlot(x="a", y="b", y_lim=[15, 100])
|
|
output = plot.postprocess(simple).model_dump()
|
|
config = json.loads(output["plot"])
|
|
assert config["encoding"]["y"]["scale"] == {"domain": [15, 100]}
|
|
|
|
def test_horizontal(self):
|
|
output = gr.BarPlot(
|
|
simple,
|
|
x="a",
|
|
y="b",
|
|
x_title="Variable A",
|
|
y_title="Variable B",
|
|
title="Simple Bar Plot with made up data",
|
|
tooltip=["a", "b"],
|
|
vertical=False,
|
|
y_lim=[20, 100],
|
|
).get_config()
|
|
assert output["value"]["chart"] == "bar"
|
|
config = json.loads(output["value"]["plot"])
|
|
assert config["encoding"]["x"]["field"] == "b"
|
|
assert config["encoding"]["x"]["scale"] == {"domain": [20, 100]}
|
|
assert config["encoding"]["x"]["title"] == "Variable B"
|
|
|
|
assert config["encoding"]["y"]["field"] == "a"
|
|
assert config["encoding"]["y"]["title"] == "Variable A"
|
|
|
|
def test_barplot_accepts_fn_as_value(self):
|
|
plot = gr.BarPlot(
|
|
value=lambda: barley.sample(frac=0.1, replace=False),
|
|
x="year",
|
|
y="yield",
|
|
)
|
|
assert isinstance(plot.value, dict)
|
|
assert isinstance(plot.value["plot"], str)
|
|
|
|
|
|
class TestCode:
|
|
def test_component_functions(self):
|
|
"""
|
|
Preprocess, postprocess, serialize, get_config
|
|
"""
|
|
code = gr.Code()
|
|
|
|
assert code.preprocess("# hello friends") == "# hello friends"
|
|
assert code.preprocess("def fn(a):\n return a") == "def fn(a):\n return a"
|
|
|
|
assert (
|
|
code.postprocess(
|
|
"""
|
|
def fn(a):
|
|
return a
|
|
"""
|
|
)
|
|
== """def fn(a):
|
|
return a"""
|
|
)
|
|
|
|
test_file_dir = Path(Path(__file__).parent, "test_files")
|
|
path = str(Path(test_file_dir, "test_label_json.json"))
|
|
with open(path) as f:
|
|
assert code.postprocess(path) == path
|
|
assert code.postprocess((path,)) == f.read()
|
|
|
|
assert code.get_config() == {
|
|
"value": None,
|
|
"language": None,
|
|
"lines": 5,
|
|
"name": "code",
|
|
"show_label": True,
|
|
"label": None,
|
|
"container": True,
|
|
"min_width": 160,
|
|
"scale": None,
|
|
"elem_id": None,
|
|
"elem_classes": [],
|
|
"visible": True,
|
|
"interactive": None,
|
|
"proxy_url": None,
|
|
"_selectable": False,
|
|
}
|
|
|
|
|
|
class TestFileExplorer:
|
|
def test_component_functions(self):
|
|
"""
|
|
Preprocess, get_config
|
|
"""
|
|
file_explorer = gr.FileExplorer(file_count="single")
|
|
|
|
config = file_explorer.get_config()
|
|
assert config["glob"] == "**/*.*"
|
|
assert config["value"] is None
|
|
assert config["file_count"] == "single"
|
|
assert config["server_fns"] == ["ls"]
|
|
|
|
input_data = FileExplorerData(root=[["test/test_files/bus.png"]])
|
|
preprocessed_data = file_explorer.preprocess(input_data)
|
|
assert isinstance(preprocessed_data, str)
|
|
assert Path(preprocessed_data).name == "bus.png"
|
|
|
|
input_data = FileExplorerData(root=[])
|
|
preprocessed_data = file_explorer.preprocess(input_data)
|
|
assert preprocessed_data is None
|
|
|
|
file_explorer = gr.FileExplorer(file_count="multiple")
|
|
|
|
config = file_explorer.get_config()
|
|
assert config["glob"] == "**/*.*"
|
|
assert config["value"] is None
|
|
assert config["file_count"] == "multiple"
|
|
assert config["server_fns"] == ["ls"]
|
|
|
|
input_data = FileExplorerData(root=[["test/test_files/bus.png"]])
|
|
preprocessed_data = file_explorer.preprocess(input_data)
|
|
assert isinstance(preprocessed_data, list)
|
|
assert Path(preprocessed_data[0]).name == "bus.png"
|
|
|
|
input_data = FileExplorerData(root=[])
|
|
preprocessed_data = file_explorer.preprocess(input_data)
|
|
assert preprocessed_data == []
|
|
|
|
def test_file_explorer_dir_only_glob(self, tmpdir):
|
|
tmpdir.mkdir("foo")
|
|
tmpdir.mkdir("bar")
|
|
tmpdir.mkdir("baz")
|
|
(Path(tmpdir) / "baz" / "qux").mkdir()
|
|
(Path(tmpdir) / "foo" / "abc").mkdir()
|
|
(Path(tmpdir) / "foo" / "abc" / "def").mkdir()
|
|
(Path(tmpdir) / "foo" / "abc" / "def" / "file.txt").touch()
|
|
|
|
file_explorer = gr.FileExplorer(glob="**/", root=Path(tmpdir))
|
|
tree = file_explorer.ls()
|
|
|
|
def sort_answer(answer):
|
|
answer = sorted(answer, key=lambda x: x["path"])
|
|
for item in answer:
|
|
if item["children"]:
|
|
item["children"] = sort_answer(item["children"])
|
|
return answer
|
|
|
|
answer = [
|
|
{
|
|
"path": "bar",
|
|
"type": "folder",
|
|
"children": [{"path": "", "type": "file", "children": None}],
|
|
},
|
|
{
|
|
"path": "baz",
|
|
"type": "folder",
|
|
"children": [
|
|
{"path": "", "type": "file", "children": None},
|
|
{
|
|
"path": "qux",
|
|
"type": "folder",
|
|
"children": [{"path": "", "type": "file", "children": None}],
|
|
},
|
|
],
|
|
},
|
|
{
|
|
"path": "foo",
|
|
"type": "folder",
|
|
"children": [
|
|
{"path": "", "type": "file", "children": None},
|
|
{
|
|
"path": "abc",
|
|
"type": "folder",
|
|
"children": [
|
|
{"path": "", "type": "file", "children": None},
|
|
{
|
|
"path": "def",
|
|
"type": "folder",
|
|
"children": [
|
|
{"path": "", "type": "file", "children": None}
|
|
],
|
|
},
|
|
],
|
|
},
|
|
],
|
|
},
|
|
]
|
|
assert sort_answer(tree) == sort_answer(answer)
|
|
|
|
|
|
def test_component_class_ids():
|
|
button_id = gr.Button().component_class_id
|
|
textbox_id = gr.Textbox().component_class_id
|
|
json_id = gr.JSON().component_class_id
|
|
mic_id = gr.Mic().component_class_id
|
|
microphone_id = gr.Microphone().component_class_id
|
|
audio_id = gr.Audio().component_class_id
|
|
|
|
assert button_id == gr.Button().component_class_id
|
|
assert textbox_id == gr.Textbox().component_class_id
|
|
assert json_id == gr.JSON().component_class_id
|
|
assert mic_id == gr.Mic().component_class_id
|
|
assert microphone_id == gr.Microphone().component_class_id
|
|
assert audio_id == gr.Audio().component_class_id
|
|
assert mic_id == microphone_id
|
|
|
|
# Make sure that the ids are unique
|
|
assert len({button_id, textbox_id, json_id, microphone_id, audio_id}) == 5
|
|
|
|
|
|
def test_constructor_args():
|
|
assert gr.Textbox(max_lines=314).constructor_args == {"max_lines": 314}
|
|
assert gr.LoginButton(visible=False, value="Log in please").constructor_args == {
|
|
"visible": False,
|
|
"value": "Log in please",
|
|
}
|
|
|
|
|
|
def test_template_component_configs(io_components):
|
|
template_components = [c for c in io_components if getattr(c, "is_template", False)]
|
|
for component in template_components:
|
|
component_parent_class = inspect.getmro(component)[1]
|
|
template_config = component().get_config()
|
|
parent_config = component_parent_class().get_config()
|
|
assert set(parent_config.keys()).issubset(set(template_config.keys()))
|