Gradio Blocks (#590)

Integrate initial work on Gradio Blocks by creating Blocks class and frontend Block code.
This commit is contained in:
aliabid94 2022-02-28 20:35:21 -08:00 committed by GitHub
parent faf55f2c9e
commit 11dda55352
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 806 additions and 331 deletions

View File

@ -19,7 +19,7 @@ iface = gr.Interface(
),
gr.inputs.Textbox(lines=3, default="The fast brown fox jumps over lazy dogs."),
],
gr.outputs.HighlightedText(color_map={'+': 'green', '-': 'pink'}),
gr.outputs.HighlightedText(color_map={"+": "green", "-": "pink"}),
)
if __name__ == "__main__":
iface.launch()

View File

@ -7,4 +7,4 @@ def greet(name):
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
if __name__ == "__main__":
iface.launch();
iface.launch()

View File

@ -12,4 +12,3 @@ iface = gr.Interface(
)
if __name__ == "__main__":
app, local_url, share_url = iface.launch()

View File

@ -14,5 +14,4 @@ iface = gr.Interface(
outputs=["text", "number"],
)
if __name__ == "__main__":
iface.launch();
iface.launch()

41
demo/xray_blocks/run.py Normal file
View File

@ -0,0 +1,41 @@
import gradio as gr
import random
xray_model = lambda diseases, img: {disease: random.random() for disease in diseases}
ct_model = lambda diseases, img: {disease: 0.1 for disease in diseases}
xray_blocks = gr.Blocks()
with xray_blocks:
gr.Markdown(
"""
# Detect Disease From Scan
With this model you can lorem ipsum
- ipsum 1
- ipsum 2
"""
)
disease = gr.inputs.CheckboxGroup(
["Covid", "Malaria", "Lung Cancer"], label="Disease to Scan For"
)
with gr.Tab("X-ray"):
with gr.Row():
xray_scan = gr.inputs.Image()
xray_results = gr.outputs.JSON()
xray_run = gr.Button("Run")
xray_run.click(xray_model, inputs=[disease, xray_scan], outputs=xray_results)
with gr.Tab("CT Scan"):
with gr.Row():
ct_scan = gr.inputs.Image()
ct_results = gr.outputs.JSON()
ct_run = gr.Button("Run")
ct_run.click(ct_model, inputs=[disease, ct_scan], outputs=ct_results)
overall_probability = gr.outputs.Textbox()
# TODO: remove later
print(xray_blocks.get_config_file())
xray_blocks.launch()

View File

@ -1,5 +1,6 @@
import pkg_resources
from gradio.blocks import Blocks, Column, Row, Tab
from gradio.flagging import (
CSVLogger,
FlaggingCallback,
@ -9,6 +10,7 @@ from gradio.flagging import (
from gradio.interface import Interface, close_all, reset_all
from gradio.mix import Parallel, Series
from gradio.routes import get_state, set_state
from gradio.static import Button, Markdown
current_pkg_version = pkg_resources.require("gradio")[0].version
__version__ = current_pkg_version

150
gradio/blocks.py Normal file
View File

@ -0,0 +1,150 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from gradio import utils
from gradio.context import Context
from gradio.launchable import Launchable
class Block:
def __init__(self):
self._id = Context.id
Context.id += 1
if Context.block is not None:
Context.block.children.append(self)
if Context.root_block is not None:
Context.root_block.blocks[self._id] = self
self.events = []
def click(self, fn, inputs, outputs):
if not isinstance(inputs, list):
inputs = [inputs]
if not isinstance(outputs, list):
outputs = [outputs]
Context.root_block.fns.append(fn)
Context.root_block.dependencies.append(
{
"targets": [self._id],
"trigger": "click",
"inputs": [block._id for block in inputs],
"outputs": [block._id for block in outputs],
}
)
class BlockContext(Block):
def __init__(self):
self.children = []
super().__init__()
def __enter__(self):
self.parent = Context.block
Context.block = self
def __exit__(self, *args):
Context.block = self.parent
class Row(BlockContext):
def get_template_context(self):
return {"type": "row"}
class Column(BlockContext):
def get_template_context(self):
return {"type": "column"}
class Tab(BlockContext):
def __init__(self, name):
self.name = name
super(Tab, self).__init__()
def get_template_context(self):
return {"type": "tab", "name": self.name}
class Blocks(Launchable, BlockContext):
def __init__(self, theme="default"):
# Cleanup shared parameters with Interface
self.save_to = None
self.ip_address = utils.get_local_ip_address()
self.api_mode = False
self.analytics_enabled = True
self.theme = theme
self.requires_permissions = False # TODO: needs to be implemented
self.enable_queue = False
super().__init__()
Context.root_block = self
self.blocks = {}
self.fns = []
self.dependencies = []
def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]:
raw_input = data["data"]
fn_index = data["fn_index"]
fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]
processed_input = [
self.blocks[input_id].preprocess(raw_input[i])
for i, input_id in enumerate(dependency["inputs"])
]
predictions = fn(*processed_input)
if len(dependency["outputs"]) == 1:
predictions = (predictions,)
processed_output = [
self.blocks[output_id].postprocess(predictions[i])
if predictions[i] is not None
else None
for i, output_id in enumerate(dependency["outputs"])
]
return {"data": processed_output}
def get_template_context(self):
return {"type": "column"}
def get_config_file(self):
from gradio.component import Component
config = {"mode": "blocks", "components": [], "theme": self.theme}
for _id, block in self.blocks.items():
if isinstance(block, Component):
config["components"].append(
{
"id": _id,
"type": block.component_type,
"props": block.get_template_context(),
}
)
def getLayout(block_context):
if not isinstance(block_context, BlockContext):
return block_context._id
children = []
running_tabs = []
for child in block_context.children:
if isinstance(child, Tab):
running_tabs.append(getLayout(child))
continue
if len(running_tabs):
children.append({"type": "tabset", "children": running_tabs})
running_tabs = []
children.append(getLayout(child))
if len(running_tabs):
children.append({"type": "tabset", "children": running_tabs})
running_tabs = []
return {"children": children, **block_context.get_template_context()}
config["layout"] = getLayout(self)
config["dependencies"] = self.dependencies
return config
def __enter__(self):
BlockContext.__enter__(self)
Context.root_block = self
def __exit__(self, *args):
BlockContext.__exit__(self, *args)
Context.root_block = self.parent

