mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
Gradio Blocks (#590)
Integrate initial work on Gradio Blocks by creating Blocks class and frontend Block code.
This commit is contained in:
parent
faf55f2c9e
commit
11dda55352
@ -19,7 +19,7 @@ iface = gr.Interface(
|
||||
),
|
||||
gr.inputs.Textbox(lines=3, default="The fast brown fox jumps over lazy dogs."),
|
||||
],
|
||||
gr.outputs.HighlightedText(color_map={'+': 'green', '-': 'pink'}),
|
||||
gr.outputs.HighlightedText(color_map={"+": "green", "-": "pink"}),
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
|
@ -7,4 +7,4 @@ def greet(name):
|
||||
|
||||
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
||||
if __name__ == "__main__":
|
||||
iface.launch();
|
||||
iface.launch()
|
||||
|
@ -12,4 +12,3 @@ iface = gr.Interface(
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
app, local_url, share_url = iface.launch()
|
||||
|
||||
|
@ -14,5 +14,4 @@ iface = gr.Interface(
|
||||
outputs=["text", "number"],
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
iface.launch();
|
||||
|
||||
iface.launch()
|
||||
|
41
demo/xray_blocks/run.py
Normal file
41
demo/xray_blocks/run.py
Normal file
@ -0,0 +1,41 @@
|
||||
import gradio as gr
|
||||
|
||||
import random
|
||||
|
||||
xray_model = lambda diseases, img: {disease: random.random() for disease in diseases}
|
||||
ct_model = lambda diseases, img: {disease: 0.1 for disease in diseases}
|
||||
|
||||
xray_blocks = gr.Blocks()
|
||||
|
||||
with xray_blocks:
|
||||
gr.Markdown(
|
||||
"""
|
||||
# Detect Disease From Scan
|
||||
With this model you can lorem ipsum
|
||||
- ipsum 1
|
||||
- ipsum 2
|
||||
"""
|
||||
)
|
||||
disease = gr.inputs.CheckboxGroup(
|
||||
["Covid", "Malaria", "Lung Cancer"], label="Disease to Scan For"
|
||||
)
|
||||
|
||||
with gr.Tab("X-ray"):
|
||||
with gr.Row():
|
||||
xray_scan = gr.inputs.Image()
|
||||
xray_results = gr.outputs.JSON()
|
||||
xray_run = gr.Button("Run")
|
||||
xray_run.click(xray_model, inputs=[disease, xray_scan], outputs=xray_results)
|
||||
|
||||
with gr.Tab("CT Scan"):
|
||||
with gr.Row():
|
||||
ct_scan = gr.inputs.Image()
|
||||
ct_results = gr.outputs.JSON()
|
||||
ct_run = gr.Button("Run")
|
||||
ct_run.click(ct_model, inputs=[disease, ct_scan], outputs=ct_results)
|
||||
|
||||
overall_probability = gr.outputs.Textbox()
|
||||
|
||||
# TODO: remove later
|
||||
print(xray_blocks.get_config_file())
|
||||
xray_blocks.launch()
|
@ -1,5 +1,6 @@
|
||||
import pkg_resources
|
||||
|
||||
from gradio.blocks import Blocks, Column, Row, Tab
|
||||
from gradio.flagging import (
|
||||
CSVLogger,
|
||||
FlaggingCallback,
|
||||
@ -9,6 +10,7 @@ from gradio.flagging import (
|
||||
from gradio.interface import Interface, close_all, reset_all
|
||||
from gradio.mix import Parallel, Series
|
||||
from gradio.routes import get_state, set_state
|
||||
from gradio.static import Button, Markdown
|
||||
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
__version__ = current_pkg_version
|
||||
|
150
gradio/blocks.py
Normal file
150
gradio/blocks.py
Normal file
@ -0,0 +1,150 @@
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from gradio import utils
|
||||
from gradio.context import Context
|
||||
from gradio.launchable import Launchable
|
||||
|
||||
|
||||
class Block:
|
||||
def __init__(self):
|
||||
self._id = Context.id
|
||||
Context.id += 1
|
||||
if Context.block is not None:
|
||||
Context.block.children.append(self)
|
||||
if Context.root_block is not None:
|
||||
Context.root_block.blocks[self._id] = self
|
||||
self.events = []
|
||||
|
||||
def click(self, fn, inputs, outputs):
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs]
|
||||
Context.root_block.fns.append(fn)
|
||||
Context.root_block.dependencies.append(
|
||||
{
|
||||
"targets": [self._id],
|
||||
"trigger": "click",
|
||||
"inputs": [block._id for block in inputs],
|
||||
"outputs": [block._id for block in outputs],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class BlockContext(Block):
|
||||
def __init__(self):
|
||||
self.children = []
|
||||
super().__init__()
|
||||
|
||||
def __enter__(self):
|
||||
self.parent = Context.block
|
||||
Context.block = self
|
||||
|
||||
def __exit__(self, *args):
|
||||
Context.block = self.parent
|
||||
|
||||
|
||||
class Row(BlockContext):
|
||||
def get_template_context(self):
|
||||
return {"type": "row"}
|
||||
|
||||
|
||||
class Column(BlockContext):
|
||||
def get_template_context(self):
|
||||
return {"type": "column"}
|
||||
|
||||
|
||||
class Tab(BlockContext):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
super(Tab, self).__init__()
|
||||
|
||||
def get_template_context(self):
|
||||
return {"type": "tab", "name": self.name}
|
||||
|
||||
|
||||
class Blocks(Launchable, BlockContext):
|
||||
def __init__(self, theme="default"):
|
||||
# Cleanup shared parameters with Interface
|
||||
self.save_to = None
|
||||
self.ip_address = utils.get_local_ip_address()
|
||||
self.api_mode = False
|
||||
self.analytics_enabled = True
|
||||
self.theme = theme
|
||||
self.requires_permissions = False # TODO: needs to be implemented
|
||||
self.enable_queue = False
|
||||
|
||||
super().__init__()
|
||||
Context.root_block = self
|
||||
self.blocks = {}
|
||||
self.fns = []
|
||||
self.dependencies = []
|
||||
|
||||
def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]:
|
||||
raw_input = data["data"]
|
||||
fn_index = data["fn_index"]
|
||||
fn = self.fns[fn_index]
|
||||
dependency = self.dependencies[fn_index]
|
||||
|
||||
processed_input = [
|
||||
self.blocks[input_id].preprocess(raw_input[i])
|
||||
for i, input_id in enumerate(dependency["inputs"])
|
||||
]
|
||||
predictions = fn(*processed_input)
|
||||
if len(dependency["outputs"]) == 1:
|
||||
predictions = (predictions,)
|
||||
processed_output = [
|
||||
self.blocks[output_id].postprocess(predictions[i])
|
||||
if predictions[i] is not None
|
||||
else None
|
||||
for i, output_id in enumerate(dependency["outputs"])
|
||||
]
|
||||
return {"data": processed_output}
|
||||
|
||||
def get_template_context(self):
|
||||
return {"type": "column"}
|
||||
|
||||
def get_config_file(self):
|
||||
from gradio.component import Component
|
||||
|
||||
config = {"mode": "blocks", "components": [], "theme": self.theme}
|
||||
for _id, block in self.blocks.items():
|
||||
if isinstance(block, Component):
|
||||
config["components"].append(
|
||||
{
|
||||
"id": _id,
|
||||
"type": block.component_type,
|
||||
"props": block.get_template_context(),
|
||||
}
|
||||
)
|
||||
|
||||
def getLayout(block_context):
|
||||
if not isinstance(block_context, BlockContext):
|
||||
return block_context._id
|
||||
children = []
|
||||
running_tabs = []
|
||||
for child in block_context.children:
|
||||
if isinstance(child, Tab):
|
||||
running_tabs.append(getLayout(child))
|
||||
continue
|
||||
if len(running_tabs):
|
||||
children.append({"type": "tabset", "children": running_tabs})
|
||||
running_tabs = []
|
||||
|
||||
children.append(getLayout(child))
|
||||
if len(running_tabs):
|
||||
children.append({"type": "tabset", "children": running_tabs})
|
||||
running_tabs = []
|
||||
return {"children": children, **block_context.get_template_context()}
|
||||
|
||||
config["layout"] = getLayout(self)
|
||||
config["dependencies"] = self.dependencies
|
||||
return config
|
||||
|
||||
def __enter__(self):
|
||||
BlockContext.__enter__(self)
|
||||
Context.root_block = self
|
||||
|
||||
def __exit__(self, *args):
|
||||
BlockContext.__exit__(self, *args)
|
||||
Context.root_block = self.parent
|
@ -3,9 +3,10 @@ import shutil
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from gradio import processing_utils
|
||||
from gradio.blocks import Block
|
||||
|
||||
|
||||
class Component:
|
||||
class Component(Block):
|
||||
"""
|
||||
A class for defining the methods that all gradio input and output components should have.
|
||||
"""
|
||||
@ -13,6 +14,7 @@ class Component:
|
||||
def __init__(self, label, requires_permissions=False):
|
||||
self.label = label
|
||||
self.requires_permissions = requires_permissions
|
||||
super().__init__()
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
4
gradio/context.py
Normal file
4
gradio/context.py
Normal file
@ -0,0 +1,4 @@
|
||||
class Context:
|
||||
root_block = None
|
||||
block = None
|
||||
id = 0
|
@ -36,6 +36,7 @@ class InputComponent(Component):
|
||||
"""
|
||||
Constructs an input component.
|
||||
"""
|
||||
self.component_type = "input"
|
||||
self.set_interpret_parameters()
|
||||
self.optional = optional
|
||||
super().__init__(label, requires_permissions)
|
||||
|
@ -6,38 +6,36 @@ including various methods for constructing an interface and then launching it.
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import getpass
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
import weakref
|
||||
import webbrowser
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from markdown_it import MarkdownIt
|
||||
from mdit_py_plugins.footnote import footnote_plugin
|
||||
|
||||
from gradio import networking # type: ignore
|
||||
from gradio import encryptor, interpretation, queueing, strings, utils
|
||||
from gradio import interpretation, utils
|
||||
from gradio.external import load_from_pipeline, load_interface # type: ignore
|
||||
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
|
||||
from gradio.inputs import InputComponent
|
||||
from gradio.inputs import State as i_State # type: ignore
|
||||
from gradio.inputs import get_input_instance
|
||||
from gradio.launchable import Launchable
|
||||
from gradio.outputs import OutputComponent
|
||||
from gradio.outputs import State as o_State # type: ignore
|
||||
from gradio.outputs import get_output_instance
|
||||
from gradio.process_examples import cache_interface_examples
|
||||
from gradio.process_examples import load_from_cache, process_example
|
||||
from gradio.routes import predict
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
import flask
|
||||
import transformers
|
||||
|
||||
|
||||
class Interface:
|
||||
class Interface(Launchable):
|
||||
"""
|
||||
Gradio interfaces are created by constructing a `Interface` object
|
||||
with a locally-defined function, or with `Interface.load()` with the path
|
||||
@ -546,6 +544,34 @@ class Interface:
|
||||
else:
|
||||
return predictions
|
||||
|
||||
def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]:
|
||||
flag_index = None
|
||||
if data.get("example_id") is not None:
|
||||
example_id = data["example_id"]
|
||||
if self.cache_examples:
|
||||
prediction = load_from_cache(self, example_id)
|
||||
durations = None
|
||||
else:
|
||||
prediction, durations = process_example(self, example_id)
|
||||
else:
|
||||
raw_input = data["data"]
|
||||
prediction, durations = self.process(raw_input)
|
||||
if self.allow_flagging == "auto":
|
||||
flag_index = self.flagging_callback.flag(
|
||||
self,
|
||||
raw_input,
|
||||
prediction,
|
||||
flag_option="" if self.flagging_options else None,
|
||||
username=username,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": prediction,
|
||||
"durations": durations,
|
||||
"avg_durations": self.config.get("avg_durations"),
|
||||
"flag_index": flag_index,
|
||||
}
|
||||
|
||||
def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:
|
||||
"""
|
||||
First preprocesses the input, then runs prediction using
|
||||
@ -585,19 +611,6 @@ class Interface:
|
||||
def interpret(self, raw_input: List[Any]) -> List[Any]:
|
||||
return interpretation.run_interpret(self, raw_input)
|
||||
|
||||
def block_thread(
|
||||
self,
|
||||
) -> None:
|
||||
"""Block main thread until interrupted by user."""
|
||||
try:
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
except (KeyboardInterrupt, OSError):
|
||||
print("Keyboard interruption in main thread... closing server.")
|
||||
self.server.close()
|
||||
if self.enable_queue:
|
||||
queueing.close()
|
||||
|
||||
def test_launch(self) -> None:
|
||||
for predict_fn in self.predict:
|
||||
print("Test launch: {}()...".format(predict_fn.__name__), end=" ")
|
||||
@ -613,219 +626,10 @@ class Interface:
|
||||
print("PASSED")
|
||||
continue
|
||||
|
||||
def launch(
|
||||
self,
|
||||
inline: bool = None,
|
||||
inbrowser: bool = None,
|
||||
share: bool = False,
|
||||
debug: bool = False,
|
||||
auth: Optional[Callable | Tuple[str, str] | List[Tuple[str, str]]] = None,
|
||||
auth_message: Optional[str] = None,
|
||||
private_endpoint: Optional[str] = None,
|
||||
prevent_thread_lock: bool = False,
|
||||
show_error: bool = True,
|
||||
server_name: Optional[str] = None,
|
||||
server_port: Optional[int] = None,
|
||||
show_tips: bool = False,
|
||||
enable_queue: bool = False,
|
||||
height: int = 500,
|
||||
width: int = 900,
|
||||
encrypt: bool = False,
|
||||
cache_examples: bool = False,
|
||||
favicon_path: Optional[str] = None,
|
||||
ssl_keyfile: Optional[str] = None,
|
||||
ssl_certfile: Optional[str] = None,
|
||||
ssl_keyfile_password: Optional[str] = None,
|
||||
) -> Tuple[flask.Flask, str, str]:
|
||||
"""
|
||||
Launches the webserver that serves the UI for the interface.
|
||||
Parameters:
|
||||
inline (bool): whether to display in the interface inline on python notebooks.
|
||||
inbrowser (bool): whether to automatically launch the interface in a new tab on the default browser.
|
||||
share (bool): whether to create a publicly shareable link from your computer for the interface.
|
||||
debug (bool): if True, and the interface was launched from Google Colab, prints the errors in the cell output.
|
||||
auth (Callable, Union[Tuple[str, str], List[Tuple[str, str]]]): If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login.
|
||||
auth_message (str): If provided, HTML message provided on login page.
|
||||
private_endpoint (str): If provided, the public URL of the interface will be this endpoint (should generally be unchanged).
|
||||
prevent_thread_lock (bool): If True, the interface will block the main thread while the server is running.
|
||||
show_error (bool): If True, any errors in the interface will be printed in the browser console log
|
||||
server_port (int): will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
|
||||
server_name (str): to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
|
||||
show_tips (bool): if True, will occasionally show tips about new Gradio features
|
||||
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
|
||||
width (int): The width in pixels of the iframe element containing the interface (used if inline=True)
|
||||
height (int): The height in pixels of the iframe element containing the interface (used if inline=True)
|
||||
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
|
||||
cache_examples (bool): If True, examples outputs will be processed and cached in a folder, and will be used if a user uses an example input.
|
||||
favicon_path (str): If a path to a file (.png, .gif, or .ico) is provided, it will be used as the favicon for the web page.
|
||||
ssl_keyfile (str): If a path to a file is provided, will use this as the private key file to create a local server running on https.
|
||||
ssl_certfile (str): If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
|
||||
ssl_keyfile_password (str): If a password is provided, will use this with the ssl certificate for https.
|
||||
Returns:
|
||||
app (flask.Flask): Flask app object
|
||||
path_to_local_server (str): Locally accessible link
|
||||
share_url (str): Publicly accessible link (if share=True)
|
||||
"""
|
||||
self.config = self.get_config_file()
|
||||
self.cache_examples = cache_examples
|
||||
if (
|
||||
auth
|
||||
and not callable(auth)
|
||||
and not isinstance(auth[0], tuple)
|
||||
and not isinstance(auth[0], list)
|
||||
):
|
||||
auth = [auth]
|
||||
self.auth = auth
|
||||
self.auth_message = auth_message
|
||||
self.show_tips = show_tips
|
||||
self.show_error = show_error
|
||||
self.height = self.height or height
|
||||
self.width = self.width or width
|
||||
self.favicon_path = favicon_path
|
||||
|
||||
if self.encrypt is None:
|
||||
self.encrypt = encrypt
|
||||
if self.encrypt:
|
||||
self.encryption_key = encryptor.get_key(
|
||||
getpass.getpass("Enter key for encryption: ")
|
||||
)
|
||||
|
||||
if self.enable_queue is None:
|
||||
self.enable_queue = enable_queue
|
||||
def launch(self, **args):
|
||||
if self.allow_flagging != "never":
|
||||
self.flagging_callback.setup(self.flagging_dir)
|
||||
|
||||
config = self.get_config_file()
|
||||
self.config = config
|
||||
|
||||
if self.cache_examples:
|
||||
cache_interface_examples(self)
|
||||
|
||||
server_port, path_to_local_server, app, server = networking.start_server(
|
||||
self,
|
||||
server_name,
|
||||
server_port,
|
||||
ssl_keyfile,
|
||||
ssl_certfile,
|
||||
ssl_keyfile_password,
|
||||
)
|
||||
|
||||
self.local_url = path_to_local_server
|
||||
self.server_port = server_port
|
||||
self.status = "RUNNING"
|
||||
self.server_app = app
|
||||
self.server = server
|
||||
|
||||
utils.launch_counter()
|
||||
|
||||
# If running in a colab or not able to access localhost,
|
||||
# automatically create a shareable link.
|
||||
is_colab = utils.colab_check()
|
||||
if is_colab or not (networking.url_ok(path_to_local_server)):
|
||||
share = True
|
||||
if is_colab:
|
||||
if debug:
|
||||
print(strings.en["COLAB_DEBUG_TRUE"])
|
||||
else:
|
||||
print(strings.en["COLAB_DEBUG_FALSE"])
|
||||
else:
|
||||
print(strings.en["RUNNING_LOCALLY"].format(path_to_local_server))
|
||||
if is_colab and self.requires_permissions:
|
||||
print(strings.en["MEDIA_PERMISSIONS_IN_COLAB"])
|
||||
|
||||
if private_endpoint is not None:
|
||||
share = True
|
||||
|
||||
if share:
|
||||
if self.is_space:
|
||||
raise RuntimeError("Share is not supported when you are in Spaces")
|
||||
try:
|
||||
share_url = networking.setup_tunnel(server_port, private_endpoint)
|
||||
self.share_url = share_url
|
||||
print(strings.en["SHARE_LINK_DISPLAY"].format(share_url))
|
||||
if private_endpoint:
|
||||
print(strings.en["PRIVATE_LINK_MESSAGE"])
|
||||
else:
|
||||
print(strings.en["SHARE_LINK_MESSAGE"])
|
||||
except RuntimeError:
|
||||
if self.analytics_enabled:
|
||||
utils.error_analytics(self.ip_address, "Not able to set up tunnel")
|
||||
share_url = None
|
||||
share = False
|
||||
print(strings.en["COULD_NOT_GET_SHARE_LINK"])
|
||||
else:
|
||||
print(strings.en["PUBLIC_SHARE_TRUE"])
|
||||
share_url = None
|
||||
|
||||
self.share = share
|
||||
|
||||
if inbrowser:
|
||||
link = share_url if share else path_to_local_server
|
||||
webbrowser.open(link)
|
||||
|
||||
# Check if running in a Python notebook in which case, display inline
|
||||
if inline is None:
|
||||
inline = utils.ipython_check() and (auth is None)
|
||||
if inline:
|
||||
if auth is not None:
|
||||
print(
|
||||
"Warning: authentication is not supported inline. Please"
|
||||
"click the link to access the interface in a new tab."
|
||||
)
|
||||
try:
|
||||
from IPython.display import IFrame, display # type: ignore
|
||||
|
||||
if share:
|
||||
while not networking.url_ok(share_url):
|
||||
time.sleep(1)
|
||||
display(IFrame(share_url, width=self.width, height=self.height))
|
||||
else:
|
||||
display(
|
||||
IFrame(
|
||||
path_to_local_server, width=self.width, height=self.height
|
||||
)
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
data = {
|
||||
"launch_method": "browser" if inbrowser else "inline",
|
||||
"is_google_colab": is_colab,
|
||||
"is_sharing_on": share,
|
||||
"share_url": share_url,
|
||||
"ip_address": self.ip_address,
|
||||
"enable_queue": self.enable_queue,
|
||||
"show_tips": self.show_tips,
|
||||
"api_mode": self.api_mode,
|
||||
"server_name": server_name,
|
||||
"server_port": server_port,
|
||||
"is_spaces": self.is_space,
|
||||
}
|
||||
if self.analytics_enabled:
|
||||
utils.launch_analytics(data)
|
||||
|
||||
utils.show_tip(self)
|
||||
|
||||
# Block main thread if debug==True
|
||||
if debug or int(os.getenv("GRADIO_DEBUG", 0)) == 1:
|
||||
self.block_thread()
|
||||
# Block main thread if running in a script to stop script from exiting
|
||||
is_in_interactive_mode = bool(getattr(sys, "ps1", sys.flags.interactive))
|
||||
if not prevent_thread_lock and not is_in_interactive_mode:
|
||||
self.block_thread()
|
||||
|
||||
return app, path_to_local_server, share_url
|
||||
|
||||
def close(self, verbose: bool = True) -> None:
|
||||
"""
|
||||
Closes the Interface that was launched and frees the port.
|
||||
"""
|
||||
try:
|
||||
self.server.close()
|
||||
if verbose:
|
||||
print("Closing server running on port: {}".format(self.server_port))
|
||||
except (AttributeError, OSError): # can't close if not running
|
||||
pass
|
||||
return super().launch(**args)
|
||||
|
||||
def integrate(self, comet_ml=None, wandb=None, mlflow=None) -> None:
|
||||
"""
|
||||
|
245
gradio/launchable.py
Normal file
245
gradio/launchable.py
Normal file
@ -0,0 +1,245 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import getpass
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import webbrowser
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
|
||||
from gradio import encryptor, networking, queueing, strings, utils # type: ignore
|
||||
from gradio.process_examples import cache_interface_examples
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
import flask
|
||||
|
||||
|
||||
class Launchable:
|
||||
"""
|
||||
Gradio launchables can be launched to serve content to a port.
|
||||
"""
|
||||
|
||||
def launch(
|
||||
self,
|
||||
inline: bool = None,
|
||||
inbrowser: bool = None,
|
||||
share: bool = False,
|
||||
debug: bool = False,
|
||||
auth: Optional[Callable | Tuple[str, str] | List[Tuple[str, str]]] = None,
|
||||
auth_message: Optional[str] = None,
|
||||
private_endpoint: Optional[str] = None,
|
||||
prevent_thread_lock: bool = False,
|
||||
show_error: bool = True,
|
||||
server_name: Optional[str] = None,
|
||||
server_port: Optional[int] = None,
|
||||
show_tips: bool = False,
|
||||
enable_queue: bool = False,
|
||||
height: int = 500,
|
||||
width: int = 900,
|
||||
encrypt: bool = False,
|
||||
cache_examples: bool = False,
|
||||
favicon_path: Optional[str] = None,
|
||||
ssl_keyfile: Optional[str] = None,
|
||||
ssl_certfile: Optional[str] = None,
|
||||
ssl_keyfile_password: Optional[str] = None,
|
||||
) -> Tuple[flask.Flask, str, str]:
|
||||
"""
|
||||
Launches the webserver that serves the UI for the interface.
|
||||
Parameters:
|
||||
inline (bool): whether to display in the interface inline on python notebooks.
|
||||
inbrowser (bool): whether to automatically launch the interface in a new tab on the default browser.
|
||||
share (bool): whether to create a publicly shareable link from your computer for the interface.
|
||||
debug (bool): if True, and the interface was launched from Google Colab, prints the errors in the cell output.
|
||||
auth (Callable, Union[Tuple[str, str], List[Tuple[str, str]]]): If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login.
|
||||
auth_message (str): If provided, HTML message provided on login page.
|
||||
private_endpoint (str): If provided, the public URL of the interface will be this endpoint (should generally be unchanged).
|
||||
prevent_thread_lock (bool): If True, the interface will block the main thread while the server is running.
|
||||
show_error (bool): If True, any errors in the interface will be printed in the browser console log
|
||||
server_port (int): will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
|
||||
server_name (str): to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
|
||||
show_tips (bool): if True, will occasionally show tips about new Gradio features
|
||||
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
|
||||
width (int): The width in pixels of the iframe element containing the interface (used if inline=True)
|
||||
height (int): The height in pixels of the iframe element containing the interface (used if inline=True)
|
||||
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
|
||||
cache_examples (bool): If True, examples outputs will be processed and cached in a folder, and will be used if a user uses an example input.
|
||||
favicon_path (str): If a path to a file (.png, .gif, or .ico) is provided, it will be used as the favicon for the web page.
|
||||
ssl_keyfile (str): If a path to a file is provided, will use this as the private key file to create a local server running on https.
|
||||
ssl_certfile (str): If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
|
||||
ssl_keyfile_password (str): If a password is provided, will use this with the ssl certificate for https.
|
||||
Returns:
|
||||
app (flask.Flask): Flask app object
|
||||
path_to_local_server (str): Locally accessible link
|
||||
share_url (str): Publicly accessible link (if share=True)
|
||||
"""
|
||||
self.config = self.get_config_file()
|
||||
self.cache_examples = cache_examples
|
||||
if (
|
||||
auth
|
||||
and not callable(auth)
|
||||
and not isinstance(auth[0], tuple)
|
||||
and not isinstance(auth[0], list)
|
||||
):
|
||||
auth = [auth]
|
||||
self.auth = auth
|
||||
self.auth_message = auth_message
|
||||
self.show_tips = show_tips
|
||||
self.show_error = show_error
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.favicon_path = favicon_path
|
||||
|
||||
if hasattr(self, "encrypt") and self.encrypt is None:
|
||||
self.encrypt = encrypt
|
||||
if hasattr(self, "encrypt") and self.encrypt:
|
||||
self.encryption_key = encryptor.get_key(
|
||||
getpass.getpass("Enter key for encryption: ")
|
||||
)
|
||||
|
||||
if hasattr(self, "enable_queue") and self.enable_queue is None:
|
||||
self.enable_queue = enable_queue
|
||||
|
||||
config = self.get_config_file()
|
||||
self.config = config
|
||||
|
||||
if self.cache_examples:
|
||||
cache_interface_examples(self)
|
||||
|
||||
server_port, path_to_local_server, app, server = networking.start_server(
|
||||
self,
|
||||
server_name,
|
||||
server_port,
|
||||
ssl_keyfile,
|
||||
ssl_certfile,
|
||||
ssl_keyfile_password,
|
||||
)
|
||||
|
||||
self.local_url = path_to_local_server
|
||||
self.server_port = server_port
|
||||
self.status = "RUNNING"
|
||||
self.server_app = app
|
||||
self.server = server
|
||||
|
||||
utils.launch_counter()
|
||||
|
||||
# If running in a colab or not able to access localhost,
|
||||
# automatically create a shareable link.
|
||||
is_colab = utils.colab_check()
|
||||
if is_colab or not (networking.url_ok(path_to_local_server)):
|
||||
share = True
|
||||
if is_colab:
|
||||
if debug:
|
||||
print(strings.en["COLAB_DEBUG_TRUE"])
|
||||
else:
|
||||
print(strings.en["COLAB_DEBUG_FALSE"])
|
||||
else:
|
||||
print(strings.en["RUNNING_LOCALLY"].format(path_to_local_server))
|
||||
if is_colab and self.requires_permissions:
|
||||
print(strings.en["MEDIA_PERMISSIONS_IN_COLAB"])
|
||||
|
||||
if private_endpoint is not None:
|
||||
share = True
|
||||
|
||||
if share:
|
||||
if self.is_space:
|
||||
raise RuntimeError("Share is not supported when you are in Spaces")
|
||||
try:
|
||||
share_url = networking.setup_tunnel(server_port, private_endpoint)
|
||||
self.share_url = share_url
|
||||
print(strings.en["SHARE_LINK_DISPLAY"].format(share_url))
|
||||
if private_endpoint:
|
||||
print(strings.en["PRIVATE_LINK_MESSAGE"])
|
||||
else:
|
||||
print(strings.en["SHARE_LINK_MESSAGE"])
|
||||
except RuntimeError:
|
||||
if self.analytics_enabled:
|
||||
utils.error_analytics(self.ip_address, "Not able to set up tunnel")
|
||||
share_url = None
|
||||
share = False
|
||||
print(strings.en["COULD_NOT_GET_SHARE_LINK"])
|
||||
else:
|
||||
print(strings.en["PUBLIC_SHARE_TRUE"])
|
||||
share_url = None
|
||||
|
||||
self.share = share
|
||||
|
||||
if inbrowser:
|
||||
link = share_url if share else path_to_local_server
|
||||
webbrowser.open(link)
|
||||
|
||||
# Check if running in a Python notebook in which case, display inline
|
||||
if inline is None:
|
||||
inline = utils.ipython_check() and (auth is None)
|
||||
if inline:
|
||||
if auth is not None:
|
||||
print(
|
||||
"Warning: authentication is not supported inline. Please"
|
||||
"click the link to access the interface in a new tab."
|
||||
)
|
||||
try:
|
||||
from IPython.display import IFrame, display # type: ignore
|
||||
|
||||
if share:
|
||||
while not networking.url_ok(share_url):
|
||||
time.sleep(1)
|
||||
display(IFrame(share_url, width=self.width, height=self.height))
|
||||
else:
|
||||
display(
|
||||
IFrame(
|
||||
path_to_local_server, width=self.width, height=self.height
|
||||
)
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
data = {
|
||||
"launch_method": "browser" if inbrowser else "inline",
|
||||
"is_google_colab": is_colab,
|
||||
"is_sharing_on": share,
|
||||
"share_url": share_url,
|
||||
"ip_address": self.ip_address,
|
||||
"enable_queue": self.enable_queue,
|
||||
"show_tips": self.show_tips,
|
||||
"api_mode": self.api_mode,
|
||||
"server_name": server_name,
|
||||
"server_port": server_port,
|
||||
"is_spaces": self.is_space,
|
||||
}
|
||||
if self.analytics_enabled:
|
||||
utils.launch_analytics(data)
|
||||
|
||||
utils.show_tip(self)
|
||||
|
||||
# Block main thread if debug==True
|
||||
if debug or int(os.getenv("GRADIO_DEBUG", 0)) == 1:
|
||||
self.block_thread()
|
||||
# Block main thread if running in a script to stop script from exiting
|
||||
is_in_interactive_mode = bool(getattr(sys, "ps1", sys.flags.interactive))
|
||||
if not prevent_thread_lock and not is_in_interactive_mode:
|
||||
self.block_thread()
|
||||
|
||||
return app, path_to_local_server, share_url
|
||||
|
||||
def close(self, verbose: bool = True) -> None:
|
||||
"""
|
||||
Closes the Interface that was launched and frees the port.
|
||||
"""
|
||||
try:
|
||||
self.server.close()
|
||||
if verbose:
|
||||
print("Closing server running on port: {}".format(self.server_port))
|
||||
except (AttributeError, OSError): # can't close if not running
|
||||
pass
|
||||
|
||||
def block_thread(
|
||||
self,
|
||||
) -> None:
|
||||
"""Block main thread until interrupted by user."""
|
||||
try:
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
except (KeyboardInterrupt, OSError):
|
||||
print("Keyboard interruption in main thread... closing server.")
|
||||
self.server.close()
|
||||
if self.enable_queue:
|
||||
queueing.close()
|
@ -20,7 +20,7 @@ from gradio.routes import app
|
||||
from gradio.tunneling import create_tunnel
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
|
||||
from gradio import Interface
|
||||
from gradio.launchable import Launchable
|
||||
|
||||
|
||||
# By default, the local server will try to open on localhost, port 7860.
|
||||
@ -71,7 +71,7 @@ def get_first_available_port(initial: int, final: int) -> int:
|
||||
|
||||
|
||||
def start_server(
|
||||
interface: Interface,
|
||||
launchable: Launchable,
|
||||
server_name: Optional[str] = None,
|
||||
server_port: Optional[int] = None,
|
||||
ssl_keyfile: Optional[str] = None,
|
||||
@ -80,13 +80,13 @@ def start_server(
|
||||
) -> Tuple[int, str, fastapi.FastAPI, Server]:
|
||||
"""Launches a local server running the provided Interface
|
||||
Parameters:
|
||||
interface: The interface object to run on the server
|
||||
launchable: The launchable object to run on the server
|
||||
server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
|
||||
server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
|
||||
auth: If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login.
|
||||
auth: If provided, username and password (or list of username-password tuples) required to access launchable. Can also provide function that takes username and password and returns True if valid login.
|
||||
ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https.
|
||||
ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
|
||||
ssl_keyfile_password (str): If a password is provided, will use this with the ssl certificate for https.
|
||||
ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
|
||||
Returns:
|
||||
port: the port number the server is running on
|
||||
path_to_local_server: the complete address that the local server can be accessed at
|
||||
@ -106,7 +106,7 @@ def start_server(
|
||||
s.close()
|
||||
except OSError:
|
||||
raise OSError(
|
||||
"Port {} is in use. If a gradio.Interface is running on the port, you can close() it or gradio.close_all().".format(
|
||||
"Port {} is in use. If a gradio.Launchable is running on the port, you can close() it or gradio.close_all().".format(
|
||||
server_port
|
||||
)
|
||||
)
|
||||
@ -123,7 +123,7 @@ def start_server(
|
||||
else:
|
||||
path_to_local_server = "http://{}:{}/".format(url_host_name, port)
|
||||
|
||||
auth = interface.auth
|
||||
auth = launchable.auth
|
||||
if auth is not None:
|
||||
if not callable(auth):
|
||||
app.auth = {account[0]: account[1] for account in auth}
|
||||
@ -131,21 +131,21 @@ def start_server(
|
||||
app.auth = auth
|
||||
else:
|
||||
app.auth = None
|
||||
app.interface = interface
|
||||
app.launchable = launchable
|
||||
app.cwd = os.getcwd()
|
||||
app.favicon_path = interface.favicon_path
|
||||
app.favicon_path = launchable.favicon_path
|
||||
app.tokens = {}
|
||||
|
||||
if app.interface.enable_queue:
|
||||
if auth is not None or app.interface.encrypt:
|
||||
if app.launchable.enable_queue:
|
||||
if auth is not None or app.launchable.encrypt:
|
||||
raise ValueError("Cannot queue with encryption or authentication enabled.")
|
||||
queueing.init()
|
||||
app.queue_thread = threading.Thread(
|
||||
target=queueing.queue_thread, args=(path_to_local_server,)
|
||||
)
|
||||
app.queue_thread.start()
|
||||
if interface.save_to is not None: # Used for selenium tests
|
||||
interface.save_to["port"] = port
|
||||
if launchable.save_to is not None: # Used for selenium tests
|
||||
launchable.save_to["port"] = port
|
||||
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
|
@ -32,6 +32,10 @@ class OutputComponent(Component):
|
||||
Output Component. All output components subclass this.
|
||||
"""
|
||||
|
||||
def __init__(self, label: str):
|
||||
self.component_type = "output"
|
||||
super().__init__(label)
|
||||
|
||||
def postprocess(self, y):
|
||||
"""
|
||||
Any postprocessing needed to be performed on function output.
|
||||
|
115
gradio/routes.py
115
gradio/routes.py
@ -24,7 +24,6 @@ from jinja2.exceptions import TemplateNotFound
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
from gradio import encryptor, queueing, utils
|
||||
from gradio.process_examples import load_from_cache, process_example
|
||||
|
||||
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
|
||||
STATIC_PATH_LIB = pkg_resources.resource_filename("gradio", "templates/frontend/static")
|
||||
@ -114,9 +113,9 @@ def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
def main(request: Request, user: str = Depends(get_current_user)):
|
||||
if app.auth is None or not (user is None):
|
||||
config = app.interface.config
|
||||
config = app.launchable.config
|
||||
else:
|
||||
config = {"auth_required": True, "auth_message": app.interface.auth_message}
|
||||
config = {"auth_required": True, "auth_message": app.launchable.auth_message}
|
||||
|
||||
try:
|
||||
return templates.TemplateResponse(
|
||||
@ -132,12 +131,12 @@ def main(request: Request, user: str = Depends(get_current_user)):
|
||||
@app.get("/config/", dependencies=[Depends(login_check)])
|
||||
@app.get("/config", dependencies=[Depends(login_check)])
|
||||
def get_config():
|
||||
return app.interface.config
|
||||
return app.launchable.config
|
||||
|
||||
|
||||
@app.get("/static/{path:path}")
|
||||
def static_resource(path: str):
|
||||
if app.interface.share:
|
||||
if app.launchable.share:
|
||||
return RedirectResponse(GRADIO_STATIC_ROOT + path)
|
||||
else:
|
||||
static_file = safe_join(STATIC_PATH_LIB, path)
|
||||
@ -148,7 +147,7 @@ def static_resource(path: str):
|
||||
|
||||
@app.get("/assets/{path:path}")
|
||||
def build_resource(path: str):
|
||||
if app.interface.share:
|
||||
if app.launchable.share:
|
||||
return RedirectResponse(GRADIO_BUILD_ROOT + path)
|
||||
else:
|
||||
build_file = safe_join(BUILD_PATH_LIB, path)
|
||||
@ -160,13 +159,13 @@ def build_resource(path: str):
|
||||
@app.get("/file/{path:path}", dependencies=[Depends(login_check)])
|
||||
def file(path):
|
||||
if (
|
||||
app.interface.encrypt
|
||||
and isinstance(app.interface.examples, str)
|
||||
and path.startswith(app.interface.examples)
|
||||
app.launchable.encrypt
|
||||
and isinstance(app.launchable.examples, str)
|
||||
and path.startswith(app.launchable.examples)
|
||||
):
|
||||
with open(safe_join(app.cwd, path), "rb") as encrypted_file:
|
||||
encrypted_data = encrypted_file.read()
|
||||
file_data = encryptor.decrypt(app.interface.encryption_key, encrypted_data)
|
||||
file_data = encryptor.decrypt(app.launchable.encryption_key, encrypted_data)
|
||||
return FileResponse(
|
||||
io.BytesIO(file_data), attachment_filename=os.path.basename(path)
|
||||
)
|
||||
@ -177,17 +176,17 @@ def file(path):
|
||||
@app.get("/api", response_class=HTMLResponse) # Needed for Spaces
|
||||
@app.get("/api/", response_class=HTMLResponse)
|
||||
def api_docs(request: Request):
|
||||
inputs = [type(inp) for inp in app.interface.input_components]
|
||||
outputs = [type(out) for out in app.interface.output_components]
|
||||
inputs = [type(inp) for inp in app.launchable.input_components]
|
||||
outputs = [type(out) for out in app.launchable.output_components]
|
||||
input_types_doc, input_types = get_types(inputs, "input")
|
||||
output_types_doc, output_types = get_types(outputs, "output")
|
||||
input_names = [type(inp).__name__ for inp in app.interface.input_components]
|
||||
output_names = [type(out).__name__ for out in app.interface.output_components]
|
||||
if app.interface.examples is not None:
|
||||
sample_inputs = app.interface.examples[0]
|
||||
input_names = [type(inp).__name__ for inp in app.launchable.input_components]
|
||||
output_names = [type(out).__name__ for out in app.launchable.output_components]
|
||||
if app.launchable.examples is not None:
|
||||
sample_inputs = app.launchable.examples[0]
|
||||
else:
|
||||
sample_inputs = [
|
||||
inp.generate_sample() for inp in app.interface.input_components
|
||||
inp.generate_sample() for inp in app.launchable.input_components
|
||||
]
|
||||
docs = {
|
||||
"inputs": input_names,
|
||||
@ -201,9 +200,9 @@ def api_docs(request: Request):
|
||||
"input_types_doc": input_types_doc,
|
||||
"output_types_doc": output_types_doc,
|
||||
"sample_inputs": sample_inputs,
|
||||
"auth": app.interface.auth,
|
||||
"local_login_url": urllib.parse.urljoin(app.interface.local_url, "login"),
|
||||
"local_api_url": urllib.parse.urljoin(app.interface.local_url, "api/predict"),
|
||||
"auth": app.launchable.auth,
|
||||
"local_login_url": urllib.parse.urljoin(app.launchable.local_url, "login"),
|
||||
"local_api_url": urllib.parse.urljoin(app.launchable.local_url, "api/predict"),
|
||||
}
|
||||
return templates.TemplateResponse("api_docs.html", {"request": request, **docs})
|
||||
|
||||
@ -211,60 +210,26 @@ def api_docs(request: Request):
|
||||
@app.post("/api/predict/", dependencies=[Depends(login_check)])
|
||||
async def predict(request: Request, username: str = Depends(get_current_user)):
|
||||
body = await request.json()
|
||||
flag_index = None
|
||||
|
||||
if body.get("example_id") is not None:
|
||||
example_id = body["example_id"]
|
||||
if app.interface.cache_examples:
|
||||
prediction = await run_in_threadpool(
|
||||
load_from_cache, app.interface, example_id
|
||||
)
|
||||
durations = None
|
||||
try:
|
||||
output = await run_in_threadpool(app.launchable.process_api, body, username)
|
||||
except BaseException as error:
|
||||
if app.launchable.show_error:
|
||||
traceback.print_exc()
|
||||
return JSONResponse(content={"error": str(error)}, status_code=500)
|
||||
else:
|
||||
prediction, durations = await run_in_threadpool(
|
||||
process_example, app.interface, example_id
|
||||
)
|
||||
else:
|
||||
raw_input = body["data"]
|
||||
if app.interface.show_error:
|
||||
try:
|
||||
prediction, durations = await run_in_threadpool(
|
||||
app.interface.process, raw_input
|
||||
)
|
||||
except BaseException as error:
|
||||
traceback.print_exc()
|
||||
return JSONResponse(content={"error": str(error)}, status_code=500)
|
||||
else:
|
||||
prediction, durations = await run_in_threadpool(
|
||||
app.interface.process, raw_input
|
||||
)
|
||||
if app.interface.allow_flagging == "auto":
|
||||
flag_index = await run_in_threadpool(
|
||||
app.interface.flagging_callback.flag,
|
||||
app.interface,
|
||||
raw_input,
|
||||
prediction,
|
||||
flag_option="" if app.interface.flagging_options else None,
|
||||
username=username,
|
||||
)
|
||||
output = {
|
||||
"data": prediction,
|
||||
"durations": durations,
|
||||
"avg_durations": app.interface.config.get("avg_durations"),
|
||||
"flag_index": flag_index,
|
||||
}
|
||||
raise error
|
||||
return output
|
||||
|
||||
|
||||
@app.post("/api/flag/", dependencies=[Depends(login_check)])
|
||||
async def flag(request: Request, username: str = Depends(get_current_user)):
|
||||
if app.interface.analytics_enabled:
|
||||
await utils.log_feature_analytics(app.interface.ip_address, "flag")
|
||||
if app.launchable.analytics_enabled:
|
||||
await utils.log_feature_analytics(app.launchable.ip_address, "flag")
|
||||
body = await request.json()
|
||||
data = body["data"]
|
||||
await run_in_threadpool(
|
||||
app.interface.flagging_callback.flag,
|
||||
app.interface,
|
||||
app.launchable.flagging_callback.flag,
|
||||
app.launchable,
|
||||
data["input_data"],
|
||||
data["output_data"],
|
||||
flag_option=data.get("flag_option"),
|
||||
@ -276,12 +241,12 @@ async def flag(request: Request, username: str = Depends(get_current_user)):
|
||||
|
||||
@app.post("/api/interpret/", dependencies=[Depends(login_check)])
|
||||
async def interpret(request: Request):
|
||||
if app.interface.analytics_enabled:
|
||||
await utils.log_feature_analytics(app.interface.ip_address, "interpret")
|
||||
if app.launchable.analytics_enabled:
|
||||
await utils.log_feature_analytics(app.launchable.ip_address, "interpret")
|
||||
body = await request.json()
|
||||
raw_input = body["data"]
|
||||
interpretation_scores, alternative_outputs = await run_in_threadpool(
|
||||
app.interface.interpret, raw_input
|
||||
app.launchable.interpret, raw_input
|
||||
)
|
||||
return {
|
||||
"interpretation_scores": interpretation_scores,
|
||||
@ -368,20 +333,20 @@ def set_state(*args):
|
||||
if __name__ == "__main__": # Run directly for debugging: python app.py
|
||||
from gradio import Interface
|
||||
|
||||
app.interface = Interface(
|
||||
app.launchable = Interface(
|
||||
lambda x: "Hello, " + x, "text", "text", analytics_enabled=False
|
||||
)
|
||||
app.interface.favicon_path = None
|
||||
app.interface.config = app.interface.get_config_file()
|
||||
app.interface.show_error = True
|
||||
app.interface.flagging_callback.setup(app.interface.flagging_dir)
|
||||
app.launchable.favicon_path = None
|
||||
app.launchable.config = app.launchable.get_config_file()
|
||||
app.launchable.show_error = True
|
||||
app.launchable.flagging_callback.setup(app.launchable.flagging_dir)
|
||||
app.tokens = {}
|
||||
|
||||
auth = True
|
||||
if auth:
|
||||
app.interface.auth = ("a", "b")
|
||||
app.launchable.auth = ("a", "b")
|
||||
app.auth = {"a": "b"}
|
||||
app.interface.auth_message = None
|
||||
app.launchable.auth_message = None
|
||||
else:
|
||||
app.auth = None
|
||||
|
||||
|
23
gradio/static.py
Normal file
23
gradio/static.py
Normal file
@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from gradio.component import Component
|
||||
|
||||
|
||||
class StaticComponent(Component):
|
||||
def __init__(self, label: str):
|
||||
self.component_type = "static"
|
||||
super().__init__(label)
|
||||
|
||||
def process(self, y):
|
||||
"""
|
||||
Any processing needed to be performed on the default value.
|
||||
"""
|
||||
return y
|
||||
|
||||
|
||||
class Markdown(StaticComponent):
|
||||
pass
|
||||
|
||||
|
||||
class Button(StaticComponent):
|
||||
pass
|
82
ui/packages/app/src/Blocks.svelte
Normal file
82
ui/packages/app/src/Blocks.svelte
Normal file
@ -0,0 +1,82 @@
|
||||
<script lang="ts">
|
||||
import Pane from "./page_layouts/Pane.svelte";
|
||||
import { _ } from "svelte-i18n";
|
||||
import { setupi18n } from "./i18n";
|
||||
setupi18n();
|
||||
|
||||
interface Component {
|
||||
name: string;
|
||||
id: string;
|
||||
props: Record<string, unknown>;
|
||||
}
|
||||
|
||||
interface Layout {
|
||||
name: string;
|
||||
type: string;
|
||||
children: Layout | number;
|
||||
}
|
||||
|
||||
interface Dependency {
|
||||
trigger: "click" | "change";
|
||||
targets: Array<string>;
|
||||
inputs: Array<string>;
|
||||
outputs: Array<string>;
|
||||
}
|
||||
|
||||
export let fn: (...args: any) => Promise<unknown>;
|
||||
export let components: Array<Component>;
|
||||
export let layout: Layout;
|
||||
export let dependencies: Array<Dependency>;
|
||||
export let theme: string;
|
||||
export let static_src: string;
|
||||
|
||||
let values: Record<string, unknown> = {};
|
||||
let component_id_map: Record<string, Component> = {};
|
||||
let event_listener_map: Record<string, Array<number>> = {};
|
||||
for (let component of components) {
|
||||
component_id_map[component.id] = component;
|
||||
if (component.props && "default" in component.props) {
|
||||
values[component.id] = component.props.default;
|
||||
} else {
|
||||
values[component.id] = null;
|
||||
}
|
||||
event_listener_map[component.id] = [];
|
||||
}
|
||||
dependencies.forEach((dependency, i) => {
|
||||
if (dependency.trigger === "click") {
|
||||
for (let target of dependency.targets) {
|
||||
event_listener_map[target].push(i);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
const setValues = (i: string, value: unknown) => {
|
||||
values[i] = value;
|
||||
};
|
||||
const triggerTarget = (i: string) => {
|
||||
event_listener_map[i].forEach((fn_index: number) => {
|
||||
let dependency = dependencies[fn_index];
|
||||
fn("predict", {
|
||||
fn_index: fn_index,
|
||||
data: dependency.inputs.map((i) => values[i])
|
||||
}).then((output) => {
|
||||
output["data"].forEach((value, i) => {
|
||||
values[dependency.outputs[i]] = value;
|
||||
});
|
||||
});
|
||||
});
|
||||
};
|
||||
</script>
|
||||
|
||||
<div class="mx-auto container p-4">
|
||||
<Pane
|
||||
{component_id_map}
|
||||
children={layout.children}
|
||||
{dependencies}
|
||||
{values}
|
||||
{setValues}
|
||||
{triggerTarget}
|
||||
{theme}
|
||||
{static_src}
|
||||
/>
|
||||
</div>
|
@ -26,6 +26,9 @@ import OutputVideo from "./output/Video/config.js";
|
||||
import OutputTimeSeries from "./output/TimeSeries/config.js";
|
||||
import OutputChatbot from "./output/Chatbot/config.js";
|
||||
|
||||
import StaticButton from "./static/Button/config.js";
|
||||
import StaticMarkdown from "./static/Markdown/config.js";
|
||||
|
||||
export const input_component_map = {
|
||||
audio: InputAudio,
|
||||
checkbox: InputCheckbox,
|
||||
@ -57,3 +60,14 @@ export const output_component_map = {
|
||||
video: OutputVideo,
|
||||
chatbot: OutputChatbot
|
||||
};
|
||||
|
||||
export const static_component_map = {
|
||||
button: StaticButton,
|
||||
markdown: StaticMarkdown
|
||||
};
|
||||
|
||||
export const all_components_map = {
|
||||
input: input_component_map,
|
||||
output: output_component_map,
|
||||
static: static_component_map
|
||||
};
|
||||
|
@ -0,0 +1,7 @@
|
||||
<script lang="ts">
|
||||
export let label: string;
|
||||
</script>
|
||||
|
||||
<button class="px-4 py-2 rounded bg-gray-100 hover:bg-gray-200 transition">
|
||||
{label}
|
||||
</button>
|
5
ui/packages/app/src/components/static/Button/config.js
Normal file
5
ui/packages/app/src/components/static/Button/config.js
Normal file
@ -0,0 +1,5 @@
|
||||
import Component from "./Button.svelte";
|
||||
|
||||
export default {
|
||||
component: Component
|
||||
};
|
@ -0,0 +1,5 @@
|
||||
<script lang="ts">
|
||||
export let label: string;
|
||||
</script>
|
||||
|
||||
<div>{@html label}</div>
|
5
ui/packages/app/src/components/static/Markdown/config.js
Normal file
5
ui/packages/app/src/components/static/Markdown/config.js
Normal file
@ -0,0 +1,5 @@
|
||||
import Component from "./Markdown.svelte";
|
||||
|
||||
export default {
|
||||
component: Component
|
||||
};
|
@ -1,4 +1,5 @@
|
||||
import App from "./App.svelte";
|
||||
import Blocks from "./Blocks.svelte";
|
||||
import Login from "./Login.svelte";
|
||||
import { fn } from "./api";
|
||||
|
||||
@ -36,6 +37,7 @@ interface Config {
|
||||
output_components: Array<Component>;
|
||||
layout: string;
|
||||
live: boolean;
|
||||
mode: "blocks" | "interface" | undefined;
|
||||
queue: boolean;
|
||||
root: string;
|
||||
show_input: boolean;
|
||||
@ -82,7 +84,7 @@ window.launchGradio = (config: Config, element_query: string) => {
|
||||
});
|
||||
} else {
|
||||
let url = new URL(window.location.toString());
|
||||
if (config.theme !== null && config.theme.startsWith("dark")) {
|
||||
if (config.theme && config.theme.startsWith("dark")) {
|
||||
target.classList.add("dark");
|
||||
config.dark = true;
|
||||
if (config.theme === "dark") {
|
||||
@ -95,10 +97,17 @@ window.launchGradio = (config: Config, element_query: string) => {
|
||||
target.classList.add("dark");
|
||||
}
|
||||
config.fn = fn.bind(null, config.root + "api/");
|
||||
new App({
|
||||
target: target,
|
||||
props: config
|
||||
});
|
||||
if (config.mode === "blocks") {
|
||||
new Blocks({
|
||||
target: target,
|
||||
props: config
|
||||
});
|
||||
} else {
|
||||
new App({
|
||||
target: target,
|
||||
props: config
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
55
ui/packages/app/src/page_layouts/Pane.svelte
Normal file
55
ui/packages/app/src/page_layouts/Pane.svelte
Normal file
@ -0,0 +1,55 @@
|
||||
<script>
|
||||
import TabSet from "./TabSet.svelte";
|
||||
import { all_components_map } from "../components/directory";
|
||||
|
||||
export let component_id_map,
|
||||
children,
|
||||
dependencies,
|
||||
type,
|
||||
values,
|
||||
setValues,
|
||||
triggerTarget,
|
||||
theme,
|
||||
static_src;
|
||||
</script>
|
||||
|
||||
<div class="flex gap-4" class:flex-col={type !== "row"}>
|
||||
{#each children as child}
|
||||
{#if typeof child === "object"}
|
||||
{#if child.type === "tabset"}
|
||||
<TabSet
|
||||
{component_id_map}
|
||||
{...child}
|
||||
{values}
|
||||
{setValues}
|
||||
{triggerTarget}
|
||||
{static_src}
|
||||
{theme}
|
||||
/>
|
||||
{:else}
|
||||
<svelte:self
|
||||
{component_id_map}
|
||||
{...child}
|
||||
{values}
|
||||
{setValues}
|
||||
{triggerTarget}
|
||||
{static_src}
|
||||
{theme}
|
||||
/>
|
||||
{/if}
|
||||
{:else if !(component_id_map[child].type === "output" && values[child] === null)}
|
||||
<div class:flex-1={type === "row"} on:click={() => triggerTarget(child)}>
|
||||
<svelte:component
|
||||
this={all_components_map[component_id_map[child].type][
|
||||
component_id_map[child].props.name
|
||||
].component}
|
||||
value={values[child]}
|
||||
setValue={setValues.bind(this, child)}
|
||||
{...component_id_map[child].props}
|
||||
{static_src}
|
||||
{theme}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
54
ui/packages/app/src/page_layouts/TabSet.svelte
Normal file
54
ui/packages/app/src/page_layouts/TabSet.svelte
Normal file
@ -0,0 +1,54 @@
|
||||
<script>
|
||||
import Pane from "./Pane.svelte";
|
||||
export let component_id_map,
|
||||
children,
|
||||
type,
|
||||
values,
|
||||
setValues,
|
||||
triggerTarget,
|
||||
theme,
|
||||
static_src;
|
||||
|
||||
let selected_tab = 0;
|
||||
console.log("tabs", children);
|
||||
</script>
|
||||
|
||||
<div class="flex flex-col">
|
||||
<div class="flex">
|
||||
{#each children as child, i}
|
||||
{#if i === selected_tab}
|
||||
<button
|
||||
class="px-4 py-2 font-semibold border-2 border-b-0 rounded-t border-gray-200"
|
||||
>
|
||||
{child.name}
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
class="px-4 py-2 border-b-2 border-gray-200"
|
||||
on:click={() => {
|
||||
selected_tab = i;
|
||||
}}
|
||||
>
|
||||
{child.name}
|
||||
</button>
|
||||
{/if}
|
||||
{/each}
|
||||
<div class="flex-1 border-b-2 border-gray-200" />
|
||||
</div>
|
||||
{#each children as child, i}
|
||||
<div
|
||||
class="p-2 border-2 border-t-0 border-gray-200"
|
||||
class:hidden={i !== selected_tab}
|
||||
>
|
||||
<Pane
|
||||
{component_id_map}
|
||||
{...child}
|
||||
{values}
|
||||
{setValues}
|
||||
{triggerTarget}
|
||||
{static_src}
|
||||
{theme}
|
||||
/>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
Loading…
Reference in New Issue
Block a user