mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-23 11:39:17 +08:00
tabbed-interface-rewritten (#958)
This commit is contained in:
parent
3c876c49ef
commit
dc6175a21d
27
demo/sst_or_tts/run.py
Normal file
27
demo/sst_or_tts/run.py
Normal file
@ -0,0 +1,27 @@
|
||||
import gradio as gr
|
||||
|
||||
title = "GPT-J-6B"
|
||||
|
||||
tts_examples = [
|
||||
"I love learning machine learning",
|
||||
"How do you do?",
|
||||
]
|
||||
|
||||
tts_demo = gr.Interface.load(
|
||||
"huggingface/facebook/fastspeech2-en-ljspeech",
|
||||
title=None,
|
||||
examples=tts_examples,
|
||||
description="Give me something to say!"
|
||||
)
|
||||
|
||||
stt_demo = gr.Interface.load(
|
||||
"huggingface/facebook/wav2vec2-base-960h",
|
||||
title=None,
|
||||
inputs="mic",
|
||||
description="Let me try to guess what you're saying!"
|
||||
)
|
||||
|
||||
demo = gr.TabbedInterface([tts_demo, stt_demo], ["Text-to-speech", "Speech-to-text"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
@ -36,7 +36,7 @@ from gradio.flagging import (
|
||||
HuggingFaceDatasetSaver,
|
||||
SimpleCSVLogger,
|
||||
)
|
||||
from gradio.interface import Interface, close_all, reset_all
|
||||
from gradio.interface import Interface, TabbedInterface, close_all, reset_all
|
||||
from gradio.mix import Parallel, Series
|
||||
from gradio.routes import get_state, set_state
|
||||
|
||||
|
@ -590,6 +590,25 @@ class Interface(Blocks):
|
||||
repr += "\n|-{}".format(str(component))
|
||||
return repr
|
||||
|
||||
def render_basic_interface(self):
|
||||
Interface(
|
||||
fn=self.predict,
|
||||
inputs=self.input_components,
|
||||
outputs=self.output_components,
|
||||
examples=self.examples,
|
||||
examples_per_page=self.examples_per_page,
|
||||
live=self.live,
|
||||
layout=self.layout,
|
||||
interpretation=self.interpretation,
|
||||
num_shap=self.num_shap,
|
||||
title=self.title,
|
||||
description=self.description,
|
||||
article=self.article,
|
||||
allow_flagging=self.allow_flagging,
|
||||
flagging_options=self.flagging_options,
|
||||
flagging_dir=self.flagging_dir,
|
||||
)
|
||||
|
||||
def run_prediction(
|
||||
self,
|
||||
processed_input: List[Any],
|
||||
@ -769,6 +788,20 @@ class Interface(Blocks):
|
||||
utils.integration_analytics(data)
|
||||
|
||||
|
||||
class TabbedInterface(Blocks):
|
||||
def __init__(
|
||||
self, interface_list: List[Interface], tab_names: Optional[List[str]] = None
|
||||
):
|
||||
if tab_names is None:
|
||||
tab_names = ["Tab {}".format(i) for i in range(len(interface_list))]
|
||||
super().__init__()
|
||||
with self:
|
||||
with Tabs():
|
||||
for (interface, tab_name) in zip(interface_list, tab_names):
|
||||
with TabItem(label=tab_name):
|
||||
interface.render_basic_interface()
|
||||
|
||||
|
||||
def close_all(verbose: bool = True) -> None:
|
||||
for io in Interface.get_instances():
|
||||
io.close(verbose)
|
||||
|
@ -159,3 +159,327 @@ XRAY_CONFIG = {
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
XRAY_CONFIG_DIFF_IDS = {
|
||||
"mode": "blocks",
|
||||
"components": [
|
||||
{
|
||||
"id": 1,
|
||||
"type": "markdown",
|
||||
"props": {
|
||||
"default_value": "<h1>Detect Disease From Scan</h1>\n<p>With this model you can lorem ipsum</p>\n<ul>\n<li>ipsum 1</li>\n<li>ipsum 2</li>\n</ul>\n",
|
||||
"name": "markdown",
|
||||
"label": None,
|
||||
"css": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 22,
|
||||
"type": "checkboxgroup",
|
||||
"props": {
|
||||
"choices": ["Covid", "Malaria", "Lung Cancer"],
|
||||
"default_value": [],
|
||||
"name": "checkboxgroup",
|
||||
"label": "Disease to Scan For",
|
||||
"css": {},
|
||||
},
|
||||
},
|
||||
{"id": 3, "type": "tabs", "props": {"css": {}, "default_value": True}},
|
||||
{
|
||||
"id": 444,
|
||||
"type": "tabitem",
|
||||
"props": {"label": "X-ray", "css": {}, "default_value": True},
|
||||
},
|
||||
{
|
||||
"id": 5,
|
||||
"type": "row",
|
||||
"props": {"type": "row", "css": {}, "default_value": True},
|
||||
},
|
||||
{
|
||||
"id": 6,
|
||||
"type": "image",
|
||||
"props": {
|
||||
"image_mode": "RGB",
|
||||
"shape": None,
|
||||
"source": "upload",
|
||||
"tool": "editor",
|
||||
"default_value": None,
|
||||
"name": "image",
|
||||
"label": None,
|
||||
"css": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 7,
|
||||
"type": "json",
|
||||
"props": {"default_value": '""', "name": "json", "label": None, "css": {}},
|
||||
},
|
||||
{
|
||||
"id": 8888,
|
||||
"type": "button",
|
||||
"props": {
|
||||
"default_value": "Run",
|
||||
"name": "button",
|
||||
"label": None,
|
||||
"css": {"background-color": "red", "--hover-color": "orange"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 9,
|
||||
"type": "tabitem",
|
||||
"props": {"label": "CT Scan", "css": {}, "default_value": True},
|
||||
},
|
||||
{
|
||||
"id": 10,
|
||||
"type": "row",
|
||||
"props": {"type": "row", "css": {}, "default_value": True},
|
||||
},
|
||||
{
|
||||
"id": 11,
|
||||
"type": "image",
|
||||
"props": {
|
||||
"image_mode": "RGB",
|
||||
"shape": None,
|
||||
"source": "upload",
|
||||
"tool": "editor",
|
||||
"default_value": None,
|
||||
"name": "image",
|
||||
"label": None,
|
||||
"css": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"type": "json",
|
||||
"props": {"default_value": '""', "name": "json", "label": None, "css": {}},
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"type": "button",
|
||||
"props": {
|
||||
"default_value": "Run",
|
||||
"name": "button",
|
||||
"label": None,
|
||||
"css": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 141,
|
||||
"type": "textbox",
|
||||
"props": {
|
||||
"lines": 1,
|
||||
"placeholder": None,
|
||||
"default_value": "",
|
||||
"name": "textbox",
|
||||
"label": None,
|
||||
"css": {},
|
||||
},
|
||||
},
|
||||
],
|
||||
"theme": "default",
|
||||
"layout": {
|
||||
"id": 0,
|
||||
"children": [
|
||||
{"id": 1},
|
||||
{"id": 22},
|
||||
{
|
||||
"id": 3,
|
||||
"children": [
|
||||
{
|
||||
"id": 444,
|
||||
"children": [
|
||||
{"id": 5, "children": [{"id": 6}, {"id": 7}]},
|
||||
{"id": 8888},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": 9,
|
||||
"children": [
|
||||
{"id": 10, "children": [{"id": 11}, {"id": 12}]},
|
||||
{"id": 13},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
{"id": 141},
|
||||
],
|
||||
},
|
||||
"dependencies": [
|
||||
{
|
||||
"targets": [8888],
|
||||
"trigger": "click",
|
||||
"inputs": [22, 6],
|
||||
"outputs": [7],
|
||||
"queue": False,
|
||||
},
|
||||
{
|
||||
"targets": [13],
|
||||
"trigger": "click",
|
||||
"inputs": [22, 11],
|
||||
"outputs": [12],
|
||||
"queue": False,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
XRAY_CONFIG_WITH_MISTAKE = {
|
||||
"mode": "blocks",
|
||||
"components": [
|
||||
{
|
||||
"id": 1,
|
||||
"type": "markdown",
|
||||
"props": {
|
||||
"default_value": "<h1>Detect Disease From Scan</h1>\n<p>With this model you can lorem ipsum</p>\n<ul>\n<li>ipsum 1</li>\n<li>ipsum 2</li>\n</ul>\n",
|
||||
"name": "markdown",
|
||||
"label": None,
|
||||
"css": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "checkboxgroup",
|
||||
"props": {
|
||||
"choices": ["Covid", "Malaria", "Lung Cancer"],
|
||||
"default_value": [],
|
||||
"name": "checkboxgroup",
|
||||
"label": "Disease to Scan For",
|
||||
"css": {},
|
||||
},
|
||||
},
|
||||
{"id": 3, "type": "tabs", "props": {"css": {}, "default_value": True}},
|
||||
{
|
||||
"id": 4,
|
||||
"type": "tabitem",
|
||||
"props": {"label": "X-ray", "css": {}, "default_value": True},
|
||||
},
|
||||
{
|
||||
"id": 5,
|
||||
"type": "row",
|
||||
"props": {"type": "row", "css": {}, "default_value": True},
|
||||
},
|
||||
{
|
||||
"id": 6,
|
||||
"type": "image",
|
||||
"props": {
|
||||
"image_mode": "RGB",
|
||||
"shape": None,
|
||||
"source": "upload",
|
||||
"tool": "editor",
|
||||
"default_value": None,
|
||||
"name": "image",
|
||||
"label": None,
|
||||
"css": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 7,
|
||||
"type": "json",
|
||||
"props": {"default_value": '""', "name": "json", "label": None, "css": {}},
|
||||
},
|
||||
{
|
||||
"id": 8,
|
||||
"type": "button",
|
||||
"props": {
|
||||
"default_value": "Run",
|
||||
"name": "button",
|
||||
"label": None,
|
||||
"css": {"background-color": "red", "--hover-color": "orange"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 9,
|
||||
"type": "tabitem",
|
||||
"props": {"label": "CT Scan", "css": {}, "default_value": True},
|
||||
},
|
||||
{
|
||||
"id": 10,
|
||||
"type": "row",
|
||||
"props": {"type": "row", "css": {}, "default_value": True},
|
||||
},
|
||||
{
|
||||
"id": 11,
|
||||
"type": "image",
|
||||
"props": {
|
||||
"image_mode": "RGB",
|
||||
"shape": None,
|
||||
"source": "upload",
|
||||
"tool": "editor",
|
||||
"default_value": None,
|
||||
"name": "image",
|
||||
"label": None,
|
||||
"css": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"type": "json",
|
||||
"props": {"default_value": '""', "name": "json", "label": None, "css": {}},
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"type": "button",
|
||||
"props": {
|
||||
"default_value": "Run",
|
||||
"name": "button",
|
||||
"label": None,
|
||||
"css": {},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": 14,
|
||||
"type": "textbox",
|
||||
"props": {
|
||||
"lines": 1,
|
||||
"placeholder": None,
|
||||
"default_value": "",
|
||||
"name": "textbox",
|
||||
"label": None,
|
||||
"css": {},
|
||||
},
|
||||
},
|
||||
],
|
||||
"theme": "default",
|
||||
"layout": {
|
||||
"id": 0,
|
||||
"children": [
|
||||
{"id": 1},
|
||||
{"id": 2},
|
||||
{
|
||||
"id": 3,
|
||||
"children": [
|
||||
{
|
||||
"id": 4,
|
||||
"children": [
|
||||
{"id": 5, "children": [{"id": 6}, {"id": 7}]},
|
||||
{"id": 8},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": 9,
|
||||
"children": [
|
||||
{"id": 10, "children": [{"id": 12}, {"id": 11}]},
|
||||
{"id": 13},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
{"id": 14},
|
||||
],
|
||||
},
|
||||
"dependencies": [
|
||||
{
|
||||
"targets": [8],
|
||||
"trigger": "click",
|
||||
"inputs": [2, 6],
|
||||
"outputs": [7],
|
||||
"queue": False,
|
||||
},
|
||||
{
|
||||
"targets": [13],
|
||||
"trigger": "click",
|
||||
"inputs": [2, 11],
|
||||
"outputs": [12],
|
||||
"queue": False,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ import json.decoder
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from distutils.version import StrictVersion
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict
|
||||
|
||||
@ -286,3 +287,45 @@ def get_default_args(func: Callable) -> Dict[str, Any]:
|
||||
v.default if v.default is not inspect.Parameter.empty else None
|
||||
for v in signature.parameters.values()
|
||||
]
|
||||
|
||||
|
||||
def assert_configs_are_equivalent_besides_ids(config1, config2):
|
||||
"""Allows you to test if two different Blocks configs produce the same demo."""
|
||||
assert config1["mode"] == config2["mode"], "Modes are different"
|
||||
assert config1["theme"] == config2["theme"], "Themes are different"
|
||||
assert len(config1["components"]) == len(
|
||||
config2["components"]
|
||||
), "# of components are different"
|
||||
|
||||
mapping = {}
|
||||
|
||||
for c1, c2 in zip(config1["components"], config2["components"]):
|
||||
c1, c2 = deepcopy(c1), deepcopy(c2)
|
||||
mapping[c1["id"]] = c2["id"]
|
||||
c1.pop("id")
|
||||
c2.pop("id")
|
||||
assert c1 == c2, "{} does not match {}".format(c1, c2)
|
||||
|
||||
def same_children_recursive(children1, chidren2, mapping):
|
||||
for child1, child2 in zip(children1, chidren2):
|
||||
assert mapping[child1["id"]] == child2["id"], "{} does not match {}".format(
|
||||
child1, child2
|
||||
)
|
||||
if "children" in child1 or "children" in child2:
|
||||
same_children_recursive(child1["children"], child2["children"], mapping)
|
||||
|
||||
children1 = config1["layout"]["children"]
|
||||
children2 = config2["layout"]["children"]
|
||||
same_children_recursive(children1, children2, mapping)
|
||||
|
||||
for d1, d2 in zip(config1["dependencies"], config2["dependencies"]):
|
||||
for t1, t2 in zip(d1["targets"], d2["targets"]):
|
||||
assert mapping[t1] == t2, "{} does not match {}".format(d1, d2)
|
||||
assert d1["trigger"] == d2["trigger"], "{} does not match {}".format(d1, d2)
|
||||
for i1, i2 in zip(d1["inputs"], d2["inputs"]):
|
||||
assert mapping[i1] == i2, "{} does not match {}".format(d1, d2)
|
||||
for o1, o2 in zip(d1["outputs"], d2["outputs"]):
|
||||
assert mapping[o1] == o2, "{} does not match {}".format(d1, d2)
|
||||
assert d1["queue"] == d2["queue"], "{} does not match {}".format(d1, d2)
|
||||
|
||||
return True
|
||||
|
@ -8,7 +8,9 @@ import mlflow
|
||||
import requests
|
||||
import wandb
|
||||
|
||||
from gradio.interface import Interface, close_all, os
|
||||
from gradio.blocks import Blocks, TabItem, Tabs
|
||||
from gradio.interface import Interface, TabbedInterface, close_all, os
|
||||
from gradio.utils import assert_configs_are_equivalent_besides_ids
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -184,5 +186,28 @@ class TestInterface(unittest.TestCase):
|
||||
mock_post.assert_called_once()
|
||||
|
||||
|
||||
class TestTabbedInterface(unittest.TestCase):
|
||||
def test_tabbed_interface_config_matches_manual_tab(self):
|
||||
interface1 = Interface(lambda x: x, "textbox", "textbox")
|
||||
interface2 = Interface(lambda x: x, "image", "image")
|
||||
|
||||
with Blocks() as demo:
|
||||
with Tabs():
|
||||
with TabItem(label="tab1"):
|
||||
interface1.render_basic_interface()
|
||||
with TabItem(label="tab2"):
|
||||
interface2.render_basic_interface()
|
||||
|
||||
interface3 = Interface(lambda x: x, "textbox", "textbox")
|
||||
interface4 = Interface(lambda x: x, "image", "image")
|
||||
tabbed_interface = TabbedInterface([interface3, interface4], ["tab1", "tab2"])
|
||||
|
||||
self.assertTrue(
|
||||
assert_configs_are_equivalent_besides_ids(
|
||||
demo.get_config_file(), tabbed_interface.get_config_file()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -7,7 +7,13 @@ import warnings
|
||||
import pkg_resources
|
||||
import requests
|
||||
|
||||
from gradio.test_data.blocks_configs import (
|
||||
XRAY_CONFIG,
|
||||
XRAY_CONFIG_DIFF_IDS,
|
||||
XRAY_CONFIG_WITH_MISTAKE,
|
||||
)
|
||||
from gradio.utils import (
|
||||
assert_configs_are_equivalent_besides_ids,
|
||||
colab_check,
|
||||
error_analytics,
|
||||
get_local_ip_address,
|
||||
@ -116,5 +122,23 @@ class TestIPAddress(unittest.TestCase):
|
||||
self.assertEqual(ip, "No internet connection")
|
||||
|
||||
|
||||
class TestAssertConfigsEquivalent(unittest.TestCase):
|
||||
def test_same_configs(self):
|
||||
self.assertTrue(
|
||||
assert_configs_are_equivalent_besides_ids(XRAY_CONFIG, XRAY_CONFIG)
|
||||
)
|
||||
|
||||
def test_equivalent_configs(self):
|
||||
self.assertTrue(
|
||||
assert_configs_are_equivalent_besides_ids(XRAY_CONFIG, XRAY_CONFIG_DIFF_IDS)
|
||||
)
|
||||
|
||||
def test_different_configs(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
assert_configs_are_equivalent_besides_ids(
|
||||
XRAY_CONFIG_WITH_MISTAKE, XRAY_CONFIG
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user