cleaned up interface.py file

This commit is contained in:
Abubakar Abid 2021-12-16 08:24:24 -06:00
parent fd99d217c3
commit 97fc1d1dfd
3 changed files with 105 additions and 96 deletions

View File

@ -160,7 +160,6 @@ def get_huggingface_interface(model_name, api_key, alias):
'inputs': pipeline['inputs'],
'outputs': pipeline['outputs'],
'title': model_name,
'api_mode': True,
}
return interface_info
@ -212,7 +211,6 @@ def get_spaces_interface(model_name, api_key, alias):
fn.__name__ = alias if (alias is not None) else model_name
interface_info["fn"] = fn
interface_info["api_mode"] = True
return interface_info

View File

@ -6,12 +6,9 @@ interface using the input and output types.
import gradio
from gradio.inputs import get_input_instance
from gradio.outputs import get_output_instance
from gradio import networking, strings, utils
from gradio import networking, strings, utils, encryptor, queue
from gradio.interpretation import quantify_difference_in_label, get_regression_or_classification_value
from gradio.external import load_interface
from gradio import encryptor
from gradio import queue
import analytics
import copy
import csv
import getpass
@ -29,10 +26,8 @@ import warnings
import webbrowser
import weakref
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
analytics_url = 'https://api.gradio.app/'
ip_address = networking.get_local_ip_address()
ip_address = networking.get_local_ip_address()
JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json")
@ -65,29 +60,30 @@ class Interface:
# create a dictionary of kwargs without overwriting the original interface_info dict because it is mutable
# and that can cause some issues since the internal prediction function may rely on the original interface_info dict
kwargs = dict(interface_info, **kwargs)
return cls(**kwargs)
interface = cls(**kwargs)
interface.api_mode = True # set api mode to true so that the interface will not preprocess/postprocess
return interface
def __init__(self, fn, inputs=None, outputs=None, verbose=False, examples=None,
examples_per_page=10, live=False,
layout="unaligned", show_input=True, show_output=True,
def __init__(self, fn, inputs=None, outputs=None, verbose=None, examples=None,
examples_per_page=10, live=False, layout="unaligned", show_input=True, show_output=True,
capture_session=None, interpretation=None, num_shap=2.0, theme=None, repeat_outputs_per_model=True,
title=None, description=None, article=None, thumbnail=None,
css=None, server_port=None, server_name=None, height=500, width=900,
allow_screenshot=True, allow_flagging=None, flagging_options=None, encrypt=False,
show_tips=False, flagging_dir="flagged", analytics_enabled=None, enable_queue=False, api_mode=False):
show_tips=None, flagging_dir="flagged", analytics_enabled=None, enable_queue=None, api_mode=None):
"""
Parameters:
fn (Callable): the function to wrap an interface around.
inputs (Union[str, List[Union[str, InputComponent]]]): a single Gradio input component, or list of Gradio input components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of input components should match the number of parameters in fn.
outputs (Union[str, List[Union[str, OutputComponent]]]): a single Gradio output component, or list of Gradio output components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of output components should match the number of values returned by fn.
verbose (bool): whether to print detailed information during launch.
verbose (bool): DEPRECATED. Whether to print detailed information during launch.
examples (Union[List[List[Any]], str]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
examples_per_page (int): If examples are provided, how many to display per page.
live (bool): whether the interface should automatically reload on change.
layout (str): Layout of input and output panels. "horizontal" arranges them as two columns of equal height, "unaligned" arranges them as two columns of unequal height, and "vertical" arranges them vertically.
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter.
num_shap (float): a multiplier that determines how many examples are computed for shap-based interpretation. Increasing this value will increase shap runtime, but improve results.
capture_session (bool): DEPRECATED. If True, captures the default graph and session (needed for Tensorflow 1.x)
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use simple built-in interpreter, "shap" to use a built-in shapley-based interpreter, or your own custom interpretation function.
num_shap (float): a multiplier that determines how many examples are computed for shap-based interpretation. Increasing this value will increase shap runtime, but improve results. Only applies if interpretation is "shap".
title (str): a title for the interface; if provided, appears above the input and output components.
description (str): a description for the interface; if provided, appears above the input and output components.
article (str): an expanded article explaining the interface; if provided, appears below the input and output components. Accepts Markdown and HTML content.
@ -99,21 +95,19 @@ class Interface:
flagging_options (List[str]): if not None, provides options a user must select when flagging.
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
flagging_dir (str): what to name the dir where flagged data is stored.
show_tips (bool): if True, will occasionally show tips about new Gradio features
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
api_mode (bool): If True, will skip preprocessing steps when the Interface is called() as a function (should remain False unless the Interface is loaded from an external repo)
show_tips (bool): DEPRECATED. if True, will occasionally show tips about new Gradio features
enable_queue (bool): DEPRECATED. if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
api_mode (bool): DEPRECATED. If True, will skip preprocessing steps when the Interface is called() as a function (should remain False unless the Interface is loaded from an external repo)
"""
if not isinstance(fn, list):
fn = [fn]
if isinstance(inputs, list):
self.input_components = [get_input_instance(i) for i in inputs]
else:
self.input_components = [get_input_instance(inputs)]
if isinstance(outputs, list):
self.output_components = [get_output_instance(i) for i in outputs]
else:
self.output_components = [get_output_instance(outputs)]
if not isinstance(inputs, list):
inputs = [inputs]
if not isinstance(outputs, list):
outputs = [outputs]
self.input_components = [get_input_instance(i) for i in inputs]
self.output_components = [get_output_instance(o) for o in outputs]
if repeat_outputs_per_model:
self.output_components *= len(fn)
@ -128,7 +122,10 @@ class Interface:
self.predict_durations = [[0, 0]] * len(fn)
self.function_names = [func.__name__ for func in fn]
self.__name__ = ", ".join(self.function_names)
self.verbose = verbose
if verbose is not None:
warnings.warn("The `verbose` parameter in the `Interface` is deprecated and has no effect.")
self.status = "OFF"
self.live = live
self.layout = layout
@ -174,7 +171,7 @@ class Interface:
self.server_port = server_port
if server_name is not None or server_port is not None:
warnings.warn("The server_name and server_port parameters in the `Interface` class will be deprecated. Please provide them in the `launch()` method instead.")
self.simple_server = None
self.allow_screenshot = allow_screenshot
# For allow_flagging and analytics_enabled: (1) first check for parameter, (2) check for environment variable, (3) default to True
@ -188,24 +185,20 @@ class Interface:
self.share = None
self.share_url = None
self.local_url = None
self.show_tips = show_tips
if show_tips is not None:
warnings.warn("The `show_tips` parameter in the `Interface` is deprecated. Please use the `show_tips` parameter in `launch()` instead")
self.requires_permissions = any(
[component.requires_permissions for component in self.input_components])
self.enable_queue = enable_queue
self.api_mode = api_mode
if self.enable_queue is not None:
warnings.warn("The `enable_queue` parameter in the `Interface` will be deprecated. Please use the `enable_queue` parameter in `launch()` instead")
data = {'fn': fn,
'inputs': inputs,
'outputs': outputs,
'live': live,
'capture_session': capture_session,
'ip_address': ip_address,
'interpretation': interpretation,
'allow_flagging': allow_flagging,
'allow_screenshot': allow_screenshot,
'custom_css': self.css is not None,
'theme': self.theme
}
if api_mode is not None:
warnings.warn("The `api_mode` parameter in the `Interface` is deprecated.")
self.api_mode = False
if self.capture_session:
try:
@ -220,12 +213,21 @@ class Interface:
if self.allow_flagging:
os.makedirs(self.flagging_dir, exist_ok=True)
data = {'fn': fn,
'inputs': inputs,
'outputs': outputs,
'live': live,
'capture_session': capture_session,
'ip_address': ip_address,
'interpretation': interpretation,
'allow_flagging': allow_flagging,
'allow_screenshot': allow_screenshot,
'custom_css': self.css is not None,
'theme': self.theme
}
if self.analytics_enabled:
try:
requests.post(analytics_url + 'gradio-initiated-analytics/',
data=data, timeout=3)
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
pass # do not push analytics if no network
utils.initiated_analytics(data)
# Alert user if a more recent version of the library exists
utils.version_check()
@ -551,7 +553,7 @@ class Interface:
def launch(self, inline=None, inbrowser=None, share=False, debug=False,
auth=None, auth_message=None, private_endpoint=None,
prevent_thread_lock=False, show_error=True, server_name=None,
server_port=None):
server_port=None, show_tips=False, enable_queue=False):
"""
Launches the webserver that serves the UI for the interface.
Parameters:
@ -566,11 +568,13 @@ class Interface:
show_error (bool): If True, any errors in the interface will be printed in the browser console log
server_port (int): will start gradio app on this port (if available)
server_name (str): to make app accessible on local network, set this to "0.0.0.0".
show_error (bool): show prediction errors in console
show_tips (bool): if True, will occasionally show tips about new Gradio features
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
Returns:
app (flask.Flask): Flask app object
path_to_local_server (str): Locally accessible link
share_url (str): Publicly accessible link (if share=True)
show_error (bool): show prediction errors in console
"""
# Set up local flask server
@ -580,6 +584,7 @@ class Interface:
auth = [auth]
self.auth = auth
self.auth_message = auth_message
self.show_tips = show_tips
# Request key for encryption
if self.encrypt:
@ -588,6 +593,8 @@ class Interface:
server_name = server_name or self.server_name or networking.LOCALHOST_NAME
server_port = server_port or self.server_port or networking.INITIAL_PORT_VALUE
if self.enable_queue is None:
self.enable_queue = enable_queue
# Launch local flask server
server_port, path_to_local_server, app, thread = networking.start_server(
@ -631,7 +638,8 @@ class Interface:
else:
print(strings.en["SHARE_LINK_MESSAGE"])
except RuntimeError:
send_error_analytics(self.analytics_enabled)
if self.analytics_enabled:
utils.error_analytics("RuntimeError")
share_url = None
else:
print(strings.en["PUBLIC_SHARE_TRUE"])
@ -661,8 +669,20 @@ class Interface:
except ImportError:
pass # IPython is not available so does not print inline.
send_launch_analytics(analytics_enabled=self.analytics_enabled, inbrowser=inbrowser, is_colab=is_colab,
share=share, share_url=share_url)
data = {
'launch_method': 'browser' if inbrowser else 'inline',
'is_google_colab': is_colab,
'is_sharing_on': share,
'share_url': share_url,
'ip_address': ip_address,
'enable_queue': self.enable_queue,
'show_tips': self.show_tips,
'api_mode': self.api_mode,
'server_name': server_name,
'server_port': server_port,
}
if self.analytics_enabled:
utils.launch_analytics(data)
show_tip(self)
@ -728,23 +748,15 @@ class Interface:
else:
mlflow.log_param("Gradio Interface Local Link",
self.local_url)
if self.analytics_enabled:
if analytics_integration:
if self.analytics_enabled and analytics_integration:
data = {'integration': analytics_integration}
try:
requests.post(analytics_url +
'gradio-integration-analytics/',
data=data, timeout=3)
except (
requests.ConnectionError, requests.exceptions.ReadTimeout):
pass # do not push analytics if no network
utils.integration_analytics(data)
def show_tip(io):
# Only show tip every other use.
if not(io.show_tips) or random.random() < 0.5:
return
print(random.choice(strings.en.TIPS))
if io.show_tips and random.random() < 0.5:
print(random.choice(strings.en.TIPS))
def launch_counter():
@ -765,35 +777,7 @@ def launch_counter():
pass
def send_error_analytics(analytics_enabled):
data = {'error': 'RuntimeError in launch method'}
if analytics_enabled:
try:
requests.post(analytics_url + 'gradio-error-analytics/',
data=data, timeout=3)
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
pass # do not push analytics if no network
def send_launch_analytics(analytics_enabled, inbrowser, is_colab, share, share_url):
launch_method = 'browser' if inbrowser else 'inline'
if analytics_enabled:
data = {
'launch_method': launch_method,
'is_google_colab': is_colab,
'is_sharing_on': share,
'share_url': share_url,
'ip_address': ip_address
}
try:
requests.post(analytics_url + 'gradio-launched-analytics/',
data=data, timeout=3)
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
pass # do not push analytics if no network
def close_all():
for io in Interface.get_instances():
io.close()
reset_all = close_all # for backwards compatibility

View File

@ -1,3 +1,4 @@
import analytics
import json.decoder
import warnings
import requests
@ -6,8 +7,10 @@ from distutils.version import StrictVersion
from socket import gaierror
from urllib3.exceptions import MaxRetryError
analytics_url = 'https://api.gradio.app/'
PKG_VERSION_URL = "https://api.gradio.app/pkg-version"
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
def version_check():
@ -28,7 +31,31 @@ def version_check():
warnings.warn("package URL does not contain version info.")
except:
warnings.warn("unable to connect with package URL to collect version info.")
def initiated_analytics(data):
try:
requests.post(analytics_url + 'gradio-initiated-analytics/',
data=data, timeout=3)
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
pass # do not push analytics if no network
def launch_analytics(data):
try:
requests.post(analytics_url + 'gradio-launched-analytics/',
data=data, timeout=3)
except (requests.ConnectionError, requests.exceptions.ReadTimeout):
pass # do not push analytics if no network
def integration_analytics(data):
try:
requests.post(analytics_url + 'gradio-integration-analytics/',
data=data, timeout=3)
except (
requests.ConnectionError, requests.exceptions.ReadTimeout):
pass # do not push analytics if no network
def error_analytics(type):