tabbed-interface-rewritten (#958)

This commit is contained in:
Ömer Faruk Özdemir 2022-04-08 12:13:56 +03:00 committed by GitHub
parent 3c876c49ef
commit dc6175a21d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 478 additions and 2 deletions

27
demo/sst_or_tts/run.py Normal file
View File

@ -0,0 +1,27 @@
import gradio as gr
title = "GPT-J-6B"
tts_examples = [
"I love learning machine learning",
"How do you do?",
]
tts_demo = gr.Interface.load(
"huggingface/facebook/fastspeech2-en-ljspeech",
title=None,
examples=tts_examples,
description="Give me something to say!"
)
stt_demo = gr.Interface.load(
"huggingface/facebook/wav2vec2-base-960h",
title=None,
inputs="mic",
description="Let me try to guess what you're saying!"
)
demo = gr.TabbedInterface([tts_demo, stt_demo], ["Text-to-speech", "Speech-to-text"])
if __name__ == "__main__":
demo.launch()

View File

@ -36,7 +36,7 @@ from gradio.flagging import (
HuggingFaceDatasetSaver,
SimpleCSVLogger,
)
from gradio.interface import Interface, close_all, reset_all
from gradio.interface import Interface, TabbedInterface, close_all, reset_all
from gradio.mix import Parallel, Series
from gradio.routes import get_state, set_state

View File

@ -590,6 +590,25 @@ class Interface(Blocks):
repr += "\n|-{}".format(str(component))
return repr
def render_basic_interface(self):
Interface(
fn=self.predict,
inputs=self.input_components,
outputs=self.output_components,
examples=self.examples,
examples_per_page=self.examples_per_page,
live=self.live,
layout=self.layout,
interpretation=self.interpretation,
num_shap=self.num_shap,
title=self.title,
description=self.description,
article=self.article,
allow_flagging=self.allow_flagging,
flagging_options=self.flagging_options,
flagging_dir=self.flagging_dir,
)
def run_prediction(
self,
processed_input: List[Any],
@ -769,6 +788,20 @@ class Interface(Blocks):
utils.integration_analytics(data)
class TabbedInterface(Blocks):
def __init__(
self, interface_list: List[Interface], tab_names: Optional[List[str]] = None
):
if tab_names is None:
tab_names = ["Tab {}".format(i) for i in range(len(interface_list))]
super().__init__()
with self:
with Tabs():
for (interface, tab_name) in zip(interface_list, tab_names):
with TabItem(label=tab_name):
interface.render_basic_interface()
def close_all(verbose: bool = True) -> None:
for io in Interface.get_instances():
io.close(verbose)

View File

@ -159,3 +159,327 @@ XRAY_CONFIG = {
},
],
}
XRAY_CONFIG_DIFF_IDS = {
"mode": "blocks",
"components": [
{
"id": 1,
"type": "markdown",
"props": {
"default_value": "<h1>Detect Disease From Scan</h1>\n<p>With this model you can lorem ipsum</p>\n<ul>\n<li>ipsum 1</li>\n<li>ipsum 2</li>\n</ul>\n",
"name": "markdown",
"label": None,
"css": {},
},
},
{
"id": 22,
"type": "checkboxgroup",
"props": {
"choices": ["Covid", "Malaria", "Lung Cancer"],
"default_value": [],
"name": "checkboxgroup",
"label": "Disease to Scan For",
"css": {},
},
},
{"id": 3, "type": "tabs", "props": {"css": {}, "default_value": True}},
{
"id": 444,
"type": "tabitem",
"props": {"label": "X-ray", "css": {}, "default_value": True},
},
{
"id": 5,
"type": "row",
"props": {"type": "row", "css": {}, "default_value": True},
},
{
"id": 6,
"type": "image",
"props": {
"image_mode": "RGB",
"shape": None,
"source": "upload",
"tool": "editor",
"default_value": None,
"name": "image",
"label": None,
"css": {},
},
},
{
"id": 7,
"type": "json",
"props": {"default_value": '""', "name": "json", "label": None, "css": {}},
},
{
"id": 8888,
"type": "button",
"props": {
"default_value": "Run",
"name": "button",
"label": None,
"css": {"background-color": "red", "--hover-color": "orange"},
},
},
{
"id": 9,
"type": "tabitem",
"props": {"label": "CT Scan", "css": {}, "default_value": True},
},
{
"id": 10,
"type": "row",
"props": {"type": "row", "css": {}, "default_value": True},
},
{
"id": 11,
"type": "image",
"props": {
"image_mode": "RGB",
"shape": None,
"source": "upload",
"tool": "editor",
"default_value": None,
"name": "image",
"label": None,
"css": {},
},
},
{
"id": 12,
"type": "json",
"props": {"default_value": '""', "name": "json", "label": None, "css": {}},
},
{
"id": 13,
"type": "button",
"props": {
"default_value": "Run",
"name": "button",
"label": None,
"css": {},
},
},
{
"id": 141,
"type": "textbox",
"props": {
"lines": 1,
"placeholder": None,
"default_value": "",
"name": "textbox",
"label": None,
"css": {},
},
},
],
"theme": "default",
"layout": {
"id": 0,
"children": [
{"id": 1},
{"id": 22},
{
"id": 3,
"children": [
{
"id": 444,
"children": [
{"id": 5, "children": [{"id": 6}, {"id": 7}]},
{"id": 8888},
],
},
{
"id": 9,
"children": [
{"id": 10, "children": [{"id": 11}, {"id": 12}]},
{"id": 13},
],
},
],
},
{"id": 141},
],
},
"dependencies": [
{
"targets": [8888],
"trigger": "click",
"inputs": [22, 6],
"outputs": [7],
"queue": False,
},
{
"targets": [13],
"trigger": "click",
"inputs": [22, 11],
"outputs": [12],
"queue": False,
},
],
}
XRAY_CONFIG_WITH_MISTAKE = {
"mode": "blocks",
"components": [
{
"id": 1,
"type": "markdown",
"props": {
"default_value": "<h1>Detect Disease From Scan</h1>\n<p>With this model you can lorem ipsum</p>\n<ul>\n<li>ipsum 1</li>\n<li>ipsum 2</li>\n</ul>\n",
"name": "markdown",
"label": None,
"css": {},
},
},
{
"id": 2,
"type": "checkboxgroup",
"props": {
"choices": ["Covid", "Malaria", "Lung Cancer"],
"default_value": [],
"name": "checkboxgroup",
"label": "Disease to Scan For",
"css": {},
},
},
{"id": 3, "type": "tabs", "props": {"css": {}, "default_value": True}},
{
"id": 4,
"type": "tabitem",
"props": {"label": "X-ray", "css": {}, "default_value": True},
},
{
"id": 5,
"type": "row",
"props": {"type": "row", "css": {}, "default_value": True},
},
{
"id": 6,
"type": "image",
"props": {
"image_mode": "RGB",
"shape": None,
"source": "upload",
"tool": "editor",
"default_value": None,
"name": "image",
"label": None,
"css": {},
},
},
{
"id": 7,
"type": "json",
"props": {"default_value": '""', "name": "json", "label": None, "css": {}},
},
{
"id": 8,
"type": "button",
"props": {
"default_value": "Run",
"name": "button",
"label": None,
"css": {"background-color": "red", "--hover-color": "orange"},
},
},
{
"id": 9,
"type": "tabitem",
"props": {"label": "CT Scan", "css": {}, "default_value": True},
},
{
"id": 10,
"type": "row",
"props": {"type": "row", "css": {}, "default_value": True},
},
{
"id": 11,
"type": "image",
"props": {
"image_mode": "RGB",
"shape": None,
"source": "upload",
"tool": "editor",
"default_value": None,
"name": "image",
"label": None,
"css": {},
},
},
{
"id": 12,
"type": "json",
"props": {"default_value": '""', "name": "json", "label": None, "css": {}},
},
{
"id": 13,
"type": "button",
"props": {
"default_value": "Run",
"name": "button",
"label": None,
"css": {},
},
},
{
"id": 14,
"type": "textbox",
"props": {
"lines": 1,
"placeholder": None,
"default_value": "",
"name": "textbox",
"label": None,
"css": {},
},
},
],
"theme": "default",
"layout": {
"id": 0,
"children": [
{"id": 1},
{"id": 2},
{
"id": 3,
"children": [
{
"id": 4,
"children": [
{"id": 5, "children": [{"id": 6}, {"id": 7}]},
{"id": 8},
],
},
{
"id": 9,
"children": [
{"id": 10, "children": [{"id": 12}, {"id": 11}]},
{"id": 13},
],
},
],
},
{"id": 14},
],
},
"dependencies": [
{
"targets": [8],
"trigger": "click",
"inputs": [2, 6],
"outputs": [7],
"queue": False,
},
{
"targets": [13],
"trigger": "click",
"inputs": [2, 11],
"outputs": [12],
"queue": False,
},
],
}