View File

@ -3,9 +3,10 @@ import shutil
from typing import Any, Dict, Optional
from gradio import processing_utils
from gradio.blocks import Block
class Component:
class Component(Block):
"""
A class for defining the methods that all gradio input and output components should have.
"""
@ -13,6 +14,7 @@ class Component:
def __init__(self, label, requires_permissions=False):
self.label = label
self.requires_permissions = requires_permissions
super().__init__()
def __str__(self):
return self.__repr__()

4
gradio/context.py Normal file
View File

@ -0,0 +1,4 @@
class Context:
root_block = None
block = None
id = 0

View File

@ -36,6 +36,7 @@ class InputComponent(Component):
"""
Constructs an input component.
"""
self.component_type = "input"
self.set_interpret_parameters()
self.optional = optional
super().__init__(label, requires_permissions)

View File

@ -6,38 +6,36 @@ including various methods for constructing an interface and then launching it.
from __future__ import annotations
import copy
import getpass
import os
import random
import re
import sys
import time
import warnings
import weakref
import webbrowser
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from markdown_it import MarkdownIt
from mdit_py_plugins.footnote import footnote_plugin
from gradio import networking # type: ignore
from gradio import encryptor, interpretation, queueing, strings, utils
from gradio import interpretation, utils
from gradio.external import load_from_pipeline, load_interface # type: ignore
from gradio.flagging import CSVLogger, FlaggingCallback # type: ignore
from gradio.inputs import InputComponent
from gradio.inputs import State as i_State # type: ignore
from gradio.inputs import get_input_instance
from gradio.launchable import Launchable
from gradio.outputs import OutputComponent
from gradio.outputs import State as o_State # type: ignore
from gradio.outputs import get_output_instance
from gradio.process_examples import cache_interface_examples
from gradio.process_examples import load_from_cache, process_example
from gradio.routes import predict
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import flask
import transformers
class Interface:
class Interface(Launchable):
"""
Gradio interfaces are created by constructing a `Interface` object
with a locally-defined function, or with `Interface.load()` with the path
@ -546,6 +544,34 @@ class Interface:
else:
return predictions
def process_api(self, data: Dict[str, Any], username: str = None) -> Dict[str, Any]:
flag_index = None
if data.get("example_id") is not None:
example_id = data["example_id"]
if self.cache_examples:
prediction = load_from_cache(self, example_id)
durations = None
else:
prediction, durations = process_example(self, example_id)
else:
raw_input = data["data"]
prediction, durations = self.process(raw_input)
if self.allow_flagging == "auto":
flag_index = self.flagging_callback.flag(
self,
raw_input,
prediction,
flag_option="" if self.flagging_options else None,
username=username,
)
return {
"data": prediction,
"durations": durations,
"avg_durations": self.config.get("avg_durations"),
"flag_index": flag_index,
}
def process(self, raw_input: List[Any]) -> Tuple[List[Any], List[float]]:
"""
First preprocesses the input, then runs prediction using
@ -585,19 +611,6 @@ class Interface:
def interpret(self, raw_input: List[Any]) -> List[Any]:
return interpretation.run_interpret(self, raw_input)
def block_thread(
self,
) -> None:
"""Block main thread until interrupted by user."""
try:
while True:
time.sleep(0.1)
except (KeyboardInterrupt, OSError):
print("Keyboard interruption in main thread... closing server.")
self.server.close()
if self.enable_queue:
queueing.close()
def test_launch(self) -> None:
for predict_fn in self.predict:
print("Test launch: {}()...".format(predict_fn.__name__), end=" ")
@ -613,219 +626,10 @@ class Interface:
print("PASSED")
continue
def launch(
self,
inline: bool = None,
inbrowser: bool = None,
share: bool = False,
debug: bool = False,
auth: Optional[Callable | Tuple[str, str] | List[Tuple[str, str]]] = None,
auth_message: Optional[str] = None,
private_endpoint: Optional[str] = None,
prevent_thread_lock: bool = False,
show_error: bool = True,
server_name: Optional[str] = None,
server_port: Optional[int] = None,
show_tips: bool = False,
enable_queue: bool = False,
height: int = 500,
width: int = 900,
encrypt: bool = False,
cache_examples: bool = False,
favicon_path: Optional[str] = None,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_keyfile_password: Optional[str] = None,
) -> Tuple[flask.Flask, str, str]:
"""
Launches the webserver that serves the UI for the interface.
Parameters:
inline (bool): whether to display in the interface inline on python notebooks.
inbrowser (bool): whether to automatically launch the interface in a new tab on the default browser.
share (bool): whether to create a publicly shareable link from your computer for the interface.
debug (bool): if True, and the interface was launched from Google Colab, prints the errors in the cell output.
auth (Callable, Union[Tuple[str, str], List[Tuple[str, str]]]): If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login.
auth_message (str): If provided, HTML message provided on login page.
private_endpoint (str): If provided, the public URL of the interface will be this endpoint (should generally be unchanged).
prevent_thread_lock (bool): If True, the interface will block the main thread while the server is running.
show_error (bool): If True, any errors in the interface will be printed in the browser console log
server_port (int): will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
server_name (str): to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
show_tips (bool): if True, will occasionally show tips about new Gradio features
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
width (int): The width in pixels of the iframe element containing the interface (used if inline=True)
height (int): The height in pixels of the iframe element containing the interface (used if inline=True)
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
cache_examples (bool): If True, examples outputs will be processed and cached in a folder, and will be used if a user uses an example input.
favicon_path (str): If a path to a file (.png, .gif, or .ico) is provided, it will be used as the favicon for the web page.
ssl_keyfile (str): If a path to a file is provided, will use this as the private key file to create a local server running on https.
ssl_certfile (str): If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
ssl_keyfile_password (str): If a password is provided, will use this with the ssl certificate for https.
Returns:
app (flask.Flask): Flask app object
path_to_local_server (str): Locally accessible link
share_url (str): Publicly accessible link (if share=True)
"""
self.config = self.get_config_file()
self.cache_examples = cache_examples
if (
auth
and not callable(auth)
and not isinstance(auth[0], tuple)
and not isinstance(auth[0], list)
):
auth = [auth]
self.auth = auth
self.auth_message = auth_message
self.show_tips = show_tips
self.show_error = show_error
self.height = self.height or height
self.width = self.width or width
self.favicon_path = favicon_path
if self.encrypt is None:
self.encrypt = encrypt
if self.encrypt:
self.encryption_key = encryptor.get_key(
getpass.getpass("Enter key for encryption: ")
)
if self.enable_queue is None:
self.enable_queue = enable_queue
def launch(self, **args):
if self.allow_flagging != "never":
self.flagging_callback.setup(self.flagging_dir)
config = self.get_config_file()
self.config = config
if self.cache_examples:
cache_interface_examples(self)
server_port, path_to_local_server, app, server = networking.start_server(
self,
server_name,
server_port,
ssl_keyfile,
ssl_certfile,
ssl_keyfile_password,
)
self.local_url = path_to_local_server
self.server_port = server_port
self.status = "RUNNING"
self.server_app = app
self.server = server
utils.launch_counter()
# If running in a colab or not able to access localhost,
# automatically create a shareable link.
is_colab = utils.colab_check()
if is_colab or not (networking.url_ok(path_to_local_server)):
share = True
if is_colab:
if debug:
print(strings.en["COLAB_DEBUG_TRUE"])
else:
print(strings.en["COLAB_DEBUG_FALSE"])
else:
print(strings.en["RUNNING_LOCALLY"].format(path_to_local_server))
if is_colab and self.requires_permissions:
print(strings.en["MEDIA_PERMISSIONS_IN_COLAB"])
if private_endpoint is not None:
share = True
if share:
if self.is_space:
raise RuntimeError("Share is not supported when you are in Spaces")
try:
share_url = networking.setup_tunnel(server_port, private_endpoint)
self.share_url = share_url
print(strings.en["SHARE_LINK_DISPLAY"].format(share_url))
if private_endpoint:
print(strings.en["PRIVATE_LINK_MESSAGE"])
else:
print(strings.en["SHARE_LINK_MESSAGE"])
except RuntimeError:
if self.analytics_enabled:
utils.error_analytics(self.ip_address, "Not able to set up tunnel")
share_url = None
share = False
print(strings.en["COULD_NOT_GET_SHARE_LINK"])
else:
print(strings.en["PUBLIC_SHARE_TRUE"])
share_url = None
self.share = share
if inbrowser:
link = share_url if share else path_to_local_server
webbrowser.open(link)
# Check if running in a Python notebook in which case, display inline
if inline is None:
inline = utils.ipython_check() and (auth is None)
if inline:
if auth is not None:
print(
"Warning: authentication is not supported inline. Please"
"click the link to access the interface in a new tab."
)
try:
from IPython.display import IFrame, display # type: ignore
if share:
while not networking.url_ok(share_url):
time.sleep(1)
display(IFrame(share_url, width=self.width, height=self.height))
else:
display(
IFrame(
path_to_local_server, width=self.width, height=self.height
)
)
except ImportError:
pass
data = {
"launch_method": "browser" if inbrowser else "inline",
"is_google_colab": is_colab,
"is_sharing_on": share,
"share_url": share_url,
"ip_address": self.ip_address,
"enable_queue": self.enable_queue,
"show_tips": self.show_tips,
"api_mode": self.api_mode,
"server_name": server_name,
"server_port": server_port,
"is_spaces": self.is_space,
}
if self.analytics_enabled:
utils.launch_analytics(data)
utils.show_tip(self)
# Block main thread if debug==True
if debug or int(os.getenv("GRADIO_DEBUG", 0)) == 1:
self.block_thread()
# Block main thread if running in a script to stop script from exiting
is_in_interactive_mode = bool(getattr(sys, "ps1", sys.flags.interactive))
if not prevent_thread_lock and not is_in_interactive_mode:
self.block_thread()
return app, path_to_local_server, share_url
def close(self, verbose: bool = True) -> None:
"""
Closes the Interface that was launched and frees the port.
"""
try:
self.server.close()
if verbose:
print("Closing server running on port: {}".format(self.server_port))
except (AttributeError, OSError): # can't close if not running
pass
return super().launch(**args)
def integrate(self, comet_ml=None, wandb=None, mlflow=None) -> None:
"""

245
gradio/launchable.py Normal file
View File

@ -0,0 +1,245 @@
from __future__ import annotations
import getpass
import os
import sys
import time
import webbrowser
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from gradio import encryptor, networking, queueing, strings, utils # type: ignore
from gradio.process_examples import cache_interface_examples
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
import flask
class Launchable:
"""
Gradio launchables can be launched to serve content to a port.
"""
def launch(
self,
inline: bool = None,
inbrowser: bool = None,
share: bool = False,
debug: bool = False,
auth: Optional[Callable | Tuple[str, str] | List[Tuple[str, str]]] = None,
auth_message: Optional[str] = None,
private_endpoint: Optional[str] = None,
prevent_thread_lock: bool = False,
show_error: bool = True,
server_name: Optional[str] = None,
server_port: Optional[int] = None,
show_tips: bool = False,
enable_queue: bool = False,
height: int = 500,
width: int = 900,
encrypt: bool = False,
cache_examples: bool = False,
favicon_path: Optional[str] = None,
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
ssl_keyfile_password: Optional[str] = None,
) -> Tuple[flask.Flask, str, str]:
"""
Launches the webserver that serves the UI for the interface.
Parameters:
inline (bool): whether to display in the interface inline on python notebooks.
inbrowser (bool): whether to automatically launch the interface in a new tab on the default browser.
share (bool): whether to create a publicly shareable link from your computer for the interface.
debug (bool): if True, and the interface was launched from Google Colab, prints the errors in the cell output.
auth (Callable, Union[Tuple[str, str], List[Tuple[str, str]]]): If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login.
auth_message (str): If provided, HTML message provided on login page.
private_endpoint (str): If provided, the public URL of the interface will be this endpoint (should generally be unchanged).
prevent_thread_lock (bool): If True, the interface will block the main thread while the server is running.
show_error (bool): If True, any errors in the interface will be printed in the browser console log
server_port (int): will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
server_name (str): to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
show_tips (bool): if True, will occasionally show tips about new Gradio features
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
width (int): The width in pixels of the iframe element containing the interface (used if inline=True)
height (int): The height in pixels of the iframe element containing the interface (used if inline=True)
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
cache_examples (bool): If True, examples outputs will be processed and cached in a folder, and will be used if a user uses an example input.
favicon_path (str): If a path to a file (.png, .gif, or .ico) is provided, it will be used as the favicon for the web page.
ssl_keyfile (str): If a path to a file is provided, will use this as the private key file to create a local server running on https.
ssl_certfile (str): If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
ssl_keyfile_password (str): If a password is provided, will use this with the ssl certificate for https.
Returns:
app (flask.Flask): Flask app object
path_to_local_server (str): Locally accessible link
share_url (str): Publicly accessible link (if share=True)
"""
self.config = self.get_config_file()
self.cache_examples = cache_examples
if (
auth
and not callable(auth)
and not isinstance(auth[0], tuple)
and not isinstance(auth[0], list)
):
auth = [auth]
self.auth = auth
self.auth_message = auth_message
self.show_tips = show_tips
self.show_error = show_error
self.height = height
self.width = width
self.favicon_path = favicon_path
if hasattr(self, "encrypt") and self.encrypt is None:
self.encrypt = encrypt
if hasattr(self, "encrypt") and self.encrypt:
self.encryption_key = encryptor.get_key(
getpass.getpass("Enter key for encryption: ")
)
if hasattr(self, "enable_queue") and self.enable_queue is None:
self.enable_queue = enable_queue
config = self.get_config_file()
self.config = config
if self.cache_examples:
cache_interface_examples(self)
server_port, path_to_local_server, app, server = networking.start_server(
self,
server_name,
server_port,
ssl_keyfile,
ssl_certfile,
ssl_keyfile_password,
)
self.local_url = path_to_local_server
self.server_port = server_port
self.status = "RUNNING"
self.server_app = app
self.server = server
utils.launch_counter()
# If running in a colab or not able to access localhost,
# automatically create a shareable link.
is_colab = utils.colab_check()
if is_colab or not (networking.url_ok(path_to_local_server)):
share = True
if is_colab:
if debug:
print(strings.en["COLAB_DEBUG_TRUE"])
else:
print(strings.en["COLAB_DEBUG_FALSE"])
else:
print(strings.en["RUNNING_LOCALLY"].format(path_to_local_server))
if is_colab and self.requires_permissions:
print(strings.en["MEDIA_PERMISSIONS_IN_COLAB"])
if private_endpoint is not None:
share = True
if share:
if self.is_space:
raise RuntimeError("Share is not supported when you are in Spaces")
try:
share_url = networking.setup_tunnel(server_port, private_endpoint)
self.share_url = share_url
print(strings.en["SHARE_LINK_DISPLAY"].format(share_url))
if private_endpoint:
print(strings.en["PRIVATE_LINK_MESSAGE"])
else:
print(strings.en["SHARE_LINK_MESSAGE"])
except RuntimeError:
if self.analytics_enabled:
utils.error_analytics(self.ip_address, "Not able to set up tunnel")
share_url = None
share = False
print(strings.en["COULD_NOT_GET_SHARE_LINK"])
else:
print(strings.en["PUBLIC_SHARE_TRUE"])
share_url = None
self.share = share
if inbrowser:
link = share_url if share else path_to_local_server
webbrowser.open(link)
# Check if running in a Python notebook in which case, display inline
if inline is None:
inline = utils.ipython_check() and (auth is None)
if inline:
if auth is not None:
print(
"Warning: authentication is not supported inline. Please"
"click the link to access the interface in a new tab."
)
try:
from IPython.display import IFrame, display # type: ignore
if share:
while not networking.url_ok(share_url):
time.sleep(1)
display(IFrame(share_url, width=self.width, height=self.height))
else:
display(
IFrame(
path_to_local_server, width=self.width, height=self.height
)
)
except ImportError:
pass
data = {
"launch_method": "browser" if inbrowser else "inline",
"is_google_colab": is_colab,
"is_sharing_on": share,
"share_url": share_url,
"ip_address": self.ip_address,
"enable_queue": self.enable_queue,
"show_tips": self.show_tips,
"api_mode": self.api_mode,
"server_name": server_name,
"server_port": server_port,
"is_spaces": self.is_space,
}
if self.analytics_enabled:
utils.launch_analytics(data)
utils.show_tip(self)
# Block main thread if debug==True
if debug or int(os.getenv("GRADIO_DEBUG", 0)) == 1:
self.block_thread()
# Block main thread if running in a script to stop script from exiting
is_in_interactive_mode = bool(getattr(sys, "ps1", sys.flags.interactive))
if not prevent_thread_lock and not is_in_interactive_mode:
self.block_thread()
return app, path_to_local_server, share_url
def close(self, verbose: bool = True) -> None:
"""
Closes the Interface that was launched and frees the port.
"""
try:
self.server.close()
if verbose:
print("Closing server running on port: {}".format(self.server_port))
except (AttributeError, OSError): # can't close if not running
pass
def block_thread(
self,
) -> None:
"""Block main thread until interrupted by user."""
try:
while True:
time.sleep(0.1)
except (KeyboardInterrupt, OSError):
print("Keyboard interruption in main thread... closing server.")
self.server.close()
if self.enable_queue:
queueing.close()

View File

@ -20,7 +20,7 @@ from gradio.routes import app
from gradio.tunneling import create_tunnel
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio import Interface
from gradio.launchable import Launchable
# By default, the local server will try to open on localhost, port 7860.
@ -71,7 +71,7 @@ def get_first_available_port(initial: int, final: int) -> int:
def start_server(
interface: Interface,
launchable: Launchable,
server_name: Optional[str] = None,
server_port: Optional[int] = None,
ssl_keyfile: Optional[str] = None,
@ -80,13 +80,13 @@ def start_server(
) -> Tuple[int, str, fastapi.FastAPI, Server]:
"""Launches a local server running the provided Interface
Parameters:
interface: The interface object to run on the server
launchable: The launchable object to run on the server
server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME.
server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT.
auth: If provided, username and password (or list of username-password tuples) required to access interface. Can also provide function that takes username and password and returns True if valid login.
auth: If provided, username and password (or list of username-password tuples) required to access launchable. Can also provide function that takes username and password and returns True if valid login.
ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https.
ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
ssl_keyfile_password (str): If a password is provided, will use this with the ssl certificate for https.
ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
Returns:
port: the port number the server is running on
path_to_local_server: the complete address that the local server can be accessed at
@ -106,7 +106,7 @@ def start_server(
s.close()
except OSError:
raise OSError(
"Port {} is in use. If a gradio.Interface is running on the port, you can close() it or gradio.close_all().".format(
"Port {} is in use. If a gradio.Launchable is running on the port, you can close() it or gradio.close_all().".format(
server_port
)
)
@ -123,7 +123,7 @@ def start_server(
else:
path_to_local_server = "http://{}:{}/".format(url_host_name, port)
auth = interface.auth
auth = launchable.auth
if auth is not None:
if not callable(auth):
app.auth = {account[0]: account[1] for account in auth}
@ -131,21 +131,21 @@ def start_server(
app.auth = auth
else:
app.auth = None
app.interface = interface
app.launchable = launchable
app.cwd = os.getcwd()
app.favicon_path = interface.favicon_path
app.favicon_path = launchable.favicon_path
app.tokens = {}
if app.interface.enable_queue:
if auth is not None or app.interface.encrypt:
if app.launchable.enable_queue:
if auth is not None or app.launchable.encrypt:
raise ValueError("Cannot queue with encryption or authentication enabled.")
queueing.init()
app.queue_thread = threading.Thread(
target=queueing.queue_thread, args=(path_to_local_server,)
)
app.queue_thread.start()
if interface.save_to is not None: # Used for selenium tests
interface.save_to["port"] = port
if launchable.save_to is not None: # Used for selenium tests
launchable.save_to["port"] = port
config = uvicorn.Config(
app=app,

View File

@ -32,6 +32,10 @@ class OutputComponent(Component):
Output Component. All output components subclass this.
"""
def __init__(self, label: str):
self.component_type = "output"
super().__init__(label)
def postprocess(self, y):
"""
Any postprocessing needed to be performed on function output.

View File

@ -24,7 +24,6 @@ from jinja2.exceptions import TemplateNotFound
from starlette.responses import RedirectResponse
from gradio import encryptor, queueing, utils
from gradio.process_examples import load_from_cache, process_example
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
STATIC_PATH_LIB = pkg_resources.resource_filename("gradio", "templates/frontend/static")
@ -114,9 +113,9 @@ def login(form_data: OAuth2PasswordRequestForm = Depends()):
@app.get("/", response_class=HTMLResponse)
def main(request: Request, user: str = Depends(get_current_user)):
if app.auth is None or not (user is None):
config = app.interface.config
config = app.launchable.config
else:
config = {"auth_required": True, "auth_message": app.interface.auth_message}
config = {"auth_required": True, "auth_message": app.launchable.auth_message}
try:
return templates.TemplateResponse(
@ -132,12 +131,12 @@ def main(request: Request, user: str = Depends(get_current_user)):
@app.get("/config/", dependencies=[Depends(login_check)])
@app.get("/config", dependencies=[Depends(login_check)])
def get_config():
return app.interface.config
return app.launchable.config
@app.get("/static/{path:path}")
def static_resource(path: str):
if app.interface.share:
if app.launchable.share:
return RedirectResponse(GRADIO_STATIC_ROOT + path)
else:
static_file = safe_join(STATIC_PATH_LIB, path)
@ -148,7 +147,7 @@ def static_resource(path: str):
@app.get("/assets/{path:path}")
def build_resource(path: str):
if app.interface.share:
if app.launchable.share:
return RedirectResponse(GRADIO_BUILD_ROOT + path)
else:
build_file = safe_join(BUILD_PATH_LIB, path)
@ -160,13 +159,13 @@ def build_resource(path: str):
@app.get("/file/{path:path}", dependencies=[Depends(login_check)])
def file(path):
if (
app.interface.encrypt
and isinstance(app.interface.examples, str)
and path.startswith(app.interface.examples)
app.launchable.encrypt
and isinstance(app.launchable.examples, str)
and path.startswith(app.launchable.examples)
):
with open(safe_join(app.cwd, path), "rb") as encrypted_file:
encrypted_data = encrypted_file.read()
file_data = encryptor.decrypt(app.interface.encryption_key, encrypted_data)
file_data = encryptor.decrypt(app.launchable.encryption_key, encrypted_data)
return FileResponse(
io.BytesIO(file_data), attachment_filename=os.path.basename(path)
)
@ -177,17 +176,17 @@ def file(path):
@app.get("/api", response_class=HTMLResponse) # Needed for Spaces
@app.get("/api/", response_class=HTMLResponse)
def api_docs(request: Request):
inputs = [type(inp) for inp in app.interface.input_components]
outputs = [type(out) for out in app.interface.output_components]
inputs = [type(inp) for inp in app.launchable.input_components]
outputs = [type(out) for out in app.launchable.output_components]
input_types_doc, input_types = get_types(inputs, "input")
output_types_doc, output_types = get_types(outputs, "output")
input_names = [type(inp).__name__ for inp in app.interface.input_components]
output_names = [type(out).__name__ for out in app.interface.output_components]
if app.interface.examples is not None:
sample_inputs = app.interface.examples[0]
input_names = [type(inp).__name__ for inp in app.launchable.input_components]
output_names = [type(out).__name__ for out in app.launchable.output_components]
if app.launchable.examples is not None:
sample_inputs = app.launchable.examples[0]
else:
sample_inputs = [
inp.generate_sample() for inp in app.interface.input_components
inp.generate_sample() for inp in app.launchable.input_components
]
docs = {
"inputs": input_names,
@ -201,9 +200,9 @@ def api_docs(request: Request):
"input_types_doc": input_types_doc,
"output_types_doc": output_types_doc,
"sample_inputs": sample_inputs,
"auth": app.interface.auth,
"local_login_url": urllib.parse.urljoin(app.interface.local_url, "login"),
"local_api_url": urllib.parse.urljoin(app.interface.local_url, "api/predict"),
"auth": app.launchable.auth,
"local_login_url": urllib.parse.urljoin(app.launchable.local_url, "login"),
"local_api_url": urllib.parse.urljoin(app.launchable.local_url, "api/predict"),
}
return templates.TemplateResponse("api_docs.html", {"request": request, **docs})
@ -211,60 +210,26 @@ def api_docs(request: Request):
@app.post("/api/predict/", dependencies=[Depends(login_check)])
async def predict(request: Request, username: str = Depends(get_current_user)):
body = await request.json()
flag_index = None
if body.get("example_id") is not None:
example_id = body["example_id"]
if app.interface.cache_examples:
prediction = await run_in_threadpool(
load_from_cache, app.interface, example_id
)
durations = None
try:
output = await run_in_threadpool(app.launchable.process_api, body, username)
except BaseException as error:
if app.launchable.show_error:
traceback.print_exc()
return JSONResponse(content={"error": str(error)}, status_code=500)
else:
prediction, durations = await run_in_threadpool(
process_example, app.interface, example_id
)
else:
raw_input = body["data"]
if app.interface.show_error:
try:
prediction, durations = await run_in_threadpool(
app.interface.process, raw_input
)
except BaseException as error:
traceback.print_exc()
return JSONResponse(content={"error": str(error)}, status_code=500)
else:
prediction, durations = await run_in_threadpool(
app.interface.process, raw_input
)
if app.interface.allow_flagging == "auto":
flag_index = await run_in_threadpool(
app.interface.flagging_callback.flag,
app.interface,
raw_input,
prediction,
flag_option="" if app.interface.flagging_options else None,
username=username,
)
output = {
"data": prediction,
"durations": durations,
"avg_durations": app.interface.config.get("avg_durations"),
"flag_index": flag_index,
}
raise error
return output
@app.post("/api/flag/", dependencies=[Depends(login_check)])
async def flag(request: Request, username: str = Depends(get_current_user)):
if app.interface.analytics_enabled:
await utils.log_feature_analytics(app.interface.ip_address, "flag")
if app.launchable.analytics_enabled:
await utils.log_feature_analytics(app.launchable.ip_address, "flag")
body = await request.json()
data = body["data"]
await run_in_threadpool(
app.interface.flagging_callback.flag,
app.interface,
app.launchable.flagging_callback.flag,
app.launchable,
data["input_data"],
data["output_data"],
flag_option=data.get("flag_option"),
@ -276,12 +241,12 @@ async def flag(request: Request, username: str = Depends(get_current_user)):
@app.post("/api/interpret/", dependencies=[Depends(login_check)])
async def interpret(request: Request):
if app.interface.analytics_enabled:
await utils.log_feature_analytics(app.interface.ip_address, "interpret")
if app.launchable.analytics_enabled:
await utils.log_feature_analytics(app.launchable.ip_address, "interpret")
body = await request.json()
raw_input = body["data"]
interpretation_scores, alternative_outputs = await run_in_threadpool(
app.interface.interpret, raw_input
app.launchable.interpret, raw_input
)
return {
"interpretation_scores": interpretation_scores,
@ -368,20 +333,20 @@ def set_state(*args):
if __name__ == "__main__": # Run directly for debugging: python app.py
from gradio import Interface
app.interface = Interface(
app.launchable = Interface(
lambda x: "Hello, " + x, "text", "text", analytics_enabled=False
)
app.interface.favicon_path = None
app.interface.config = app.interface.get_config_file()
app.interface.show_error = True
app.interface.flagging_callback.setup(app.interface.flagging_dir)
app.launchable.favicon_path = None
app.launchable.config = app.launchable.get_config_file()
app.launchable.show_error = True
app.launchable.flagging_callback.setup(app.launchable.flagging_dir)
app.tokens = {}
auth = True
if auth:
app.interface.auth = ("a", "b")
app.launchable.auth = ("a", "b")
app.auth = {"a": "b"}
app.interface.auth_message = None
app.launchable.auth_message = None
else:
app.auth = None

23
gradio/static.py Normal file
View File

@ -0,0 +1,23 @@
from __future__ import annotations
from gradio.component import Component
class StaticComponent(Component):
def __init__(self, label: str):
self.component_type = "static"
super().__init__(label)
def process(self, y):
"""
Any processing needed to be performed on the default value.
"""
return y
class Markdown(StaticComponent):
pass
class Button(StaticComponent):
pass

View File

@ -0,0 +1,82 @@
<script lang="ts">
import Pane from "./page_layouts/Pane.svelte";
import { _ } from "svelte-i18n";
import { setupi18n } from "./i18n";
setupi18n();
interface Component {
name: string;
id: string;
props: Record<string, unknown>;
}
interface Layout {
name: string;
type: string;
children: Layout | number;
}
interface Dependency {
trigger: "click" | "change";
targets: Array<string>;
inputs: Array<string>;
outputs: Array<string>;
}
export let fn: (...args: any) => Promise<unknown>;
export let components: Array<Component>;
export let layout: Layout;
export let dependencies: Array<Dependency>;
export let theme: string;
export let static_src: string;
let values: Record<string, unknown> = {};
let component_id_map: Record<string, Component> = {};
let event_listener_map: Record<string, Array<number>> = {};
for (let component of components) {
component_id_map[component.id] = component;
if (component.props && "default" in component.props) {
values[component.id] = component.props.default;
} else {
values[component.id] = null;
}
event_listener_map[component.id] = [];
}
dependencies.forEach((dependency, i) => {
if (dependency.trigger === "click") {
for (let target of dependency.targets) {
event_listener_map[target].push(i);
}
}
});
const setValues = (i: string, value: unknown) => {
values[i] = value;
};
const triggerTarget = (i: string) => {
event_listener_map[i].forEach((fn_index: number) => {
let dependency = dependencies[fn_index];
fn("predict", {
fn_index: fn_index,
data: dependency.inputs.map((i) => values[i])
}).then((output) => {
output["data"].forEach((value, i) => {
values[dependency.outputs[i]] = value;
});
});
});
};
</script>
<div class="mx-auto container p-4">
<Pane
{component_id_map}
children={layout.children}
{dependencies}
{values}
{setValues}
{triggerTarget}
{theme}
{static_src}
/>
</div>

View File

@ -26,6 +26,9 @@ import OutputVideo from "./output/Video/config.js";
import OutputTimeSeries from "./output/TimeSeries/config.js";
import OutputChatbot from "./output/Chatbot/config.js";
import StaticButton from "./static/Button/config.js";
import StaticMarkdown from "./static/Markdown/config.js";
export const input_component_map = {
audio: InputAudio,
checkbox: InputCheckbox,
@ -57,3 +60,14 @@ export const output_component_map = {
video: OutputVideo,
chatbot: OutputChatbot
};
export const static_component_map = {
button: StaticButton,
markdown: StaticMarkdown
};
export const all_components_map = {
input: input_component_map,
output: output_component_map,
static: static_component_map
};

View File

@ -0,0 +1,7 @@
<script lang="ts">
export let label: string;
</script>
<button class="px-4 py-2 rounded bg-gray-100 hover:bg-gray-200 transition">
{label}
</button>

View File

@ -0,0 +1,5 @@
import Component from "./Button.svelte";
export default {
component: Component
};

View File

@ -0,0 +1,5 @@
<script lang="ts">
export let label: string;
</script>
<div>{@html label}</div>

View File

@ -0,0 +1,5 @@
import Component from "./Markdown.svelte";
export default {
component: Component
};

View File

@ -1,4 +1,5 @@
import App from "./App.svelte";
import Blocks from "./Blocks.svelte";
import Login from "./Login.svelte";
import { fn } from "./api";
@ -36,6 +37,7 @@ interface Config {
output_components: Array<Component>;
layout: string;
live: boolean;
mode: "blocks" | "interface" | undefined;
queue: boolean;
root: string;
show_input: boolean;
@ -82,7 +84,7 @@ window.launchGradio = (config: Config, element_query: string) => {
});
} else {
let url = new URL(window.location.toString());
if (config.theme !== null && config.theme.startsWith("dark")) {
if (config.theme && config.theme.startsWith("dark")) {
target.classList.add("dark");
config.dark = true;
if (config.theme === "dark") {
@ -95,10 +97,17 @@ window.launchGradio = (config: Config, element_query: string) => {
target.classList.add("dark");
}
config.fn = fn.bind(null, config.root + "api/");
new App({
target: target,
props: config
});
if (config.mode === "blocks") {
new Blocks({
target: target,
props: config
});
} else {
new App({
target: target,
props: config
});
}
}
};

View File

@ -0,0 +1,55 @@
<script>
import TabSet from "./TabSet.svelte";
import { all_components_map } from "../components/directory";
export let component_id_map,
children,
dependencies,
type,
values,
setValues,
triggerTarget,
theme,
static_src;
</script>
<div class="flex gap-4" class:flex-col={type !== "row"}>
{#each children as child}
{#if typeof child === "object"}
{#if child.type === "tabset"}
<TabSet
{component_id_map}
{...child}
{values}
{setValues}
{triggerTarget}
{static_src}
{theme}
/>
{:else}
<svelte:self
{component_id_map}
{...child}
{values}
{setValues}
{triggerTarget}
{static_src}
{theme}
/>
{/if}
{:else if !(component_id_map[child].type === "output" && values[child] === null)}
<div class:flex-1={type === "row"} on:click={() => triggerTarget(child)}>
<svelte:component
this={all_components_map[component_id_map[child].type][
component_id_map[child].props.name
].component}
value={values[child]}
setValue={setValues.bind(this, child)}
{...component_id_map[child].props}
{static_src}
{theme}
/>
</div>
{/if}
{/each}
</div>

View File

@ -0,0 +1,54 @@
<script>
import Pane from "./Pane.svelte";
export let component_id_map,
children,
type,
values,
setValues,
triggerTarget,
theme,
static_src;
let selected_tab = 0;
console.log("tabs", children);
</script>
<div class="flex flex-col">
<div class="flex">
{#each children as child, i}
{#if i === selected_tab}
<button
class="px-4 py-2 font-semibold border-2 border-b-0 rounded-t border-gray-200"
>
{child.name}
</button>
{:else}
<button
class="px-4 py-2 border-b-2 border-gray-200"
on:click={() => {
selected_tab = i;
}}
>
{child.name}
</button>
{/if}
{/each}
<div class="flex-1 border-b-2 border-gray-200" />
</div>
{#each children as child, i}
<div
class="p-2 border-2 border-t-0 border-gray-200"
class:hidden={i !== selected_tab}
>
<Pane
{component_id_map}
{...child}
{values}
{setValues}
{triggerTarget}
{static_src}
{theme}
/>
</div>
{/each}
</div>