From 97fc1d1dfd6b0aa0fa4f02a42c04db1968158179 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 16 Dec 2021 08:24:24 -0600 Subject: [PATCH] cleaned up interface.py file --- gradio/external.py | 2 - gradio/interface.py | 170 ++++++++++++++++++++------------------------ gradio/utils.py | 29 +++++++- 3 files changed, 105 insertions(+), 96 deletions(-) diff --git a/gradio/external.py b/gradio/external.py index 2a5f5dfd50..71f6d988b1 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -160,7 +160,6 @@ def get_huggingface_interface(model_name, api_key, alias): 'inputs': pipeline['inputs'], 'outputs': pipeline['outputs'], 'title': model_name, - 'api_mode': True, } return interface_info @@ -212,7 +211,6 @@ def get_spaces_interface(model_name, api_key, alias): fn.__name__ = alias if (alias is not None) else model_name interface_info["fn"] = fn - interface_info["api_mode"] = True return interface_info diff --git a/gradio/interface.py b/gradio/interface.py index 831f6daef5..ba5f3ac6d9 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -6,12 +6,9 @@ interface using the input and output types. import gradio from gradio.inputs import get_input_instance from gradio.outputs import get_output_instance -from gradio import networking, strings, utils +from gradio import networking, strings, utils, encryptor, queue from gradio.interpretation import quantify_difference_in_label, get_regression_or_classification_value from gradio.external import load_interface -from gradio import encryptor -from gradio import queue -import analytics import copy import csv import getpass @@ -29,10 +26,8 @@ import warnings import webbrowser import weakref -analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy" -analytics_url = 'https://api.gradio.app/' -ip_address = networking.get_local_ip_address() +ip_address = networking.get_local_ip_address() JSON_PATH = os.path.join(os.path.dirname(gradio.__file__), "launches.json") @@ -65,29 +60,30 @@ class Interface: # create a dictionary of kwargs without overwriting the original interface_info dict because it is mutable # and that can cause some issues since the internal prediction function may rely on the original interface_info dict kwargs = dict(interface_info, **kwargs) - return cls(**kwargs) + interface = cls(**kwargs) + interface.api_mode = True # set api mode to true so that the interface will not preprocess/postprocess + return interface - def __init__(self, fn, inputs=None, outputs=None, verbose=False, examples=None, - examples_per_page=10, live=False, - layout="unaligned", show_input=True, show_output=True, + def __init__(self, fn, inputs=None, outputs=None, verbose=None, examples=None, + examples_per_page=10, live=False, layout="unaligned", show_input=True, show_output=True, capture_session=None, interpretation=None, num_shap=2.0, theme=None, repeat_outputs_per_model=True, title=None, description=None, article=None, thumbnail=None, css=None, server_port=None, server_name=None, height=500, width=900, allow_screenshot=True, allow_flagging=None, flagging_options=None, encrypt=False, - show_tips=False, flagging_dir="flagged", analytics_enabled=None, enable_queue=False, api_mode=False): + show_tips=None, flagging_dir="flagged", analytics_enabled=None, enable_queue=None, api_mode=None): """ Parameters: fn (Callable): the function to wrap an interface around. inputs (Union[str, List[Union[str, InputComponent]]]): a single Gradio input component, or list of Gradio input components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of input components should match the number of parameters in fn. outputs (Union[str, List[Union[str, OutputComponent]]]): a single Gradio output component, or list of Gradio output components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of output components should match the number of values returned by fn. - verbose (bool): whether to print detailed information during launch. + verbose (bool): DEPRECATED. Whether to print detailed information during launch. examples (Union[List[List[Any]], str]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs. examples_per_page (int): If examples are provided, how many to display per page. live (bool): whether the interface should automatically reload on change. layout (str): Layout of input and output panels. "horizontal" arranges them as two columns of equal height, "unaligned" arranges them as two columns of unequal height, and "vertical" arranges them vertically. - capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x) - interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter. - num_shap (float): a multiplier that determines how many examples are computed for shap-based interpretation. Increasing this value will increase shap runtime, but improve results. + capture_session (bool): DEPRECATED. If True, captures the default graph and session (needed for Tensorflow 1.x) + interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use simple built-in interpreter, "shap" to use a built-in shapley-based interpreter, or your own custom interpretation function. + num_shap (float): a multiplier that determines how many examples are computed for shap-based interpretation. Increasing this value will increase shap runtime, but improve results. Only applies if interpretation is "shap". title (str): a title for the interface; if provided, appears above the input and output components. description (str): a description for the interface; if provided, appears above the input and output components. article (str): an expanded article explaining the interface; if provided, appears below the input and output components. Accepts Markdown and HTML content. @@ -99,21 +95,19 @@ class Interface: flagging_options (List[str]): if not None, provides options a user must select when flagging. encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch flagging_dir (str): what to name the dir where flagged data is stored. - show_tips (bool): if True, will occasionally show tips about new Gradio features - enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout. - api_mode (bool): If True, will skip preprocessing steps when the Interface is called() as a function (should remain False unless the Interface is loaded from an external repo) + show_tips (bool): DEPRECATED. if True, will occasionally show tips about new Gradio features + enable_queue (bool): DEPRECATED. if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout. + api_mode (bool): DEPRECATED. If True, will skip preprocessing steps when the Interface is called() as a function (should remain False unless the Interface is loaded from an external repo) """ if not isinstance(fn, list): fn = [fn] - if isinstance(inputs, list): - self.input_components = [get_input_instance(i) for i in inputs] - else: - self.input_components = [get_input_instance(inputs)] - if isinstance(outputs, list): - self.output_components = [get_output_instance(i) for i in outputs] - else: - self.output_components = [get_output_instance(outputs)] + if not isinstance(inputs, list): + inputs = [inputs] + if not isinstance(outputs, list): + outputs = [outputs] + self.input_components = [get_input_instance(i) for i in inputs] + self.output_components = [get_output_instance(o) for o in outputs] if repeat_outputs_per_model: self.output_components *= len(fn) @@ -128,7 +122,10 @@ class Interface: self.predict_durations = [[0, 0]] * len(fn) self.function_names = [func.__name__ for func in fn] self.__name__ = ", ".join(self.function_names) - self.verbose = verbose + + if verbose is not None: + warnings.warn("The `verbose` parameter in the `Interface` is deprecated and has no effect.") + self.status = "OFF" self.live = live self.layout = layout @@ -174,7 +171,7 @@ class Interface: self.server_port = server_port if server_name is not None or server_port is not None: warnings.warn("The server_name and server_port parameters in the `Interface` class will be deprecated. Please provide them in the `launch()` method instead.") - + self.simple_server = None self.allow_screenshot = allow_screenshot # For allow_flagging and analytics_enabled: (1) first check for parameter, (2) check for environment variable, (3) default to True @@ -188,24 +185,20 @@ class Interface: self.share = None self.share_url = None self.local_url = None - self.show_tips = show_tips + + if show_tips is not None: + warnings.warn("The `show_tips` parameter in the `Interface` is deprecated. Please use the `show_tips` parameter in `launch()` instead") + self.requires_permissions = any( [component.requires_permissions for component in self.input_components]) + self.enable_queue = enable_queue - self.api_mode = api_mode + if self.enable_queue is not None: + warnings.warn("The `enable_queue` parameter in the `Interface` will be deprecated. Please use the `enable_queue` parameter in `launch()` instead") - data = {'fn': fn, - 'inputs': inputs, - 'outputs': outputs, - 'live': live, - 'capture_session': capture_session, - 'ip_address': ip_address, - 'interpretation': interpretation, - 'allow_flagging': allow_flagging, - 'allow_screenshot': allow_screenshot, - 'custom_css': self.css is not None, - 'theme': self.theme - } + if api_mode is not None: + warnings.warn("The `api_mode` parameter in the `Interface` is deprecated.") + self.api_mode = False if self.capture_session: try: @@ -220,12 +213,21 @@ class Interface: if self.allow_flagging: os.makedirs(self.flagging_dir, exist_ok=True) + data = {'fn': fn, + 'inputs': inputs, + 'outputs': outputs, + 'live': live, + 'capture_session': capture_session, + 'ip_address': ip_address, + 'interpretation': interpretation, + 'allow_flagging': allow_flagging, + 'allow_screenshot': allow_screenshot, + 'custom_css': self.css is not None, + 'theme': self.theme + } + if self.analytics_enabled: - try: - requests.post(analytics_url + 'gradio-initiated-analytics/', - data=data, timeout=3) - except (requests.ConnectionError, requests.exceptions.ReadTimeout): - pass # do not push analytics if no network + utils.initiated_analytics(data) # Alert user if a more recent version of the library exists utils.version_check() @@ -551,7 +553,7 @@ class Interface: def launch(self, inline=None, inbrowser=None, share=False, debug=False, auth=None, auth_message=None, private_endpoint=None, prevent_thread_lock=False, show_error=True, server_name=None, - server_port=None): + server_port=None, show_tips=False, enable_queue=False): """ Launches the webserver that serves the UI for the interface. Parameters: @@ -566,11 +568,13 @@ class Interface: show_error (bool): If True, any errors in the interface will be printed in the browser console log server_port (int): will start gradio app on this port (if available) server_name (str): to make app accessible on local network, set this to "0.0.0.0". + show_error (bool): show prediction errors in console + show_tips (bool): if True, will occasionally show tips about new Gradio features + enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout. Returns: app (flask.Flask): Flask app object path_to_local_server (str): Locally accessible link share_url (str): Publicly accessible link (if share=True) - show_error (bool): show prediction errors in console """ # Set up local flask server @@ -580,6 +584,7 @@ class Interface: auth = [auth] self.auth = auth self.auth_message = auth_message + self.show_tips = show_tips # Request key for encryption if self.encrypt: @@ -588,6 +593,8 @@ class Interface: server_name = server_name or self.server_name or networking.LOCALHOST_NAME server_port = server_port or self.server_port or networking.INITIAL_PORT_VALUE + if self.enable_queue is None: + self.enable_queue = enable_queue # Launch local flask server server_port, path_to_local_server, app, thread = networking.start_server( @@ -631,7 +638,8 @@ class Interface: else: print(strings.en["SHARE_LINK_MESSAGE"]) except RuntimeError: - send_error_analytics(self.analytics_enabled) + if self.analytics_enabled: + utils.error_analytics("RuntimeError") share_url = None else: print(strings.en["PUBLIC_SHARE_TRUE"]) @@ -661,8 +669,20 @@ class Interface: except ImportError: pass # IPython is not available so does not print inline. - send_launch_analytics(analytics_enabled=self.analytics_enabled, inbrowser=inbrowser, is_colab=is_colab, - share=share, share_url=share_url) + data = { + 'launch_method': 'browser' if inbrowser else 'inline', + 'is_google_colab': is_colab, + 'is_sharing_on': share, + 'share_url': share_url, + 'ip_address': ip_address, + 'enable_queue': self.enable_queue, + 'show_tips': self.show_tips, + 'api_mode': self.api_mode, + 'server_name': server_name, + 'server_port': server_port, + } + if self.analytics_enabled: + utils.launch_analytics(data) show_tip(self) @@ -728,23 +748,15 @@ class Interface: else: mlflow.log_param("Gradio Interface Local Link", self.local_url) - if self.analytics_enabled: - if analytics_integration: + if self.analytics_enabled and analytics_integration: data = {'integration': analytics_integration} - try: - requests.post(analytics_url + - 'gradio-integration-analytics/', - data=data, timeout=3) - except ( - requests.ConnectionError, requests.exceptions.ReadTimeout): - pass # do not push analytics if no network + utils.integration_analytics(data) def show_tip(io): # Only show tip every other use. - if not(io.show_tips) or random.random() < 0.5: - return - print(random.choice(strings.en.TIPS)) + if io.show_tips and random.random() < 0.5: + print(random.choice(strings.en.TIPS)) def launch_counter(): @@ -765,35 +777,7 @@ def launch_counter(): pass -def send_error_analytics(analytics_enabled): - data = {'error': 'RuntimeError in launch method'} - if analytics_enabled: - try: - requests.post(analytics_url + 'gradio-error-analytics/', - data=data, timeout=3) - except (requests.ConnectionError, requests.exceptions.ReadTimeout): - pass # do not push analytics if no network - - -def send_launch_analytics(analytics_enabled, inbrowser, is_colab, share, share_url): - launch_method = 'browser' if inbrowser else 'inline' - if analytics_enabled: - data = { - 'launch_method': launch_method, - 'is_google_colab': is_colab, - 'is_sharing_on': share, - 'share_url': share_url, - 'ip_address': ip_address - } - try: - requests.post(analytics_url + 'gradio-launched-analytics/', - data=data, timeout=3) - except (requests.ConnectionError, requests.exceptions.ReadTimeout): - pass # do not push analytics if no network - - def close_all(): for io in Interface.get_instances(): io.close() -reset_all = close_all # for backwards compatibility diff --git a/gradio/utils.py b/gradio/utils.py index 52fb08481d..5d2c7b7937 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -1,3 +1,4 @@ +import analytics import json.decoder import warnings import requests @@ -6,8 +7,10 @@ from distutils.version import StrictVersion from socket import gaierror from urllib3.exceptions import MaxRetryError + analytics_url = 'https://api.gradio.app/' PKG_VERSION_URL = "https://api.gradio.app/pkg-version" +analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy" def version_check(): @@ -28,7 +31,31 @@ def version_check(): warnings.warn("package URL does not contain version info.") except: warnings.warn("unable to connect with package URL to collect version info.") - + + +def initiated_analytics(data): + try: + requests.post(analytics_url + 'gradio-initiated-analytics/', + data=data, timeout=3) + except (requests.ConnectionError, requests.exceptions.ReadTimeout): + pass # do not push analytics if no network + + +def launch_analytics(data): + try: + requests.post(analytics_url + 'gradio-launched-analytics/', + data=data, timeout=3) + except (requests.ConnectionError, requests.exceptions.ReadTimeout): + pass # do not push analytics if no network + + +def integration_analytics(data): + try: + requests.post(analytics_url + 'gradio-integration-analytics/', + data=data, timeout=3) + except ( + requests.ConnectionError, requests.exceptions.ReadTimeout): + pass # do not push analytics if no network def error_analytics(type):