View File

@ -9,6 +9,7 @@ import json.decoder
import os
import random
import warnings
from copy import deepcopy
from distutils.version import StrictVersion
from typing import TYPE_CHECKING, Any, Callable, Dict
@ -286,3 +287,45 @@ def get_default_args(func: Callable) -> Dict[str, Any]:
v.default if v.default is not inspect.Parameter.empty else None
for v in signature.parameters.values()
]
def assert_configs_are_equivalent_besides_ids(config1, config2):
"""Allows you to test if two different Blocks configs produce the same demo."""
assert config1["mode"] == config2["mode"], "Modes are different"
assert config1["theme"] == config2["theme"], "Themes are different"
assert len(config1["components"]) == len(
config2["components"]
), "# of components are different"
mapping = {}
for c1, c2 in zip(config1["components"], config2["components"]):
c1, c2 = deepcopy(c1), deepcopy(c2)
mapping[c1["id"]] = c2["id"]
c1.pop("id")
c2.pop("id")
assert c1 == c2, "{} does not match {}".format(c1, c2)
def same_children_recursive(children1, chidren2, mapping):
for child1, child2 in zip(children1, chidren2):
assert mapping[child1["id"]] == child2["id"], "{} does not match {}".format(
child1, child2
)
if "children" in child1 or "children" in child2:
same_children_recursive(child1["children"], child2["children"], mapping)
children1 = config1["layout"]["children"]
children2 = config2["layout"]["children"]
same_children_recursive(children1, children2, mapping)
for d1, d2 in zip(config1["dependencies"], config2["dependencies"]):
for t1, t2 in zip(d1["targets"], d2["targets"]):
assert mapping[t1] == t2, "{} does not match {}".format(d1, d2)
assert d1["trigger"] == d2["trigger"], "{} does not match {}".format(d1, d2)
for i1, i2 in zip(d1["inputs"], d2["inputs"]):
assert mapping[i1] == i2, "{} does not match {}".format(d1, d2)
for o1, o2 in zip(d1["outputs"], d2["outputs"]):
assert mapping[o1] == o2, "{} does not match {}".format(d1, d2)
assert d1["queue"] == d2["queue"], "{} does not match {}".format(d1, d2)
return True

