mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-24 10:54:04 +08:00
0.7.0
This commit is contained in:
parent
e2e1c90229
commit
44175958cd
@ -8,6 +8,9 @@ from abc import ABC, abstractmethod
|
||||
from gradio import preprocessing_utils, validation_data
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
import datetime
|
||||
import csv
|
||||
import pandas as pd
|
||||
|
||||
# Where to find the static resources associated with each template.
|
||||
BASE_INPUT_INTERFACE_TEMPLATE_PATH = 'templates/input/{}.html'
|
||||
@ -63,6 +66,12 @@ class AbstractInput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
"""
|
||||
All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64)
|
||||
"""
|
||||
pass
|
||||
|
||||
class Sketchpad(AbstractInput):
|
||||
def __init__(self, preprocessing_fn=None, shape=(28, 28), invert_colors=True, flatten=False, scale=1, shift=0,
|
||||
@ -95,6 +104,16 @@ class Sketchpad(AbstractInput):
|
||||
array = array * self.scale + self.shift
|
||||
array = array.astype(self.dtype)
|
||||
return array
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
inp = msg['data']['input']
|
||||
im = preprocessing_utils.encoding_to_image(inp)
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
return filename
|
||||
|
||||
|
||||
class Webcam(AbstractInput):
|
||||
@ -119,6 +138,16 @@ class Webcam(AbstractInput):
|
||||
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
|
||||
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
|
||||
return array
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
inp = msg['data']['input']
|
||||
im = preprocessing_utils.encoding_to_image(inp)
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
return filename
|
||||
|
||||
|
||||
class Textbox(AbstractInput):
|
||||
@ -133,7 +162,15 @@ class Textbox(AbstractInput):
|
||||
By default, no pre-processing is applied to text.
|
||||
"""
|
||||
return inp
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
"""
|
||||
Default rebuild method for text saves it .txt file
|
||||
"""
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.txt'
|
||||
with open(f'{dir}/{filename}.txt','w') as f:
|
||||
f.write(inp)
|
||||
return filename
|
||||
|
||||
class ImageUpload(AbstractInput):
|
||||
def __init__(self, preprocessing_fn=None, shape=(224, 224, 3), image_mode='RGB',
|
||||
@ -171,6 +208,16 @@ class ImageUpload(AbstractInput):
|
||||
array = im.reshape(1, self.image_width, self.image_height, self.num_channels)
|
||||
return array
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
inp = msg['data']['input']
|
||||
im = preprocessing_utils.encoding_to_image(inp)
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
return filename
|
||||
|
||||
class CSV(AbstractInput):
|
||||
|
||||
@ -184,6 +231,12 @@ class CSV(AbstractInput):
|
||||
"""
|
||||
return inp
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
"""
|
||||
Default rebuild method for csv
|
||||
"""
|
||||
inp = msg['data']['inp']
|
||||
return json.loads(inp)
|
||||
|
||||
# Automatically adds all subclasses of AbstractInput into a dictionary (keyed by class name) for easy referencing.
|
||||
registry = {cls.__name__.lower(): cls for cls in AbstractInput.__subclasses__()}
|
||||
|
@ -1,27 +1,25 @@
|
||||
'''
|
||||
"""
|
||||
This is the core file in the `gradio` package, and defines the Interface class, including methods for constructing the
|
||||
interface using the input and output types.
|
||||
'''
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import websockets
|
||||
import nest_asyncio
|
||||
import tempfile
|
||||
import traceback
|
||||
import webbrowser
|
||||
|
||||
import gradio.inputs
|
||||
import gradio.outputs
|
||||
from gradio import networking, strings
|
||||
import tempfile
|
||||
import threading
|
||||
import traceback
|
||||
import urllib
|
||||
import json
|
||||
from distutils.version import StrictVersion
|
||||
import pkg_resources
|
||||
import requests
|
||||
import termcolor
|
||||
import random
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
LOCALHOST_IP = '127.0.0.1'
|
||||
SHARE_LINK_FORMAT = 'https://{}.gradio.app/'
|
||||
LOCALHOST_IP = "127.0.0.1"
|
||||
INITIAL_WEBSOCKET_PORT = 9200
|
||||
TRY_NUM_PORTS = 100
|
||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||
|
||||
|
||||
class Interface:
|
||||
@ -31,12 +29,24 @@ class Interface:
|
||||
"""
|
||||
|
||||
# Dictionary in which each key is a valid `model_type` argument to constructor, and the value being the description.
|
||||
VALID_MODEL_TYPES = {'sklearn': 'sklearn model', 'keras': 'Keras model', 'pyfunc': 'python function',
|
||||
'pytorch': 'PyTorch model'}
|
||||
STATUS_TYPES = {'OFF': 'off', 'RUNNING': 'running'}
|
||||
VALID_MODEL_TYPES = {
|
||||
"sklearn": "sklearn model",
|
||||
"keras": "Keras model",
|
||||
"pyfunc": "python function",
|
||||
"pytorch": "PyTorch model",
|
||||
}
|
||||
STATUS_TYPES = {"OFF": "off", "RUNNING": "running"}
|
||||
|
||||
def __init__(self, inputs, outputs, model, model_type=None, preprocessing_fns=None, postprocessing_fns=None,
|
||||
verbose=True):
|
||||
def __init__(
|
||||
self,
|
||||
inputs,
|
||||
outputs,
|
||||
model,
|
||||
model_type=None,
|
||||
preprocessing_fns=None,
|
||||
postprocessing_fns=None,
|
||||
verbose=True,
|
||||
):
|
||||
"""
|
||||
:param inputs: a string or `AbstractInput` representing the input interface.
|
||||
:param outputs: a string or `AbstractOutput` representing the output interface.
|
||||
@ -47,53 +57,68 @@ class Interface:
|
||||
:param postprocessing_fns: an optional function that overrides the postprocessing fn of the output interface.
|
||||
"""
|
||||
if isinstance(inputs, str):
|
||||
self.input_interface = gradio.inputs.registry[inputs.lower()](preprocessing_fns)
|
||||
self.input_interface = gradio.inputs.registry[inputs.lower()](
|
||||
preprocessing_fns
|
||||
)
|
||||
elif isinstance(inputs, gradio.inputs.AbstractInput):
|
||||
self.input_interface = inputs
|
||||
else:
|
||||
raise ValueError('Input interface must be of type `str` or `AbstractInput`')
|
||||
raise ValueError("Input interface must be of type `str` or `AbstractInput`")
|
||||
if isinstance(outputs, str):
|
||||
self.output_interface = gradio.outputs.registry[outputs.lower()](postprocessing_fns)
|
||||
self.output_interface = gradio.outputs.registry[outputs.lower()](
|
||||
postprocessing_fns
|
||||
)
|
||||
elif isinstance(outputs, gradio.outputs.AbstractOutput):
|
||||
self.output_interface = outputs
|
||||
else:
|
||||
raise ValueError('Output interface must be of type `str` or `AbstractOutput`')
|
||||
raise ValueError(
|
||||
"Output interface must be of type `str` or `AbstractOutput`"
|
||||
)
|
||||
self.model_obj = model
|
||||
if model_type is None:
|
||||
model_type = self._infer_model_type(model)
|
||||
if verbose:
|
||||
print("Model type not explicitly identified, inferred to be: {}".format(
|
||||
self.VALID_MODEL_TYPES[model_type]))
|
||||
elif not(model_type.lower() in self.VALID_MODEL_TYPES):
|
||||
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
|
||||
print(
|
||||
"Model type not explicitly identified, inferred to be: {}".format(
|
||||
self.VALID_MODEL_TYPES[model_type]
|
||||
)
|
||||
)
|
||||
elif not (model_type.lower() in self.VALID_MODEL_TYPES):
|
||||
ValueError("model_type must be one of: {}".format(self.VALID_MODEL_TYPES))
|
||||
self.model_type = model_type
|
||||
if self.model_type == "keras":
|
||||
import tensorflow as tf
|
||||
self.graph = tf.get_default_graph()
|
||||
self.verbose = verbose
|
||||
self.status = self.STATUS_TYPES['OFF']
|
||||
self.status = self.STATUS_TYPES["OFF"]
|
||||
self.validate_flag = False
|
||||
self.simple_server = None
|
||||
self.ngrok_api_ports = None
|
||||
self.hash = random.getrandbits(32)
|
||||
|
||||
@staticmethod
|
||||
def _infer_model_type(model):
|
||||
""" Helper method that attempts to identify the type of trained ML model."""
|
||||
try:
|
||||
import sklearn
|
||||
|
||||
if isinstance(model, sklearn.base.BaseEstimator):
|
||||
return 'sklearn'
|
||||
return "sklearn"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
|
||||
if isinstance(model, tf.keras.Model):
|
||||
return 'keras'
|
||||
return "keras"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import keras
|
||||
|
||||
if isinstance(model, keras.Model):
|
||||
return 'keras'
|
||||
return "keras"
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -102,57 +127,26 @@ class Interface:
|
||||
|
||||
raise ValueError("model_type could not be inferred, please specify parameter `model_type`")
|
||||
|
||||
async def communicate(self, websocket, path):
|
||||
"""
|
||||
Method that defines how this interface should communicates with the websocket. (1) When an input is received by
|
||||
the websocket, it is passed into the input interface and preprocssed. (2) Then the model is called to make a
|
||||
prediction. (3) Finally, the prediction is postprocessed to get something to be displayed by the output.
|
||||
:param websocket: a Websocket server used to communicate with the interface frontend
|
||||
:param path: not used, but required for compliance with websocket library
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
msg = json.loads(await websocket.recv())
|
||||
if msg['action'] == 'input':
|
||||
processed_input = self.input_interface.preprocess(msg['data'])
|
||||
prediction = self.predict(processed_input)
|
||||
processed_output = self.output_interface.postprocess(prediction)
|
||||
output = {
|
||||
'action': 'output',
|
||||
'data': processed_output,
|
||||
}
|
||||
await websocket.send(json.dumps(output))
|
||||
if msg['action'] == 'flag':
|
||||
f = open('gradio-flagged.txt','a+')
|
||||
f.write(str(msg['data']))
|
||||
f.close()
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
|
||||
def predict(self, preprocessed_input):
|
||||
"""
|
||||
Method that calls the relevant method of the model object to make a prediction.
|
||||
:param preprocessed_input: the preprocessed input returned by the input interface
|
||||
"""
|
||||
if self.model_type=='sklearn':
|
||||
if self.model_type == "sklearn":
|
||||
return self.model_obj.predict(preprocessed_input)
|
||||
elif self.model_type=='keras':
|
||||
return self.model_obj.predict(preprocessed_input)
|
||||
elif self.model_type=='pyfunc':
|
||||
elif self.model_type == "keras":
|
||||
with self.graph.as_default():
|
||||
return self.model_obj.predict(preprocessed_input)
|
||||
elif self.model_type == "pyfunc":
|
||||
return self.model_obj(preprocessed_input)
|
||||
elif self.model_type=='pytorch':
|
||||
elif self.model_type == "pytorch":
|
||||
import torch
|
||||
print(preprocessed_input.dtype)
|
||||
value = torch.from_numpy(preprocessed_input)
|
||||
print(value.dtype)
|
||||
value = torch.autograd.Variable(value)
|
||||
prediction = self.model_obj(value)
|
||||
return prediction.data.numpy()
|
||||
else:
|
||||
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
|
||||
ValueError("model_type must be one of: {}".format(self.VALID_MODEL_TYPES))
|
||||
|
||||
def validate(self):
|
||||
if self.validate_flag:
|
||||
@ -164,18 +158,28 @@ class Interface:
|
||||
if n == 0:
|
||||
self.validate_flag = True
|
||||
if self.verbose:
|
||||
print("No validation samples for this interface... skipping validation.")
|
||||
print(
|
||||
"No validation samples for this interface... skipping validation."
|
||||
)
|
||||
return
|
||||
for m, msg in enumerate(validation_inputs):
|
||||
if self.verbose:
|
||||
print(f"Validating samples: {m+1}/{n} [" + "="*(m+1) + "."*(n-m-1) + "]", end='\r')
|
||||
print(
|
||||
f"Validating samples: {m+1}/{n} ["
|
||||
+ "=" * (m + 1)
|
||||
+ "." * (n - m - 1)
|
||||
+ "]",
|
||||
end="\r",
|
||||
)
|
||||
try:
|
||||
processed_input = self.input_interface.preprocess(msg)
|
||||
prediction = self.predict(processed_input)
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print("\n----------")
|
||||
print("Validation failed, likely due to incompatible pre-processing and model input. See below:\n")
|
||||
print(
|
||||
"Validation failed, likely due to incompatible pre-processing and model input. See below:\n"
|
||||
)
|
||||
print(traceback.format_exc())
|
||||
break
|
||||
try:
|
||||
@ -183,8 +187,10 @@ class Interface:
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print("\n----------")
|
||||
print("Validation failed, likely due to incompatible model output and post-processing."
|
||||
"See below:\n")
|
||||
print(
|
||||
"Validation failed, likely due to incompatible model output and post-processing."
|
||||
"See below:\n"
|
||||
)
|
||||
print(traceback.format_exc())
|
||||
break
|
||||
else: # This means if a break was not explicitly called
|
||||
@ -206,7 +212,7 @@ class Interface:
|
||||
self.validate()
|
||||
|
||||
# If an existing interface is running with this instance, close it.
|
||||
if self.status == self.STATUS_TYPES['RUNNING']:
|
||||
if self.status == self.STATUS_TYPES["RUNNING"]:
|
||||
if self.verbose:
|
||||
print("Closing existing server...")
|
||||
if self.simple_server is not None:
|
||||
@ -217,68 +223,61 @@ class Interface:
|
||||
|
||||
output_directory = tempfile.mkdtemp()
|
||||
# Set up a port to serve the directory containing the static files with interface.
|
||||
server_port, httpd = networking.start_simple_server(output_directory)
|
||||
path_to_local_server = 'http://localhost:{}/'.format(server_port)
|
||||
networking.build_template(output_directory, self.input_interface, self.output_interface)
|
||||
server_port, httpd = networking.start_simple_server(self, output_directory)
|
||||
path_to_local_server = "http://localhost:{}/".format(server_port)
|
||||
networking.build_template(
|
||||
output_directory, self.input_interface, self.output_interface
|
||||
)
|
||||
|
||||
# Set up a port to serve a websocket that sets up the communication between the front-end and model.
|
||||
websocket_port = networking.get_first_available_port(
|
||||
INITIAL_WEBSOCKET_PORT, INITIAL_WEBSOCKET_PORT + TRY_NUM_PORTS)
|
||||
start_server = websockets.serve(self.communicate, LOCALHOST_IP, websocket_port)
|
||||
networking.set_socket_port_in_js(output_directory, websocket_port) # sets the websocket port in the JS file.
|
||||
networking.set_interface_types_in_config_file(output_directory,
|
||||
self.input_interface.__class__.__name__.lower(),
|
||||
self.output_interface.__class__.__name__.lower())
|
||||
self.status = self.STATUS_TYPES['RUNNING']
|
||||
networking.set_interface_types_in_config_file(
|
||||
output_directory,
|
||||
self.input_interface.__class__.__name__.lower(),
|
||||
self.output_interface.__class__.__name__.lower(),
|
||||
)
|
||||
self.status = self.STATUS_TYPES["RUNNING"]
|
||||
self.simple_server = httpd
|
||||
|
||||
is_colab = False
|
||||
try: # Check if running interactively using ipython.
|
||||
from_ipynb = get_ipython()
|
||||
if 'google.colab' in str(from_ipynb):
|
||||
if "google.colab" in str(from_ipynb):
|
||||
is_colab = True
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
|
||||
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
|
||||
print(f"IMPORTANT: You are using gradio version {current_pkg_version}, however version {latest_pkg_version} "
|
||||
f"is available, please upgrade.")
|
||||
print('--------')
|
||||
if self.verbose:
|
||||
print(strings.en["BETA_MESSAGE"])
|
||||
if not is_colab:
|
||||
print(strings.en["RUNNING_LOCALLY"].format(path_to_local_server))
|
||||
if share:
|
||||
try:
|
||||
path_to_ngrok_server, ngrok_api_ports = networking.setup_ngrok(
|
||||
server_port, websocket_port, output_directory, self.ngrok_api_ports)
|
||||
self.ngrok_api_ports = ngrok_api_ports
|
||||
share_url = networking.setup_tunnel(server_port)
|
||||
except RuntimeError:
|
||||
path_to_ngrok_server = None
|
||||
share_url = None
|
||||
if self.verbose:
|
||||
print(strings.en["NGROK_NO_INTERNET"])
|
||||
else:
|
||||
if is_colab: # For a colab notebook, create a public link even if share is False.
|
||||
path_to_ngrok_server, ngrok_api_ports = networking.setup_ngrok(
|
||||
server_port, websocket_port, output_directory, self.ngrok_api_ports)
|
||||
self.ngrok_api_ports = ngrok_api_ports
|
||||
if (
|
||||
is_colab
|
||||
): # For a colab notebook, create a public link even if share is False.
|
||||
share_url = networking.setup_tunnel(server_port)
|
||||
if self.verbose:
|
||||
print(strings.en["COLAB_NO_LOCAL"])
|
||||
else: # If it's not a colab notebook and share=False, print a message telling them about the share option.
|
||||
if self.verbose:
|
||||
print(strings.en["PUBLIC_SHARE_TRUE"])
|
||||
path_to_ngrok_server = None
|
||||
share_url = None
|
||||
|
||||
if path_to_ngrok_server is not None:
|
||||
url = urllib.parse.urlparse(path_to_ngrok_server)
|
||||
subdomain = url.hostname.split('.')[0]
|
||||
path_to_ngrok_interface_page = SHARE_LINK_FORMAT.format(subdomain)
|
||||
if share_url is not None:
|
||||
if self.verbose:
|
||||
print(strings.en["MODEL_PUBLICLY_AVAILABLE_URL"].format(path_to_ngrok_interface_page))
|
||||
|
||||
# Keep the server running in the background.
|
||||
asyncio.get_event_loop().run_until_complete(start_server)
|
||||
try:
|
||||
_ = get_ipython()
|
||||
except NameError: # Runtime errors are thrown in jupyter notebooks because of async.
|
||||
t = threading.Thread(target=asyncio.get_event_loop().run_forever, daemon=True)
|
||||
t.start()
|
||||
print(strings.en["MODEL_PUBLICLY_AVAILABLE_URL"].format(share_url))
|
||||
networking.set_share_url_in_config_file(output_directory, share_url)
|
||||
|
||||
if inline is None:
|
||||
try: # Check if running interactively using ipython.
|
||||
@ -295,12 +294,17 @@ class Interface:
|
||||
inbrowser = False
|
||||
|
||||
if inbrowser and not is_colab:
|
||||
webbrowser.open(path_to_local_server) # Open a browser tab with the interface.
|
||||
webbrowser.open(
|
||||
path_to_local_server
|
||||
) # Open a browser tab with the interface.
|
||||
if inline:
|
||||
from IPython.display import IFrame
|
||||
if is_colab: # Embed the remote interface page if on google colab; otherwise, embed the local page.
|
||||
display(IFrame(path_to_ngrok_interface_page, width=1000, height=500))
|
||||
|
||||
if (
|
||||
is_colab
|
||||
): # Embed the remote interface page if on google colab; otherwise, embed the local page.
|
||||
display(IFrame(share_url, width=1000, height=500))
|
||||
else:
|
||||
display(IFrame(path_to_local_server, width=1000, height=500))
|
||||
|
||||
return httpd, path_to_local_server, path_to_ngrok_server
|
||||
return httpd, path_to_local_server, share_url
|
||||
|
@ -1,46 +1,45 @@
|
||||
'''
|
||||
"""
|
||||
Defines helper methods useful for setting up ports, launching servers, and handling `ngrok`
|
||||
'''
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import requests
|
||||
import zipfile
|
||||
import io
|
||||
import sys
|
||||
import os
|
||||
import socket
|
||||
from psutil import process_iter, AccessDenied, NoSuchProcess
|
||||
from signal import SIGTERM # or SIGKILL
|
||||
import threading
|
||||
from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
|
||||
import stat
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
import pkg_resources
|
||||
from bs4 import BeautifulSoup
|
||||
from distutils import dir_util
|
||||
from gradio import inputs, outputs
|
||||
import time
|
||||
import json
|
||||
from urllib.parse import urlparse
|
||||
from gradio.tunneling import create_tunnel
|
||||
import urllib.request
|
||||
from shutil import copyfile
|
||||
|
||||
INITIAL_PORT_VALUE = 7860 # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||
TRY_NUM_PORTS = 100 # Number of ports to try before giving up and throwing an exception.
|
||||
LOCALHOST_NAME = 'localhost'
|
||||
NGROK_TUNNEL_API_URL = "http://{}/api/tunnels"
|
||||
|
||||
BASE_TEMPLATE = pkg_resources.resource_filename('gradio', 'templates/base_template.html')
|
||||
STATIC_PATH_LIB = pkg_resources.resource_filename('gradio', 'static/')
|
||||
STATIC_PATH_TEMP = 'static/'
|
||||
TEMPLATE_TEMP = 'index.html'
|
||||
BASE_JS_FILE = 'static/js/all_io.js'
|
||||
CONFIG_FILE = 'static/config.json'
|
||||
INITIAL_PORT_VALUE = (
|
||||
7860
|
||||
) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||
TRY_NUM_PORTS = (
|
||||
100
|
||||
) # Number of ports to try before giving up and throwing an exception.
|
||||
LOCALHOST_NAME = "localhost"
|
||||
GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request"
|
||||
|
||||
BASE_TEMPLATE = pkg_resources.resource_filename(
|
||||
"gradio", "templates/base_template.html"
|
||||
)
|
||||
STATIC_PATH_LIB = pkg_resources.resource_filename("gradio", "static/")
|
||||
STATIC_PATH_TEMP = "static/"
|
||||
TEMPLATE_TEMP = "index.html"
|
||||
BASE_JS_FILE = "static/js/all_io.js"
|
||||
CONFIG_FILE = "static/config.json"
|
||||
|
||||
ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association"
|
||||
ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association"
|
||||
|
||||
FLAGGING_DIRECTORY = 'gradio-flagged/{}'
|
||||
FLAGGING_FILENAME = 'gradio-flagged.txt'
|
||||
|
||||
NGROK_ZIP_URLS = {
|
||||
"linux": "https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip",
|
||||
"darwin": "https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-darwin-amd64.zip",
|
||||
"win32": "https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-windows-amd64.zip",
|
||||
}
|
||||
|
||||
def build_template(temp_dir, input_interface, output_interface):
|
||||
"""
|
||||
@ -50,16 +49,27 @@ def build_template(temp_dir, input_interface, output_interface):
|
||||
:param output_interface: an AbstractInput object which includes is used to get the input template
|
||||
"""
|
||||
input_template_path = pkg_resources.resource_filename(
|
||||
'gradio', inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(input_interface.get_name()))
|
||||
"gradio",
|
||||
inputs.BASE_INPUT_INTERFACE_TEMPLATE_PATH.format(input_interface.get_name()),
|
||||
)
|
||||
output_template_path = pkg_resources.resource_filename(
|
||||
'gradio', outputs.BASE_OUTPUT_INTERFACE_TEMPLATE_PATH.format(output_interface.get_name()))
|
||||
"gradio",
|
||||
outputs.BASE_OUTPUT_INTERFACE_TEMPLATE_PATH.format(output_interface.get_name()),
|
||||
)
|
||||
input_page = open(input_template_path)
|
||||
output_page = open(output_template_path)
|
||||
input_soup = BeautifulSoup(render_string_or_list_with_tags(
|
||||
input_page.read(), input_interface.get_template_context()), features="html.parser")
|
||||
input_soup = BeautifulSoup(
|
||||
render_string_or_list_with_tags(
|
||||
input_page.read(), input_interface.get_template_context()
|
||||
),
|
||||
features="html.parser",
|
||||
)
|
||||
output_soup = BeautifulSoup(
|
||||
render_string_or_list_with_tags(
|
||||
output_page.read(), output_interface.get_template_context()), features="html.parser")
|
||||
output_page.read(), output_interface.get_template_context()
|
||||
),
|
||||
features="html.parser",
|
||||
)
|
||||
|
||||
all_io_page = open(BASE_TEMPLATE)
|
||||
all_io_soup = BeautifulSoup(all_io_page.read(), features="html.parser")
|
||||
@ -73,12 +83,24 @@ def build_template(temp_dir, input_interface, output_interface):
|
||||
f.write(str(all_io_soup))
|
||||
|
||||
copy_files(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
|
||||
render_template_with_tags(os.path.join(
|
||||
temp_dir, inputs.BASE_INPUT_INTERFACE_JS_PATH.format(input_interface.get_name())),
|
||||
input_interface.get_js_context())
|
||||
render_template_with_tags(os.path.join(
|
||||
temp_dir, outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(output_interface.get_name())),
|
||||
output_interface.get_js_context())
|
||||
# Move association file to root of temporary directory.
|
||||
copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC),
|
||||
os.path.join(temp_dir, ASSOCIATION_PATH_IN_ROOT))
|
||||
|
||||
render_template_with_tags(
|
||||
os.path.join(
|
||||
temp_dir,
|
||||
inputs.BASE_INPUT_INTERFACE_JS_PATH.format(input_interface.get_name()),
|
||||
),
|
||||
input_interface.get_js_context(),
|
||||
)
|
||||
render_template_with_tags(
|
||||
os.path.join(
|
||||
temp_dir,
|
||||
outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(output_interface.get_name()),
|
||||
),
|
||||
output_interface.get_js_context(),
|
||||
)
|
||||
|
||||
|
||||
def copy_files(src_dir, dest_dir):
|
||||
@ -100,7 +122,7 @@ def render_template_with_tags(template_path, context):
|
||||
with open(template_path) as fin:
|
||||
old_lines = fin.readlines()
|
||||
new_lines = render_string_or_list_with_tags(old_lines, context)
|
||||
with open(template_path, 'w') as fout:
|
||||
with open(template_path, "w") as fout:
|
||||
for line in new_lines:
|
||||
fout.write(line)
|
||||
|
||||
@ -109,36 +131,37 @@ def render_string_or_list_with_tags(old_lines, context):
|
||||
# Handle string case
|
||||
if isinstance(old_lines, str):
|
||||
for key, value in context.items():
|
||||
old_lines = old_lines.replace(r'{{' + key + r'}}', str(value))
|
||||
old_lines = old_lines.replace(r"{{" + key + r"}}", str(value))
|
||||
return old_lines
|
||||
|
||||
# Handle list case
|
||||
new_lines = []
|
||||
for line in old_lines:
|
||||
for key, value in context.items():
|
||||
line = line.replace(r'{{' + key + r'}}', str(value))
|
||||
line = line.replace(r"{{" + key + r"}}", str(value))
|
||||
new_lines.append(line)
|
||||
return new_lines
|
||||
|
||||
|
||||
#TODO(abidlabs): Handle the http vs. https issue that sometimes happens (a ws cannot be loaded from an https page)
|
||||
def set_ngrok_url_in_js(temp_dir, ngrok_socket_url):
|
||||
ngrok_socket_url = ngrok_socket_url.replace('http', 'ws')
|
||||
js_file = os.path.join(temp_dir, BASE_JS_FILE)
|
||||
render_template_with_tags(js_file, {'ngrok_socket_url': ngrok_socket_url})
|
||||
config_file = os.path.join(temp_dir, CONFIG_FILE)
|
||||
render_template_with_tags(config_file, {'ngrok_socket_url': ngrok_socket_url})
|
||||
|
||||
|
||||
def set_socket_port_in_js(temp_dir, socket_port):
|
||||
js_file = os.path.join(temp_dir, BASE_JS_FILE)
|
||||
render_template_with_tags(js_file, {'socket_port': str(socket_port)})
|
||||
|
||||
|
||||
def set_interface_types_in_config_file(temp_dir, input_interface, output_interface):
|
||||
config_file = os.path.join(temp_dir, CONFIG_FILE)
|
||||
render_template_with_tags(config_file, {'input_interface_type': input_interface,
|
||||
'output_interface_type': output_interface})
|
||||
render_template_with_tags(
|
||||
config_file,
|
||||
{
|
||||
"input_interface_type": input_interface,
|
||||
"output_interface_type": output_interface,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def set_share_url_in_config_file(temp_dir, share_url):
|
||||
config_file = os.path.join(temp_dir, CONFIG_FILE)
|
||||
render_template_with_tags(
|
||||
config_file,
|
||||
{
|
||||
"share_url": share_url,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_first_available_port(initial, final):
|
||||
@ -156,13 +179,22 @@ def get_first_available_port(initial, final):
|
||||
return port
|
||||
except OSError:
|
||||
pass
|
||||
raise OSError("All ports from {} to {} are in use. Please close a port.".format(initial, final))
|
||||
raise OSError(
|
||||
"All ports from {} to {} are in use. Please close a port.".format(
|
||||
initial, final
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def serve_files_in_background(port, directory_to_serve=None):
|
||||
def serve_files_in_background(interface, port, directory_to_serve=None):
|
||||
class HTTPHandler(SimpleHTTPRequestHandler):
|
||||
"""This handler uses server.base_path instead of always using os.getcwd()"""
|
||||
|
||||
def _set_headers(self):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
|
||||
def translate_path(self, path):
|
||||
path = SimpleHTTPRequestHandler.translate_path(self, path)
|
||||
relpath = os.path.relpath(path, os.getcwd())
|
||||
@ -172,6 +204,36 @@ def serve_files_in_background(port, directory_to_serve=None):
|
||||
def log_message(self, format, *args):
|
||||
return
|
||||
|
||||
def do_POST(self):
|
||||
# Read body of the request.
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
|
||||
if self.path == "/api/predict/":
|
||||
# Make the prediction.
|
||||
msg = json.loads(data_string)
|
||||
processed_input = interface.input_interface.preprocess(msg["data"])
|
||||
prediction = interface.predict(processed_input)
|
||||
processed_output = interface.output_interface.postprocess(prediction)
|
||||
output = {"action": "output", "data": processed_output}
|
||||
|
||||
# Prepare return json dictionary.
|
||||
self.wfile.write(json.dumps(output).encode())
|
||||
|
||||
elif self.path == "/api/flag/":
|
||||
msg = json.loads(data_string)
|
||||
flag_dir = FLAGGING_DIRECTORY.format(interface.hash)
|
||||
os.makedirs(flag_dir, exist_ok=True)
|
||||
dict = {'input': interface.input_interface.rebuild_flagged(flag_dir, msg),
|
||||
'output': interface.output_interface.rebuild_flagged(flag_dir, msg),
|
||||
'message': msg['data']['message']}
|
||||
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
|
||||
f.write(json.dumps(dict))
|
||||
f.write("\n")
|
||||
|
||||
else:
|
||||
self.send_response(404)
|
||||
|
||||
class HTTPServer(BaseHTTPServer):
|
||||
"""The main server, you pass in base_path which is the path you want to serve requests from"""
|
||||
|
||||
@ -196,9 +258,11 @@ def serve_files_in_background(port, directory_to_serve=None):
|
||||
return httpd
|
||||
|
||||
|
||||
def start_simple_server(directory_to_serve=None):
|
||||
port = get_first_available_port(INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS)
|
||||
httpd = serve_files_in_background(port, directory_to_serve)
|
||||
def start_simple_server(interface, directory_to_serve=None):
|
||||
port = get_first_available_port(
|
||||
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
|
||||
)
|
||||
httpd = serve_files_in_background(interface, port, directory_to_serve)
|
||||
return port, httpd
|
||||
|
||||
|
||||
@ -207,68 +271,23 @@ def close_server(server):
|
||||
server.server_close()
|
||||
|
||||
|
||||
def download_ngrok():
|
||||
def url_request(url):
|
||||
try:
|
||||
zip_file_url = NGROK_ZIP_URLS[sys.platform]
|
||||
except KeyError:
|
||||
print("Sorry, we don't currently support your operating system, please leave us a note on GitHub, and "
|
||||
"we'll look into it!")
|
||||
return
|
||||
r = requests.get(zip_file_url)
|
||||
z = zipfile.ZipFile(io.BytesIO(r.content))
|
||||
z.extractall()
|
||||
if sys.platform == 'darwin' or sys.platform == 'linux':
|
||||
st = os.stat('ngrok')
|
||||
os.chmod('ngrok', st.st_mode | stat.S_IEXEC)
|
||||
req = urllib.request.Request(
|
||||
url=url, headers={"content-type": "application/json"}
|
||||
)
|
||||
res = urllib.request.urlopen(req, timeout=10)
|
||||
return res
|
||||
except Exception as e:
|
||||
raise RuntimeError(str(e))
|
||||
|
||||
|
||||
def create_ngrok_tunnel(local_port, log_file):
|
||||
if not(os.path.isfile('ngrok.exe') or os.path.isfile('ngrok')):
|
||||
download_ngrok()
|
||||
if sys.platform == 'win32':
|
||||
subprocess.Popen(['ngrok', 'http', str(local_port), '--log', log_file, '--log-format', 'json'])
|
||||
else:
|
||||
subprocess.Popen(['./ngrok', 'http', str(local_port), '--log', log_file, '--log-format', 'json'])
|
||||
time.sleep(1.5) # Let ngrok write to the log file TODO(abidlabs): a better way to do this.
|
||||
session = requests.Session()
|
||||
retry = Retry(connect=3, backoff_factor=0.5)
|
||||
adapter = HTTPAdapter(max_retries=retry)
|
||||
session.mount('http://', adapter)
|
||||
session.mount('https://', adapter)
|
||||
|
||||
api_url = None
|
||||
with open(log_file) as f:
|
||||
for line in f:
|
||||
log = json.loads(line)
|
||||
if log["msg"] == "starting web service":
|
||||
api_url = log["addr"]
|
||||
api_port = urlparse(api_url).port
|
||||
break
|
||||
|
||||
if api_url is None:
|
||||
raise RuntimeError("Tunnel information not available in log file")
|
||||
|
||||
r = session.get(NGROK_TUNNEL_API_URL.format(api_url))
|
||||
for tunnel in r.json()['tunnels']:
|
||||
if '{}:'.format(LOCALHOST_NAME) + str(local_port) in tunnel['config']['addr'] and tunnel['proto'] == 'https':
|
||||
return tunnel['public_url'], api_port
|
||||
raise RuntimeError("Not able to retrieve ngrok public URL")
|
||||
|
||||
|
||||
def setup_ngrok(server_port, websocket_port, output_directory, existing_ports):
|
||||
if not(existing_ports is None):
|
||||
kill_processes(existing_ports)
|
||||
site_ngrok_url, port1 = create_ngrok_tunnel(server_port, os.path.join(output_directory, 'ngrok1.log'))
|
||||
socket_ngrok_url, port2 = create_ngrok_tunnel(websocket_port, os.path.join(output_directory, 'ngrok2.log'))
|
||||
set_ngrok_url_in_js(output_directory, socket_ngrok_url)
|
||||
return site_ngrok_url, [port1, port2]
|
||||
|
||||
|
||||
def kill_processes(process_ids): #TODO(abidlabs): remove this, we shouldn't need to kill
|
||||
for proc in process_iter():
|
||||
def setup_tunnel(local_server_port):
|
||||
response = url_request(GRADIO_API_SERVER)
|
||||
if response and response.code == 200:
|
||||
try:
|
||||
for conns in proc.connections(kind='inet'):
|
||||
if conns.laddr.port in process_ids:
|
||||
proc.send_signal(SIGTERM) # or SIGKILL
|
||||
except (AccessDenied, NoSuchProcess):
|
||||
pass
|
||||
payload = json.loads(response.read().decode("utf-8"))[0]
|
||||
return create_tunnel(payload, LOCALHOST_NAME, local_server_port)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(str(e))
|
||||
|
@ -54,6 +54,13 @@ class AbstractOutput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rebuild_flagged(self, inp):
|
||||
"""
|
||||
All interfaces should define a method that rebuilds the flagged output when it's passed back (i.e. rebuilds image from base64)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Label(AbstractOutput):
|
||||
LABEL_KEY = 'label'
|
||||
@ -107,6 +114,12 @@ class Label(AbstractOutput):
|
||||
raise ValueError("Unable to post-process model prediction.")
|
||||
return json.dumps(response)
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
"""
|
||||
Default rebuild method for label
|
||||
"""
|
||||
out = msg['data']['output']
|
||||
return json.loads(out)
|
||||
|
||||
class Textbox(AbstractOutput):
|
||||
|
||||
@ -118,6 +131,12 @@ class Textbox(AbstractOutput):
|
||||
"""
|
||||
return prediction
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
"""
|
||||
Default rebuild method for label
|
||||
"""
|
||||
out = msg['data']['output']
|
||||
return json.loads(out)
|
||||
|
||||
class Image(AbstractOutput):
|
||||
|
||||
@ -129,5 +148,16 @@ class Image(AbstractOutput):
|
||||
"""
|
||||
return prediction
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
"""
|
||||
Default rebuild method to decode a base64 image
|
||||
"""
|
||||
out = msg['data']['output']
|
||||
im = preprocessing_utils.encoding_to_image(out)
|
||||
timestamp = datetime.datetime.now()
|
||||
filename = f'output_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
return filename
|
||||
|
||||
|
||||
registry = {cls.__name__.lower(): cls for cls in AbstractOutput.__subclasses__()}
|
||||
|
11
build/lib/gradio/static/apple-app-site-association
Normal file
11
build/lib/gradio/static/apple-app-site-association
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"applinks": {
|
||||
"apps": [],
|
||||
"details": [
|
||||
{
|
||||
"appID": "RHW8FBGSTX.app.gradio.Gradio",
|
||||
"paths": ["*"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
{
|
||||
"input_interface_type": "{{input_interface_type}}",
|
||||
"output_interface_type": "{{output_interface_type}}",
|
||||
"ngrok_socket_url": "{{ngrok_socket_url}}"
|
||||
"share_url": "{{share_url}}"
|
||||
}
|
||||
|
@ -36,22 +36,28 @@
|
||||
.panel_buttons {
|
||||
display: flex;
|
||||
}
|
||||
.panel_buttons > input, .panel_buttons > button {
|
||||
background-color: #F6F6F6;
|
||||
.submit, .clear, .flag {
|
||||
background-color: #F6F6F6 !important;
|
||||
flex-grow: 1;
|
||||
padding: 8px;
|
||||
padding: 8px !important;
|
||||
box-sizing: border-box;
|
||||
text-transform: uppercase;
|
||||
font-weight: bold;
|
||||
border: 0 none;
|
||||
border: 0 none !important;
|
||||
}
|
||||
.submit {
|
||||
background-color: #EEA45D !important;
|
||||
color: white !important;
|
||||
}
|
||||
.submit, .flag_message {
|
||||
flex-grow: 2 !important;
|
||||
margin-right: 8px;
|
||||
}
|
||||
.clear {
|
||||
.flag_message {
|
||||
padding: 8px !important;
|
||||
background-color: #F6F6F6 !important;
|
||||
}
|
||||
.clear, .flag {
|
||||
margin-left: 8px;
|
||||
}
|
||||
|
||||
@ -83,14 +89,6 @@
|
||||
font-weight: bold;
|
||||
font-size: 14px;
|
||||
}
|
||||
.interface_button.primary {
|
||||
color: white;
|
||||
background-color: #EEA45D;
|
||||
}
|
||||
.interface_button.secondary {
|
||||
color: black;
|
||||
background-color: #F6F6F6;
|
||||
}
|
||||
.overlay {
|
||||
position: absolute;
|
||||
height: 100vh;
|
||||
@ -101,3 +99,6 @@
|
||||
top: 0;
|
||||
left: 0;
|
||||
}
|
||||
.flag.flagged {
|
||||
background-color: pink !important;
|
||||
}
|
||||
|
@ -1,32 +1,41 @@
|
||||
.confidence_intervals {
|
||||
font-size: 16px;
|
||||
display: flex;
|
||||
font-size: 20px;
|
||||
}
|
||||
.confidences {
|
||||
flex-grow: 1;
|
||||
display: flex;
|
||||
flex-flow: column;
|
||||
align-items: baseline;
|
||||
font-family: monospace;
|
||||
}
|
||||
.confidence {
|
||||
padding: 3px;
|
||||
background-color: #888888;
|
||||
color: white;
|
||||
text-align: right;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: flex-end;
|
||||
}
|
||||
.level, .label {
|
||||
display: inline-block;
|
||||
.labels {
|
||||
max-width: 120px;
|
||||
margin-right: 4px;
|
||||
}
|
||||
.label, .confidence {
|
||||
overflow: hidden;
|
||||
white-space: nowrap;
|
||||
height: 27px;
|
||||
margin-bottom: 4px;
|
||||
padding: 2px;
|
||||
}
|
||||
.label {
|
||||
width: 60px;
|
||||
}
|
||||
.confidence_intervals .level {
|
||||
font-size: 14px;
|
||||
margin-left: 8px;
|
||||
margin-right: 8px;
|
||||
background-color: #AAA;
|
||||
padding: 2px 4px;
|
||||
text-overflow: ellipsis;
|
||||
text-align: right;
|
||||
font-family: monospace;
|
||||
color: white;
|
||||
font-weight: bold;
|
||||
}
|
||||
.confidence_intervals > * {
|
||||
vertical-align: bottom;
|
||||
}
|
||||
.flag.flagged {
|
||||
background-color: pink !important;
|
||||
.confidence {
|
||||
text-overflow: clip;
|
||||
padding-left: 6px;
|
||||
padding-right: 6px;
|
||||
}
|
||||
.output_class {
|
||||
font-weight: bold;
|
||||
|
@ -14,21 +14,52 @@ button, input[type="submit"], input[type="reset"], input[type="text"], input[typ
|
||||
-webkit-appearance: none;
|
||||
border-radius: 0;
|
||||
}
|
||||
nav, #panels {
|
||||
nav, #panels, #share_row {
|
||||
margin-left: 60px;
|
||||
margin-right: 60px;
|
||||
}
|
||||
nav {
|
||||
text-align: center;
|
||||
padding: 16px 0 8px;
|
||||
padding: 16px 0 4px;
|
||||
}
|
||||
nav img {
|
||||
margin-right: auto;
|
||||
height: 32px;
|
||||
}
|
||||
#share_row {
|
||||
justify-content: center;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
#share_form {
|
||||
flex-grow: 1;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
#share_row button, #share_row input[type=text] {
|
||||
padding: 6px 4px;
|
||||
margin-left: 5px;
|
||||
}
|
||||
#share_row input[type=text] {
|
||||
background-color: #F6F6F6;
|
||||
}
|
||||
#share_email {
|
||||
flex-grow: 1;
|
||||
max-width: 400px;
|
||||
}
|
||||
#share_row, #share_complete, #share_form {
|
||||
display: none;
|
||||
}
|
||||
#panels {
|
||||
display: flex;
|
||||
flex-flow: row;
|
||||
flex-wrap: wrap;
|
||||
justify-content: center;
|
||||
}
|
||||
button.primary {
|
||||
color: white;
|
||||
background-color: #EEA45D;
|
||||
}
|
||||
button.secondary {
|
||||
color: black;
|
||||
background-color: #F6F6F6;
|
||||
}
|
||||
|
BIN
build/lib/gradio/static/img/logo_loading.gif
Normal file
BIN
build/lib/gradio/static/img/logo_loading.gif
Normal file
Binary file not shown.
After Width: | Height: | Size: 188 KiB |
BIN
build/lib/gradio/static/img/logo_only.png
Normal file
BIN
build/lib/gradio/static/img/logo_only.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.4 KiB |
@ -1,36 +1,43 @@
|
||||
var NGROK_URL = "{{ngrok_socket_url}}"
|
||||
var SOCKET_PORT = "{{socket_port}}"
|
||||
|
||||
var origin = window.location.origin;
|
||||
if (origin.includes("ngrok") || origin.includes("gradio.app")){ //TODO(abidlabs): better way to distinguish localhost?
|
||||
var ws = new WebSocket(NGROK_URL)
|
||||
} else {
|
||||
var ws = new WebSocket("ws://127.0.0.1:" + SOCKET_PORT + "/")
|
||||
}
|
||||
ws.onclose = function(event) {
|
||||
console.log("WebSocket is closed now.");
|
||||
}
|
||||
|
||||
var io_master = {
|
||||
input: function(interface_id, data) {
|
||||
var ws_data = {
|
||||
'action': 'input',
|
||||
this.last_input = data;
|
||||
this.last_output = null;
|
||||
var post_data = {
|
||||
'data': data
|
||||
};
|
||||
console.log(ws_data)
|
||||
ws.send(JSON.stringify(ws_data), function(e) {
|
||||
console.log(e)
|
||||
})
|
||||
$.ajax({type: "POST",
|
||||
url: "/api/predict/",
|
||||
data: JSON.stringify(post_data),
|
||||
success: function(output){
|
||||
if (output['action'] == 'output') {
|
||||
io_master.output(output['data']);
|
||||
}
|
||||
},
|
||||
error: function(XMLHttpRequest, textStatus, errorThrown) {
|
||||
console.log(XMLHttpRequest);
|
||||
console.log(textStatus);
|
||||
console.log(errorThrown);
|
||||
}
|
||||
});
|
||||
},
|
||||
output: function(data) {
|
||||
console.log(data)
|
||||
this.last_output = data;
|
||||
this.output_interface.output(data);
|
||||
},
|
||||
flag: function(message) {
|
||||
var post_data = {
|
||||
'data': {
|
||||
'input' : this.last_input,
|
||||
'output' : this.last_output,
|
||||
'message' : message
|
||||
}
|
||||
}
|
||||
$.ajax({type: "POST",
|
||||
url: "/api/flag/",
|
||||
data: JSON.stringify(post_data),
|
||||
success: function(output){
|
||||
console.log("Flagging successful")
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
ws.onmessage = function (event) {
|
||||
var output = JSON.parse(event.data)
|
||||
if (output['action'] == 'output') {
|
||||
io_master.output(output['data']);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -68,6 +68,9 @@ const image_input = {
|
||||
let io = get_interface(e.target);
|
||||
io.overlay_target.addClass("hide");
|
||||
if ($(e.target).hasClass('tui_save')) {
|
||||
// if (io.tui_editor.ui.submenu == "crop") {
|
||||
// io.tui_editor._cropAction().crop());
|
||||
// }
|
||||
io.set_image_data(io.tui_editor.toDataURL(), /*update_editor=*/false);
|
||||
}
|
||||
})
|
||||
|
@ -7,6 +7,7 @@ const image_output = {
|
||||
this.target.find(".output_image").attr('src', data).show();
|
||||
},
|
||||
submit: function() {
|
||||
this.target.find(".output_image").attr('src', 'static/img/logo_loading.gif').show();
|
||||
},
|
||||
clear: function() {
|
||||
this.target.find(".output_image").attr('src', "").hide();
|
||||
|
@ -1,28 +1,33 @@
|
||||
const label_output = {
|
||||
html: `
|
||||
<div class="output_class"></div>
|
||||
<div class="confidence_intervals"></div>
|
||||
<div class="confidence_intervals">
|
||||
<div class="labels"></div>
|
||||
<div class="confidences"></div>
|
||||
</div>
|
||||
`,
|
||||
init: function() {},
|
||||
output: function(data) {
|
||||
data = JSON.parse(data)
|
||||
this.target.find(".output_class").html(data["label"])
|
||||
this.target.find(".confidence_intervals").empty()
|
||||
this.target.find(".confidence_intervals > div").empty()
|
||||
if (data.confidences) {
|
||||
for (var i = 0; i < data.confidences.length; i++)
|
||||
{
|
||||
let c = data.confidences[i]
|
||||
let confidence = c["confidence"]
|
||||
this.target.find(".confidence_intervals").append(`<div class="confidence"><div class=
|
||||
"label">${c["label"]}</div><div class="level" style="flex-grow:
|
||||
${confidence}">${Math.round(confidence * 100)}%</div></div>`)
|
||||
let label = c["label"]
|
||||
let confidence = Math.round(c["confidence"] * 100) + "%";
|
||||
this.target.find(".labels").append(`<div class="label" title="${label}">${label}</div>`);
|
||||
this.target.find(".confidences").append(`
|
||||
<div class="confidence" style="min-width: calc(${confidence} - 12px);" title="${confidence}">${confidence}</div>`);
|
||||
}
|
||||
}
|
||||
},
|
||||
submit: function() {
|
||||
this.target.find(".output_class").html("<img src='static/img/logo_loading.gif'>")
|
||||
},
|
||||
clear: function() {
|
||||
this.target.find(".output_class").empty();
|
||||
this.target.find(".confidence_intervals").empty();
|
||||
this.target.find(".confidence_intervals > div").empty();
|
||||
}
|
||||
}
|
||||
|
@ -23,11 +23,13 @@ function get_interface(target) {
|
||||
attr("interface_id")];
|
||||
}
|
||||
|
||||
var config;
|
||||
$.getJSON("static/config.json", function(data) {
|
||||
config = data;
|
||||
input_interface = Object.create(input_to_object_map[
|
||||
data["input_interface_type"]]);
|
||||
config["input_interface_type"]]);
|
||||
output_interface = Object.create(output_to_object_map[
|
||||
data["output_interface_type"]]);
|
||||
config["output_interface_type"]]);
|
||||
$("#input_interface").html(input_interface.html);
|
||||
input_interface.target = $("#input_interface");
|
||||
set_interface_id(input_interface, 1)
|
||||
@ -39,21 +41,28 @@ $.getJSON("static/config.json", function(data) {
|
||||
$(".submit").click(function() {
|
||||
input_interface.submit();
|
||||
output_interface.submit();
|
||||
$(".flag").removeClass("flagged");
|
||||
})
|
||||
$(".clear").click(function() {
|
||||
input_interface.clear();
|
||||
output_interface.clear();
|
||||
output_interface.clear();
|
||||
$(".flag").removeClass("flagged");
|
||||
$(".flag_message").empty()
|
||||
io_master.last_input = null;
|
||||
io_master.last_output = null;
|
||||
})
|
||||
input_interface.io_master = io_master;
|
||||
io_master.input_interface = input_interface;
|
||||
output_interface.io_master = io_master;
|
||||
io_master.output_interface = output_interface;
|
||||
if (config["share_url"] != "None") {
|
||||
$("#share_row").css('display', 'flex');
|
||||
}
|
||||
});
|
||||
|
||||
$('body').on('click', '.flag', function(e) {
|
||||
if ($(".flag").hasClass("flagged")) {
|
||||
$(".flag").removeClass("flagged").attr("value", "flag");
|
||||
} else {
|
||||
$(".flag").addClass("flagged").attr("value", "flagged");
|
||||
if (io_master.last_output) {
|
||||
$(".flag").addClass("flagged");
|
||||
io_master.flag($(".flag_message").val());
|
||||
}
|
||||
})
|
||||
|
41
build/lib/gradio/static/js/share.js
Normal file
41
build/lib/gradio/static/js/share.js
Normal file
@ -0,0 +1,41 @@
|
||||
$("#share").click(function() {
|
||||
$("#share").hide()
|
||||
$("#share_form").css('display', 'flex')
|
||||
})
|
||||
|
||||
$("#send_link").click(function(evt) {
|
||||
let name = $("#share_name").val()
|
||||
let email = $("#share_email").val()
|
||||
if (name && email) {
|
||||
$("#send_link").attr('disabled', true);
|
||||
$.ajax({
|
||||
"url" : "https://gradio.app/api/send-email/",
|
||||
"type": "POST",
|
||||
"crossDomain": true,
|
||||
"data": {
|
||||
"url": config["ngrok_socket_url"],
|
||||
"name": name,
|
||||
"email": email
|
||||
},
|
||||
"success": function() {
|
||||
$("#share_message").text("Shared successfully.");
|
||||
$("#share_more").text("Share more");
|
||||
},
|
||||
"error": function() {
|
||||
$("#share_message").text("Failed to share.");
|
||||
$("#share_more").text("Try again");
|
||||
},
|
||||
"complete": function() {
|
||||
$("#share_form").hide();
|
||||
$("#share_complete").show();
|
||||
$("#send_link").attr('disabled', false);
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
$("#share_more").click(function (evt) {
|
||||
$("#share_email").val("");
|
||||
$("#share_form").show();
|
||||
$("#share_complete").hide();
|
||||
})
|
@ -29,7 +29,6 @@ function resizeImage(base64Str, max_width, max_height, callback) {
|
||||
canvas.height = height;
|
||||
var ctx = canvas.getContext('2d');
|
||||
ctx.drawImage(img, 0, 0, width, height);
|
||||
console.log(canvas.toDataURL())
|
||||
callback.call(null, canvas.toDataURL());
|
||||
}
|
||||
}
|
||||
|
@ -8262,7 +8262,7 @@ return /******/ (function(modules) { // webpackBootstrap
|
||||
|
||||
function _inherits(subClass, superClass) { if (typeof superClass !== "function" && superClass !== null) { throw new TypeError("Super expression must either be null or a function, not " + typeof superClass); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, enumerable: false, writable: true, configurable: true } }); if (superClass) Object.setPrototypeOf ? Object.setPrototypeOf(subClass, superClass) : subClass.__proto__ = superClass; }
|
||||
|
||||
var DRAW_OPACITY = 0.7;
|
||||
var DRAW_OPACITY = 1;
|
||||
|
||||
/**
|
||||
* Draw ui class
|
||||
|
@ -5,5 +5,5 @@ en = {
|
||||
"restarting python interpreter.",
|
||||
"COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.",
|
||||
"PUBLIC_SHARE_TRUE": "To create a public link, set `share=True` in the argument to `launch()`.",
|
||||
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly for 8 hours at: {}"
|
||||
"MODEL_PUBLICLY_AVAILABLE_URL": "Model available publicly at: {} -- may take up to a minute to setup."
|
||||
}
|
||||
|
@ -21,8 +21,20 @@
|
||||
|
||||
<body>
|
||||
<nav>
|
||||
<img src="../static/img/logo_inline.png" />
|
||||
<a href="https://gradio.app"><img src="../static/img/logo_inline.png" /></a>
|
||||
</nav>
|
||||
<div id="share_row">
|
||||
<button id="share" class="primary">Share this Interface</button>
|
||||
<div id="share_form">
|
||||
<input type="text" id="share_name" placeholder="sender name (you)"></input>
|
||||
<input type="text" id="share_email" placeholder="emails (comma-separated if multiple)"></input>
|
||||
<button class="primary" id="send_link">Send Link</button>
|
||||
</div>
|
||||
<div id="share_complete">
|
||||
<span id="share_message"></span>
|
||||
<button class="secondary" id="share_more"></button>
|
||||
</div>
|
||||
</div>
|
||||
<div id="panels">
|
||||
<div class="panel">
|
||||
<div class="panel_header">Input</div>
|
||||
@ -42,6 +54,7 @@
|
||||
</div>
|
||||
<div id="output_interface" class="interface"></div>
|
||||
<div class="panel_buttons">
|
||||
<input type="text" class="flag_message" placeholder="(Optional message for flagging)"/>
|
||||
<input type="button" class="flag" value="flag"/>
|
||||
</div>
|
||||
</div>
|
||||
@ -60,13 +73,14 @@
|
||||
<script src="../static/js/all_io.js"></script>
|
||||
<script src="../static/js/interfaces/input/csv.js"></script>
|
||||
<script src="../static/js/interfaces/input/image_upload.js"></script>
|
||||
<script src="../static/js/vendor/sketchpad.js"></script>
|
||||
<script src="../static/js/vendor/sketchpad.js"></script>
|
||||
<script src="../static/js/interfaces/input/sketchpad.js"></script>
|
||||
<script src="../static/js/interfaces/input/textbox.js"></script>
|
||||
<script src="../static/js/interfaces/input/csv.js"></script>
|
||||
<script src="../static/js/interfaces/output/image.js"></script>
|
||||
<script src="../static/js/interfaces/output/label.js"></script>
|
||||
<script src="../static/js/interfaces/output/textbox.js"></script>
|
||||
<script src="../static/js/share.js"></script>
|
||||
<script src="../static/js/load_interfaces.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
|
103
build/lib/gradio/tunneling.py
Normal file
103
build/lib/gradio/tunneling.py
Normal file
@ -0,0 +1,103 @@
|
||||
"""
|
||||
This file provides remote port forwarding functionality using paramiko package,
|
||||
Inspired by: https://github.com/paramiko/paramiko/blob/master/demos/rforward.py
|
||||
"""
|
||||
|
||||
import select
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
from io import StringIO
|
||||
import warnings
|
||||
import paramiko
|
||||
|
||||
DEBUG_MODE = False
|
||||
|
||||
|
||||
def handler(chan, host, port):
|
||||
sock = socket.socket()
|
||||
try:
|
||||
sock.connect((host, port))
|
||||
except Exception as e:
|
||||
verbose("Forwarding request to %s:%d failed: %r" % (host, port, e))
|
||||
return
|
||||
|
||||
verbose(
|
||||
"Connected! Tunnel open %r -> %r -> %r"
|
||||
% (chan.origin_addr, chan.getpeername(), (host, port))
|
||||
)
|
||||
while True:
|
||||
r, w, x = select.select([sock, chan], [], [])
|
||||
if sock in r:
|
||||
data = sock.recv(1024)
|
||||
if len(data) == 0:
|
||||
break
|
||||
chan.send(data)
|
||||
if chan in r:
|
||||
data = chan.recv(1024)
|
||||
if len(data) == 0:
|
||||
break
|
||||
sock.send(data)
|
||||
chan.close()
|
||||
sock.close()
|
||||
verbose("Tunnel closed from %r" % (chan.origin_addr,))
|
||||
|
||||
|
||||
def reverse_forward_tunnel(server_port, remote_host, remote_port, transport):
|
||||
transport.request_port_forward("", server_port)
|
||||
while True:
|
||||
chan = transport.accept(1000)
|
||||
if chan is None:
|
||||
continue
|
||||
thr = threading.Thread(target=handler, args=(chan, remote_host, remote_port))
|
||||
thr.setDaemon(True)
|
||||
thr.start()
|
||||
|
||||
|
||||
def verbose(s):
|
||||
if DEBUG_MODE:
|
||||
print(s)
|
||||
|
||||
|
||||
def create_tunnel(payload, local_server, local_server_port):
|
||||
client = paramiko.SSHClient()
|
||||
# client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||
client.set_missing_host_key_policy(paramiko.WarningPolicy())
|
||||
|
||||
verbose(
|
||||
"Connecting to ssh host %s:%d ..." % (payload["host"], int(payload["port"]))
|
||||
)
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
client.connect(
|
||||
hostname=payload["host"],
|
||||
port=int(payload["port"]),
|
||||
username=payload["user"],
|
||||
pkey=paramiko.RSAKey.from_private_key(StringIO(payload["key"])),
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
"*** Failed to connect to %s:%d: %r"
|
||||
% (payload["host"], int(payload["port"]), e)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
verbose(
|
||||
"Now forwarding remote port %d to %s:%d ..."
|
||||
% (int(payload["remote_port"]), local_server, local_server_port)
|
||||
)
|
||||
|
||||
thread = threading.Thread(
|
||||
target=reverse_forward_tunnel,
|
||||
args=(
|
||||
int(payload["remote_port"]),
|
||||
local_server,
|
||||
local_server_port,
|
||||
client.get_transport(),
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
|
||||
return payload["share_url"]
|
Binary file not shown.
@ -1,6 +1,6 @@
|
||||
Metadata-Version: 1.0
|
||||
Name: gradio
|
||||
Version: 0.5.0
|
||||
Version: 0.7.0
|
||||
Summary: Python library for easily interacting with trained machine learning models
|
||||
Home-page: https://github.com/abidlabs/gradio
|
||||
Author: Abubakar Abid
|
||||
|
@ -9,12 +9,14 @@ gradio/networking.py
|
||||
gradio/outputs.py
|
||||
gradio/preprocessing_utils.py
|
||||
gradio/strings.py
|
||||
gradio/tunneling.py
|
||||
gradio/validation_data.py
|
||||
gradio.egg-info/PKG-INFO
|
||||
gradio.egg-info/SOURCES.txt
|
||||
gradio.egg-info/dependency_links.txt
|
||||
gradio.egg-info/requires.txt
|
||||
gradio.egg-info/top_level.txt
|
||||
gradio/static/apple-app-site-association
|
||||
gradio/static/config.json
|
||||
gradio/static/css/.DS_Store
|
||||
gradio/static/css/font-awesome.min.css
|
||||
@ -32,7 +34,9 @@ gradio/static/css/vendor/tui-color-picker.css
|
||||
gradio/static/css/vendor/tui-image-editor.css
|
||||
gradio/static/img/logo.png
|
||||
gradio/static/img/logo_inline.png
|
||||
gradio/static/img/logo_loading.gif
|
||||
gradio/static/img/logo_mini.png
|
||||
gradio/static/img/logo_only.png
|
||||
gradio/static/img/mic.png
|
||||
gradio/static/img/table.png
|
||||
gradio/static/img/webcam.png
|
||||
@ -42,6 +46,7 @@ gradio/static/img/vendor/icon-c.svg
|
||||
gradio/static/img/vendor/icon-d.svg
|
||||
gradio/static/js/all_io.js
|
||||
gradio/static/js/load_interfaces.js
|
||||
gradio/static/js/share.js
|
||||
gradio/static/js/utils.js
|
||||
gradio/static/js/interfaces/input/csv.js
|
||||
gradio/static/js/interfaces/input/image_upload.js
|
||||
|
Loading…
Reference in New Issue
Block a user