From 3ad5959c740c141edaa104e65cb48e484e40bc3d Mon Sep 17 00:00:00 2001 From: Ali Abid Date: Tue, 19 Jan 2021 09:15:09 -0800 Subject: [PATCH] add auth --- gradio.egg-info/requires.txt | 1 + gradio/interface.py | 6 ++++-- gradio/networking.py | 10 ++++++++-- setup.py | 1 + 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/gradio.egg-info/requires.txt b/gradio.egg-info/requires.txt index 91b95b5c5c..0afa8bed8d 100644 --- a/gradio.egg-info/requires.txt +++ b/gradio.egg-info/requires.txt @@ -3,6 +3,7 @@ requests Flask>=1.1.1 Flask-Cors>=3.0.8 flask-cachebuster +Flask-BasicAuth paramiko scipy IPython diff --git a/gradio/interface.py b/gradio/interface.py index b88783865a..e2023d6a3b 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -44,7 +44,7 @@ class Interface: Interface.instances) def __init__(self, fn, inputs, outputs, verbose=False, examples=None, - examples_per_page=10, live=False, + examples_per_page=10, live=False, auth=None, layout="horizontal", show_input=True, show_output=True, capture_session=False, interpretation=None, title=None, description=None, article=None, thumbnail=None, @@ -61,6 +61,7 @@ class Interface: examples (List[List[Any]]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. examples_per_page (int): If examples are provided, how many to display per page. live (bool): whether the interface should automatically reload on change. + auth (Tuple[str, str]): If provided, username and password required to access interface. layout (str): Layout of input and output panels. "horizontal" arranges them as two columns of equal height, "unaligned" arranges them as two columns of unequal height, and "vertical" arranges them vertically. capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x) interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter. @@ -113,6 +114,7 @@ class Interface: self.verbose = verbose self.status = "OFF" self.live = live + self.auth = auth self.layout = layout self.show_input = show_input self.show_output = show_output @@ -384,7 +386,7 @@ class Interface: networking.set_meta_tags(self.title, self.description, self.thumbnail) server_port, app, thread = networking.start_server( - self, self.server_name, self.server_port) + self, self.server_name, self.server_port, self.auth) path_to_local_server = "http://{}:{}/".format(self.server_name, server_port) self.server_port = server_port self.status = "RUNNING" diff --git a/gradio/networking.py b/gradio/networking.py index 88c9d9e001..abd9ac928f 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -7,6 +7,7 @@ import socket import threading from flask import Flask, request, jsonify, abort, send_file, render_template from flask_cachebuster import CacheBuster +from flask_basicauth import BasicAuth from flask_cors import CORS import threading import pkg_resources @@ -256,12 +257,17 @@ def interpret(): def file(path): return send_file(os.path.join(app.cwd, path)) -def start_server(interface, server_name, server_port=None): +def start_server(interface, server_name, server_port=None, auth=None): if server_port is None: server_port = INITIAL_PORT_VALUE port = get_first_available_port( server_port, server_port + TRY_NUM_PORTS ) + if auth is not None: + app.config['BASIC_AUTH_USERNAME'] = auth[0] + app.config['BASIC_AUTH_PASSWORD'] = auth[1] + app.config['BASIC_AUTH_FORCE'] = True + basic_auth = BasicAuth(app) app.interface = interface app.cwd = os.getcwd() log = logging.getLogger('werkzeug') @@ -304,6 +310,6 @@ def setup_tunnel(local_server_port): def url_ok(url): try: r = requests.head(url) - return r.status_code == 200 + return r.status_code == 200 or r.status_code == 401 except ConnectionError: return False diff --git a/setup.py b/setup.py index 49e8d0652c..e557a98a71 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,7 @@ setup( 'Flask>=1.1.1', 'Flask-Cors>=3.0.8', 'flask-cachebuster', + 'Flask-BasicAuth', 'paramiko', 'scipy', 'IPython',