View File

@ -8,7 +8,9 @@ import mlflow
import requests
import wandb
from gradio.interface import Interface, close_all, os
from gradio.blocks import Blocks, TabItem, Tabs
from gradio.interface import Interface, TabbedInterface, close_all, os
from gradio.utils import assert_configs_are_equivalent_besides_ids
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
@ -184,5 +186,28 @@ class TestInterface(unittest.TestCase):
mock_post.assert_called_once()
class TestTabbedInterface(unittest.TestCase):
def test_tabbed_interface_config_matches_manual_tab(self):
interface1 = Interface(lambda x: x, "textbox", "textbox")
interface2 = Interface(lambda x: x, "image", "image")
with Blocks() as demo:
with Tabs():
with TabItem(label="tab1"):
interface1.render_basic_interface()
with TabItem(label="tab2"):
interface2.render_basic_interface()
interface3 = Interface(lambda x: x, "textbox", "textbox")
interface4 = Interface(lambda x: x, "image", "image")
tabbed_interface = TabbedInterface([interface3, interface4], ["tab1", "tab2"])
self.assertTrue(
assert_configs_are_equivalent_besides_ids(
demo.get_config_file(), tabbed_interface.get_config_file()
)
)
if __name__ == "__main__":
unittest.main()

View File

@ -7,7 +7,13 @@ import warnings
import pkg_resources
import requests
from gradio.test_data.blocks_configs import (
XRAY_CONFIG,
XRAY_CONFIG_DIFF_IDS,
XRAY_CONFIG_WITH_MISTAKE,
)
from gradio.utils import (
assert_configs_are_equivalent_besides_ids,
colab_check,
error_analytics,
get_local_ip_address,
@ -116,5 +122,23 @@ class TestIPAddress(unittest.TestCase):
self.assertEqual(ip, "No internet connection")
class TestAssertConfigsEquivalent(unittest.TestCase):
def test_same_configs(self):
self.assertTrue(
assert_configs_are_equivalent_besides_ids(XRAY_CONFIG, XRAY_CONFIG)
)
def test_equivalent_configs(self):
self.assertTrue(
assert_configs_are_equivalent_besides_ids(XRAY_CONFIG, XRAY_CONFIG_DIFF_IDS)
)
def test_different_configs(self):
with self.assertRaises(AssertionError):
assert_configs_are_equivalent_besides_ids(
XRAY_CONFIG_WITH_MISTAKE, XRAY_CONFIG
)
if __name__ == "__main__":
unittest.main()