fix default interpretation to support different types for input

This commit is contained in:
Ali Abid 2020-09-21 13:53:40 -07:00
commit 102b151d42
34 changed files with 792 additions and 795 deletions

2
.gitignore vendored
View File

@ -22,3 +22,5 @@ dist/*
docs.json
*.bak
demo/tmp.zip
demo/flagged
test.txt

View File

@ -292,6 +292,7 @@ class Image(InputComponent):
def preprocess(self, x):
im = processing_utils.decode_base64_to_image(x)
fmt = im.format
with warnings.catch_warnings():
warnings.simplefilter("ignore")
im = im.convert(self.image_mode)
@ -305,7 +306,7 @@ class Image(InputComponent):
elif self.type == "numpy":
return np.array(im)
elif self.type == "file":
file_obj = tempfile.NamedTemporaryFile()
file_obj = tempfile.NamedTemporaryFile(suffix="."+fmt)
im.save(file_obj.name)
return file_obj
else:
@ -449,7 +450,7 @@ class Dataframe(InputComponent):
else:
return pd.DataFrame(x)
if self.col_count == 1:
x = x[0]
x = [row[0] for row in x]
if self.type == "numpy":
return np.array(x)
elif self.type == "array":

View File

@ -5,18 +5,10 @@ interface using the input and output types.
import tempfile
import webbrowser
from gradio.inputs import InputComponent
from gradio.inputs import Image
from gradio.inputs import Textbox
from gradio.outputs import OutputComponent
from gradio import networking, strings, utils, processing_utils
from distutils.version import StrictVersion
from skimage.segmentation import slic
from skimage.util import img_as_float
from gradio import processing_utils
import PIL
import pkg_resources
from gradio import networking, strings, utils
import gradio.interpretation
import requests
import random
import time
@ -26,9 +18,7 @@ import sys
import weakref
import analytics
import os
import numpy as np
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
analytics_url = 'https://api.gradio.app/'
try:
@ -53,8 +43,9 @@ class Interface:
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
live=False, show_input=True, show_output=True,
capture_session=False, explain_by=None, title=None, description=None,
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
capture_session=False, interpretation=None, title=None,
description=None, thumbnail=None, server_port=None,
server_name=networking.LOCALHOST_NAME,
allow_screenshot=True, allow_flagging=True,
flagging_dir="flagged", analytics_enabled=True):
@ -67,6 +58,7 @@ class Interface:
examples (List[List[Any]]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component.
live (bool): whether the interface should automatically reload on change.
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter.
title (str): a title for the interface; if provided, appears above the input and output components.
description (str): a description for the interface; if provided, appears above the input and output components.
thumbnail (str): path to image or src to use as display picture for models listed in gradio.app/hub
@ -108,6 +100,7 @@ class Interface:
if not isinstance(fn, list):
fn = [fn]
self.output_interfaces *= len(fn)
self.predict = fn
self.verbose = verbose
@ -117,7 +110,7 @@ class Interface:
self.show_output = show_output
self.flag_hash = random.getrandbits(32)
self.capture_session = capture_session
self.explain_by = explain_by
self.interpretation = interpretation
self.session = None
self.server_name = server_name
self.title = title
@ -186,7 +179,7 @@ class Interface:
"thumbnail": self.thumbnail,
"allow_screenshot": self.allow_screenshot,
"allow_flagging": self.allow_flagging,
"allow_interpretation": self.explain_by is not None
"allow_interpretation": self.interpretation is not None
}
try:
param_names = inspect.getfullargspec(self.predict[0])[0]
@ -199,21 +192,17 @@ class Interface:
iface[1]["label"] = ret_name
except ValueError:
pass
if self.examples is not None:
processed_examples = []
for example_set in self.examples:
processed_set = []
for iface, example in zip(self.input_interfaces, example_set):
processed_set.append(iface.process_example(example))
processed_examples.append(processed_set)
config["examples"] = processed_examples
return config
def process(self, raw_input, predict_fn=None):
"""
:param raw_input: a list of raw inputs to process and apply the
prediction(s) on.
:param predict_fn: which function to process. If not provided, all of the model functions are used.
:return:
processed output: a list of processed outputs to return as the
prediction(s).
duration: a list of time deltas measuring inference time for each
prediction fn.
"""
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(self.input_interfaces)]
def run_prediction(self, processed_input, return_duration=False):
predictions = []
durations = []
for predict_fn in self.predict:
@ -243,6 +232,27 @@ class Interface:
prediction = [prediction]
durations.append(duration)
predictions.extend(prediction)
if return_duration:
return predictions, durations
else:
return predictions
def process(self, raw_input, predict_fn=None):
"""
:param raw_input: a list of raw inputs to process and apply the
prediction(s) on.
:param predict_fn: which function to process. If not provided, all of the model functions are used.
:return:
processed output: a list of processed outputs to return as the
prediction(s).
duration: a list of time deltas measuring inference time for each
prediction fn.
"""
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(self.input_interfaces)]
predictions, durations = self.run_prediction(processed_input, return_duration=True)
processed_output = [output_interface.postprocess(
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
return processed_output, durations
@ -285,33 +295,22 @@ class Interface:
share (bool): whether to create a publicly shareable link from your computer for the interface.
debug (bool): if True, and the interface was launched from Google Colab, prints the errors in the cell output.
Returns
httpd (str): HTTPServer object
app (flask.Flask): Flask app object
path_to_local_server (str): Locally accessible link
share_url (str): Publicly accessible link (if share=True)
"""
output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface.
server_port, httpd, thread = networking.start_simple_server(
self, output_directory, self.server_name, server_port=self.server_port)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory)
config = self.get_config_file()
networking.set_config(config)
networking.set_meta_tags(self.title, self.description, self.thumbnail)
server_port, app, thread = networking.start_server(
self, self.server_port)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
self.server_port = server_port
self.status = "RUNNING"
self.simple_server = httpd
try:
current_pkg_version = pkg_resources.require("gradio")[0].version
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
print("IMPORTANT: You are using gradio version {}, "
"however version {} "
"is available, please upgrade.".format(
current_pkg_version, latest_pkg_version))
print('--------')
except: # TODO(abidlabs): don't catch all exceptions
pass
self.server = app
utils.version_check()
is_colab = utils.colab_check()
if not is_colab:
if not networking.url_ok(path_to_local_server):
@ -381,20 +380,7 @@ class Interface:
else:
display(IFrame(path_to_local_server, width=1000, height=500))
config = self.get_config_file()
config["share_url"] = share_url
processed_examples = []
if self.examples is not None:
for example_set in self.examples:
processed_set = []
for iface, example in zip(self.input_interfaces, example_set):
processed_set.append(iface.process_example(example))
processed_examples.append(processed_set)
config["examples"] = processed_examples
networking.set_config(config, output_directory)
networking.set_meta_tags(output_directory, self.title, self.description, self.thumbnail)
r = requests.get(path_to_local_server + "enable_sharing/" + (share_url or "None"))
if debug:
while True:
@ -402,14 +388,15 @@ class Interface:
time.sleep(0.1)
launch_method = 'browser' if inbrowser else 'inline'
data = {'launch_method': launch_method,
if self.analytics_enabled:
data = {
'launch_method': launch_method,
'is_google_colab': is_colab,
'is_sharing_on': share,
'share_url': share_url,
'ip_address': ip_address
}
if self.analytics_enabled:
}
try:
requests.post(analytics_url + 'gradio-launched-analytics/',
data=data)
@ -420,96 +407,8 @@ class Interface:
if not is_in_interactive_mode:
self.run_until_interrupted(thread, path_to_local_server)
return httpd, path_to_local_server, share_url
def tokenize_text(self, text):
leave_one_out_tokens = []
tokens = text.split()
for idx, _ in enumerate(tokens):
new_token_array = tokens.copy()
del new_token_array[idx]
leave_one_out_tokens.append(new_token_array)
return tokens, leave_one_out_tokens
def tokenize_image(self, image):
image = np.array(processing_utils.decode_base64_to_image(image))
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
leave_one_out_tokens = []
for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = segments_slic == segVal
white_screen = np.copy(image)
white_screen[segments_slic == segVal] = 255
leave_one_out_tokens.append((mask, white_screen))
return leave_one_out_tokens
def score_text(self, tokens, leave_one_out_tokens, text):
original_label = ""
original_confidence = 0
tokens = text.split()
input_text = " ".join(tokens)
original_output = self.process([input_text])
output = {result["label"] : result["confidence"]
for result in original_output[0][0]['confidences']}
original_label = original_output[0][0]["label"]
original_confidence = output[original_label]
scores = []
for idx, input_text in enumerate(leave_one_out_tokens):
input_text = " ".join(input_text)
raw_output = self.process([input_text])
output = {result["label"] : result["confidence"]
for result in raw_output[0][0]['confidences']}
scores.append(original_confidence - output[original_label])
scores_by_char = []
for idx, token in enumerate(tokens):
if idx != 0:
scores_by_char.append((" ", 0))
for char in token:
scores_by_char.append((char, scores[idx]))
return scores_by_char
def score_image(self, leave_one_out_tokens, image):
original_output = self.process([image])
output = {result["label"] : result["confidence"]
for result in original_output[0][0]['confidences']}
original_label = original_output[0][0]["label"]
original_confidence = output[original_label]
image_interface = self.input_interfaces[0]
shape = processing_utils.decode_base64_to_image(image).size
output_scores = np.full((shape[1], shape[0]), 0.0)
for mask, input_image in leave_one_out_tokens:
input_image_base64 = processing_utils.encode_array_to_base64(
input_image)
raw_output = self.process([input_image_base64])
output = {result["label"] : result["confidence"]
for result in raw_output[0][0]['confidences']}
score = original_confidence - output[original_label]
output_scores += score * mask
max_val = np.max(np.abs(output_scores))
if max_val > 0:
output_scores = output_scores / max_val
return output_scores.tolist()
def simple_explanation(self, x):
if isinstance(self.input_interfaces[0], Textbox):
tokens, leave_one_out_tokens = self.tokenize_text(x[0])
return [self.score_text(tokens, leave_one_out_tokens, x[0])]
elif isinstance(self.input_interfaces[0], Image):
leave_one_out_tokens = self.tokenize_image(x[0])
return [self.score_image(leave_one_out_tokens, x[0])]
else:
print("Not valid input type")
def explain(self, x):
if self.explain_by == "default":
return self.simple_explanation(x)
else:
preprocessed_x = [input_interface(x_i) for x_i, input_interface in zip(x, self.input_interfaces)]
return self.explain_by(*preprocessed_x)
return app, path_to_local_server, share_url
def reset_all():
for io in Interface.get_instances():

View File

@ -0,0 +1,103 @@
from gradio.inputs import Image, Textbox
from gradio.outputs import Label
from gradio import processing_utils
from skimage.segmentation import slic
import numpy as np
expected_types = {
Image: "numpy",
Textbox: "str"
}
def default(separator=" ", n_segments=20):
"""
Basic "default" interpretation method that uses "leave-one-out" to explain predictions for
the following inputs: Image, Text, and the following outputs: Label. In case of multiple
inputs and outputs, uses the first component.
"""
def tokenize_text(text):
leave_one_out_tokens = []
tokens = text.split(separator)
for idx, _ in enumerate(tokens):
new_token_array = tokens.copy()
del new_token_array[idx]
leave_one_out_tokens.append(new_token_array)
return leave_one_out_tokens
def tokenize_image(image):
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
leave_one_out_tokens = []
replace_color = np.mean(image, axis=(0, 1))
for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = segments_slic == segVal
white_screen = np.copy(image)
white_screen[segments_slic == segVal] = replace_color
leave_one_out_tokens.append((mask, white_screen))
return leave_one_out_tokens
def score_text(interface, leave_one_out_tokens, text):
tokens = text.split(separator)
original_output = interface.run_prediction([text])
scores_by_words = []
for idx, input_text in enumerate(leave_one_out_tokens):
perturbed_text = separator.join(input_text)
perturbed_output = interface.run_prediction([perturbed_text])
score = quantify_difference_in_label(interface, original_output, perturbed_output)
scores_by_words.append(score)
scores_by_char = []
for idx, token in enumerate(tokens):
if idx != 0:
scores_by_char.append((" ", 0))
for char in token:
scores_by_char.append((char, scores_by_words[idx]))
return scores_by_char
def score_image(interface, leave_one_out_tokens, image):
output_scores = np.zeros((image.shape[0], image.shape[1]))
original_output = interface.run_prediction([image])
for mask, perturbed_image in leave_one_out_tokens:
perturbed_output = interface.run_prediction([perturbed_image])
score = quantify_difference_in_label(interface, original_output, perturbed_output)
output_scores += score * mask
max_val, min_val = np.max(output_scores), np.min(output_scores)
if max_val > 0:
output_scores = (output_scores - min_val) / (max_val - min_val)
return output_scores.tolist()
def quantify_difference_in_label(interface, original_output, perturbed_output):
post_original_output = interface.output_interfaces[0].postprocess(original_output[0])
post_perturbed_output = interface.output_interfaces[0].postprocess(perturbed_output[0])
original_label = post_original_output["label"]
perturbed_label = post_perturbed_output["label"]
# Handle different return types of Label interface
if "confidences" in post_original_output:
original_confidence = original_output[0][original_label]
perturbed_confidence = perturbed_output[0][original_label]
score = original_confidence - perturbed_confidence
else:
try: # try computing numerical difference
score = float(original_label) - float(perturbed_label)
except ValueError: # otherwise, look at strict difference in label
score = int(not(perturbed_label == original_label))
return score
def default_interpretation(interface, x):
if isinstance(interface.input_interfaces[0], Textbox) \
and isinstance(interface.output_interfaces[0], Label):
leave_one_out_tokens = tokenize_text(x[0])
return [score_text(interface, leave_one_out_tokens, x[0])]
if isinstance(interface.input_interfaces[0], Image) \
and isinstance(interface.output_interfaces[0], Label):
leave_one_out_tokens = tokenize_image(x[0])
return [score_image(interface, leave_one_out_tokens, x[0])]
else:
print("Not valid input or output types for 'default' interpretation")
return default_interpretation

View File

@ -5,10 +5,11 @@ Defines helper methods useful for setting up ports, launching servers, and handl
import os
import socket
import threading
from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
from flask import Flask, request, jsonify, abort, send_file, render_template
from multiprocessing import Process
import pkg_resources
from distutils import dir_util
from gradio import inputs, outputs
import gradio as gr
import time
import json
from gradio.tunneling import create_tunnel
@ -17,7 +18,7 @@ from shutil import copyfile
import requests
import sys
import csv
import copy
INITIAL_PORT_VALUE = int(os.getenv(
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
@ -29,77 +30,23 @@ GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request"
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
STATIC_PATH_LIB = pkg_resources.resource_filename("gradio", "static/")
STATIC_PATH_TEMP = "static/"
TEMPLATE_TEMP = "index.html"
BASE_JS_FILE = "static/js/all_io.js"
CONFIG_FILE = "static/config.json"
ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association"
ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association"
app = Flask(__name__,
template_folder=STATIC_TEMPLATE_LIB,
static_folder=STATIC_PATH_LIB)
app.app_globals = {}
def build_template(temp_dir):
"""
Create HTML file with supporting JS and CSS files in a given directory.
:param temp_dir: string with path to temp directory in which the html file should be built
"""
dir_util.copy_tree(STATIC_TEMPLATE_LIB, temp_dir)
dir_util.copy_tree(STATIC_PATH_LIB, os.path.join(
temp_dir, STATIC_PATH_TEMP))
# Move association file to root of temporary directory.
copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC),
os.path.join(temp_dir, ASSOCIATION_PATH_IN_ROOT))
def render_template_with_tags(template_path, context):
"""
Combines the given template with a given context dictionary by replacing all of the occurrences of tags (enclosed
in double curly braces) with corresponding values.
:param template_path: a string with the path to the template file
:param context: a dictionary whose string keys are the tags to replace and whose string values are the replacements.
"""
with open(template_path) as fin:
old_lines = fin.readlines()
new_lines = render_string_or_list_with_tags(old_lines, context)
with open(template_path, "w") as fout:
for line in new_lines:
fout.write(line)
def render_string_or_list_with_tags(old_lines, context):
# Handle string case
if isinstance(old_lines, str):
for key, value in context.items():
old_lines = old_lines.replace(r"{{" + key + r"}}", str(value))
return old_lines
# Handle list case
new_lines = []
for line in old_lines:
for key, value in context.items():
line = line.replace(r"{{" + key + r"}}", str(value))
new_lines.append(line)
return new_lines
def set_meta_tags(temp_dir, title, description, thumbnail):
title = "Gradio" if title is None else title
description = "Easy-to-use UI for your machine learning model" if description is None else description
thumbnail = "https://gradio.app/static/img/logo_only.png" if thumbnail is None else thumbnail
index_file = os.path.join(temp_dir, TEMPLATE_TEMP)
render_template_with_tags(index_file, {
def set_meta_tags(title, description, thumbnail):
app.app_globals.update({
"title": title,
"description": description,
"thumbnail": thumbnail
})
def set_config(config, temp_dir):
config_file = os.path.join(temp_dir, CONFIG_FILE)
with open(config_file, "w") as output:
json.dump(config, output)
def set_config(config):
app.app_globals["config"] = config
def get_first_available_port(initial, final):
@ -124,134 +71,108 @@ def get_first_available_port(initial, final):
)
def serve_files_in_background(interface, port, directory_to_serve=None, server_name=LOCALHOST_NAME):
class HTTPHandler(SimpleHTTPRequestHandler):
"""This handler uses server.base_path instead of always using os.getcwd()"""
def _set_headers(self):
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
def end_headers(self):
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'GET, POST')
return super(HTTPHandler, self).end_headers()
def translate_path(self, path):
path = SimpleHTTPRequestHandler.translate_path(self, path)
relpath = os.path.relpath(path, os.getcwd())
fullpath = os.path.join(self.server.base_path, relpath)
return fullpath
def log_message(self, format, *args):
return
def do_POST(self):
# Read body of the request.
if self.path == "/api/predict/":
# Make the prediction.
self._set_headers()
data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
raw_input = msg["data"]
prediction, durations = interface.process(raw_input)
output = {"data": prediction, "durations": durations}
self.wfile.write(json.dumps(output).encode())
elif self.path == "/api/flag/":
self._set_headers()
data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
os.makedirs(interface.flagging_dir, exist_ok=True)
output = {'inputs': [interface.input_interfaces[
i].rebuild(
interface.flagging_dir, msg['data']['input_data'][i]) for i
in range(len(interface.input_interfaces))],
'outputs': [interface.output_interfaces[
i].rebuild(
interface.flagging_dir, msg['data']['output_data'][i])
for i
in range(len(interface.output_interfaces))]}
log_fp = "{}/log.csv".format(interface.flagging_dir)
is_new = not os.path.exists(log_fp)
with open(log_fp, "a") as csvfile:
headers = ["input_{}".format(i) for i in range(len(
output["inputs"]))] + ["output_{}".format(i) for i in
range(len(output["outputs"]))]
writer = csv.DictWriter(csvfile, delimiter=',',
lineterminator='\n',
fieldnames=headers)
if is_new:
writer.writeheader()
writer.writerow(
dict(zip(headers, output["inputs"] +
output["outputs"]))
)
elif self.path == "/api/interpret/":
self._set_headers()
data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
interpretation = interface.explain(msg["data"])
self.wfile.write(json.dumps(interpretation).encode())
else:
self.send_error(404, 'Path not found: {}'.format(self.path))
@app.route("/", methods=["GET"])
def main():
return render_template("index.html",
title=app.app_globals["title"],
description=app.app_globals["description"],
thumbnail=app.app_globals["thumbnail"],
)
def do_GET(self):
if self.path.startswith("/file/"):
self.send_response(200)
self.end_headers()
with open(self.path[6:], "rb") as f:
self.wfile.write(f.read())
else:
super().do_GET()
@app.route("/config/", methods=["GET"])
def config():
return jsonify(app.app_globals["config"])
class HTTPServer(BaseHTTPServer):
"""The main server, you pass in base_path which is the path you want to serve requests from"""
def __init__(self, base_path, server_address, RequestHandlerClass=HTTPHandler):
self.base_path = base_path
BaseHTTPServer.__init__(self, server_address, RequestHandlerClass)
class QuittableHTTPThread(threading.Thread):
def __init__(self, httpd):
super().__init__(daemon=False)
self.httpd = httpd
self.keep_running =True
def run(self):
while self.keep_running:
self.httpd.handle_request()
httpd = HTTPServer(directory_to_serve, (server_name, port))
thread = QuittableHTTPThread(httpd=httpd)
thread.start()
return httpd, thread
@app.route("/enable_sharing/<path:path>", methods=["GET"])
def enable_sharing(path):
if path == "None":
path = None
app.app_globals["config"]["share_url"] = path
return jsonify(success=True)
def start_simple_server(interface, directory_to_serve=None, server_name=None, server_port=None):
@app.route("/api/predict/", methods=["POST"])
def predict():
raw_input = request.json["data"]
prediction, durations = app.interface.process(raw_input)
output = {"data": prediction, "durations": durations}
return jsonify(output)
@app.route("/api/flag/", methods=["POST"])
def flag():
os.makedirs(app.interface.flagging_dir, exist_ok=True)
output = {'inputs': [app.interface.input_interfaces[
i].rebuild(
app.interface.flagging_dir, request.json['data']['input_data'][i]) for i
in range(len(app.interface.input_interfaces))],
'outputs': [app.interface.output_interfaces[
i].rebuild(
app.interface.flagging_dir, request.json['data']['output_data'][i])
for i
in range(len(app.interface.output_interfaces))]}
log_fp = "{}/log.csv".format(app.interface.flagging_dir)
is_new = not os.path.exists(log_fp)
with open(log_fp, "a") as csvfile:
headers = ["input_{}".format(i) for i in range(len(
output["inputs"]))] + ["output_{}".format(i) for i in
range(len(output["outputs"]))]
writer = csv.DictWriter(csvfile, delimiter=',',
lineterminator='\n',
fieldnames=headers)
if is_new:
writer.writeheader()
writer.writerow(
dict(zip(headers, output["inputs"] +
output["outputs"]))
)
return jsonify(success=True)
@app.route("/api/interpret/", methods=["POST"])
def interpret():
raw_input = request.json["data"]
if app.interface.interpretation == "default":
interpreter = gr.interpretation.default()
processed_input = []
for i, x in enumerate(raw_input):
input_interface = copy.deepcopy(app.interface.input_interfaces[i])
input_interface.type = gr.interpretation.expected_types[type(input_interface)]
processed_input.append(input_interface.preprocess(x))
else:
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(app.interface.input_interfaces)]
interpreter = app.interface.interpretation
interpretation = interpreter(app.interface, processed_input)
return jsonify(interpretation)
@app.route("/file/<path:path>", methods=["GET"])
def file(path):
return send_file(os.path.join(os.getcwd(), path))
def start_server(interface, server_port=None):
if server_port is None:
server_port = INITIAL_PORT_VALUE
port = get_first_available_port(
server_port, server_port + TRY_NUM_PORTS
)
httpd, thread = serve_files_in_background(interface, port, directory_to_serve, server_name)
return port, httpd, thread
app.interface = interface
process = Process(target=app.run, kwargs={"port": port})
process.start()
return port, app, process
def close_server(server):
server.server_close()
def close_server(process):
process.terminate()
process.join()
def url_request(url):
try:

View File

View File

@ -55,9 +55,9 @@ class Textbox(OutputComponent):
def postprocess(self, y):
if self.type == "str" or self.type == "auto":
return y
elif self.type == "number":
return str(y)
elif self.type == "number":
return y
else:
raise ValueError("Unknown type: " + self.type + ". Please choose from: 'str', 'number'")
@ -68,8 +68,6 @@ class Label(OutputComponent):
Output type: Union[Dict[str, float], str, int, float]
'''
LABEL_KEY = "label"
CONFIDENCE_KEY = "confidence"
CONFIDENCES_KEY = "confidences"
def __init__(self, num_top_classes=None, type="auto", label=None):
@ -85,7 +83,7 @@ class Label(OutputComponent):
def postprocess(self, y):
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or isinstance(y, Number))):
return {self.LABEL_KEY: str(y)}
return {"label": str(y)}
elif self.type == "confidences" or (self.type == "auto" and isinstance(y, dict)):
sorted_pred = sorted(
y.items(),
@ -95,11 +93,11 @@ class Label(OutputComponent):
if self.num_top_classes is not None:
sorted_pred = sorted_pred[:self.num_top_classes]
return {
self.LABEL_KEY: sorted_pred[0][0],
self.CONFIDENCES_KEY: [
"label": sorted_pred[0][0],
"confidences": [
{
self.LABEL_KEY: pred[0],
self.CONFIDENCE_KEY: pred[1]
"label": pred[0],
"confidence": pred[1]
} for pred in sorted_pred
]
}

View File

@ -84,9 +84,6 @@ input.submit {
input.submit:hover {
background-color: #f39c12;
}
.flag {
visibility: hidden;
}
.flagged {
background-color: pink !important;
}
@ -111,9 +108,6 @@ input.submit:hover {
.invisible {
display: none !important;
}
.screenshot {
visibility: hidden;
}
.screenshot_logo {
display: none;
flex-grow: 1;

View File

@ -80,6 +80,8 @@ var io_master_template = {
$.ajax({type: "POST",
url: "/api/flag/",
data: JSON.stringify(post_data),
dataType: 'json',
contentType: 'application/json; charset=utf-8',
});
},
interpret: function() {
@ -92,9 +94,10 @@ var io_master_template = {
$.ajax({type: "POST",
url: "/api/interpret/",
data: JSON.stringify(post_data),
dataType: 'json',
contentType: 'application/json; charset=utf-8',
success: function(data) {
for (let [idx, interpretation] of data.entries()) {
console.log(idx)
io.input_interfaces[idx].show_interpretation(interpretation);
}
io.target.find(".loading_in_progress").hide();

View File

@ -166,14 +166,18 @@ function gradio(config, fn, target, example_file_path) {
io_master.last_output = null;
});
if (config["allow_screenshot"]) {
target.find(".screenshot").css("visibility", "visible");
}
if (config["allow_flagging"]) {
target.find(".flag").css("visibility", "visible");
}
if (config["allow_interpretation"]) {
target.find(".interpret").css("visibility", "visible");
if (!config["allow_screenshot"] && !config["allow_flagging"] && !config["allow_interpretation"]) {
target.find(".screenshot, .flag, .interpret").css("visibility", "hidden");
} else {
if (!config["allow_screenshot"]) {
target.find(".screenshot").hide();
}
if (!config["allow_flagging"]) {
target.find(".flag").hide();
}
if (!config["allow_interpretation"]) {
target.find(".interpret").hide();
}
}
if (config["examples"]) {
target.find(".examples").removeClass("invisible");

View File

@ -117,7 +117,7 @@
<script src="/static/js/interfaces/output/file.js"></script>
<script src="/static/js/gradio.js"></script>
<script>
$.getJSON("static/config.json", function(config) {
$.getJSON("/config/", function(config) {
io = gradio_url(config, "/api/predict/", "#interface_target", "/file/");
});
const copyToClipboard = str => {

File diff suppressed because one or more lines are too long

View File

@ -1,7 +1,22 @@
import requests
import pkg_resources
from distutils.version import StrictVersion
from IPython import get_ipython
analytics_url = 'https://api.gradio.app/'
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
def version_check():
try:
current_pkg_version = pkg_resources.require("gradio")[0].version
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
print("IMPORTANT: You are using gradio version {}, "
"however version {} "
"is available, please upgrade.".format(
current_pkg_version, latest_pkg_version))
print('--------')
except: # TODO(abidlabs): don't catch all exceptions
pass
def error_analytics(type):
"""

View File

@ -2,7 +2,6 @@
import gradio as gr
def sentence_builder(quantity, animal, place, activity_list, morning):
return f"""The {quantity} {animal}s went to the {place} where they {" and ".join(activity_list)} until the {"morning" if morning else "night"}"""

View File

@ -3,10 +3,11 @@ README.md
setup.py
gradio/__init__.py
gradio/component.py
gradio/explain.py
gradio/inputs.py
gradio/interface.py
gradio/interpretation.py
gradio/networking.py
gradio/notebook.py
gradio/outputs.py
gradio/processing_utils.py
gradio/strings.py
@ -119,6 +120,5 @@ gradio/static/js/vendor/webcam.min.js
gradio/static/js/vendor/white-theme.js
gradio/templates/index.html
test/test_inputs.py
test/test_interface.py
test/test_networking.py
test/test_interfaces.py
test/test_outputs.py

View File

@ -1,5 +1,6 @@
numpy
requests
flask
paramiko
scipy
IPython

View File

@ -292,6 +292,7 @@ class Image(InputComponent):
def preprocess(self, x):
im = processing_utils.decode_base64_to_image(x)
fmt = im.format
with warnings.catch_warnings():
warnings.simplefilter("ignore")
im = im.convert(self.image_mode)
@ -305,7 +306,7 @@ class Image(InputComponent):
elif self.type == "numpy":
return np.array(im)
elif self.type == "file":
file_obj = tempfile.NamedTemporaryFile()
file_obj = tempfile.NamedTemporaryFile(suffix="."+fmt)
im.save(file_obj.name)
return file_obj
else:
@ -449,7 +450,7 @@ class Dataframe(InputComponent):
else:
return pd.DataFrame(x)
if self.col_count == 1:
x = x[0]
x = [row[0] for row in x]
if self.type == "numpy":
return np.array(x)
elif self.type == "array":

View File

@ -9,8 +9,6 @@ from gradio.inputs import InputComponent
from gradio.outputs import OutputComponent
from gradio import networking, strings, utils
import gradio.interpretation
from distutils.version import StrictVersion
import pkg_resources
import requests
import random
import time
@ -21,8 +19,6 @@ import weakref
import analytics
import os
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
analytics_url = 'https://api.gradio.app/'
try:
@ -62,6 +58,7 @@ class Interface:
examples (List[List[Any]]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component.
live (bool): whether the interface should automatically reload on change.
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter.
title (str): a title for the interface; if provided, appears above the input and output components.
description (str): a description for the interface; if provided, appears above the input and output components.
thumbnail (str): path to image or src to use as display picture for models listed in gradio.app/hub
@ -103,10 +100,6 @@ class Interface:
if not isinstance(fn, list):
fn = [fn]
if interpretation == "default":
self.interpretation = gradio.interpretation.default()
else:
self.interpretation = interpretation
self.output_interfaces *= len(fn)
self.predict = fn
@ -117,6 +110,7 @@ class Interface:
self.show_output = show_output
self.flag_hash = random.getrandbits(32)
self.capture_session = capture_session
self.interpretation = interpretation
self.session = None
self.server_name = server_name
self.title = title
@ -198,6 +192,14 @@ class Interface:
iface[1]["label"] = ret_name
except ValueError:
pass
if self.examples is not None:
processed_examples = []
for example_set in self.examples:
processed_set = []
for iface, example in zip(self.input_interfaces, example_set):
processed_set.append(iface.process_example(example))
processed_examples.append(processed_set)
config["examples"] = processed_examples
return config
def run_prediction(self, processed_input, return_duration=False):
@ -293,33 +295,22 @@ class Interface:
share (bool): whether to create a publicly shareable link from your computer for the interface.
debug (bool): if True, and the interface was launched from Google Colab, prints the errors in the cell output.
Returns
httpd (str): HTTPServer object
app (flask.Flask): Flask app object
path_to_local_server (str): Locally accessible link
share_url (str): Publicly accessible link (if share=True)
"""
output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface.
server_port, httpd, thread = networking.start_simple_server(
self, output_directory, self.server_name, server_port=self.server_port)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory)
config = self.get_config_file()
networking.set_config(config)
networking.set_meta_tags(self.title, self.description, self.thumbnail)
server_port, app, thread = networking.start_server(
self, self.server_port)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
self.server_port = server_port
self.status = "RUNNING"
self.simple_server = httpd
try:
current_pkg_version = pkg_resources.require("gradio")[0].version
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
print("IMPORTANT: You are using gradio version {}, "
"however version {} "
"is available, please upgrade.".format(
current_pkg_version, latest_pkg_version))
print('--------')
except: # TODO(abidlabs): don't catch all exceptions
pass
self.server = app
utils.version_check()
is_colab = utils.colab_check()
if not is_colab:
if not networking.url_ok(path_to_local_server):
@ -389,20 +380,7 @@ class Interface:
else:
display(IFrame(path_to_local_server, width=1000, height=500))
config = self.get_config_file()
config["share_url"] = share_url
processed_examples = []
if self.examples is not None:
for example_set in self.examples:
processed_set = []
for iface, example in zip(self.input_interfaces, example_set):
processed_set.append(iface.process_example(example))
processed_examples.append(processed_set)
config["examples"] = processed_examples
networking.set_config(config, output_directory)
networking.set_meta_tags(output_directory, self.title, self.description, self.thumbnail)
r = requests.get(path_to_local_server + "enable_sharing/" + (share_url or "None"))
if debug:
while True:
@ -410,14 +388,15 @@ class Interface:
time.sleep(0.1)
launch_method = 'browser' if inbrowser else 'inline'
data = {'launch_method': launch_method,
if self.analytics_enabled:
data = {
'launch_method': launch_method,
'is_google_colab': is_colab,
'is_sharing_on': share,
'share_url': share_url,
'ip_address': ip_address
}
if self.analytics_enabled:
}
try:
requests.post(analytics_url + 'gradio-launched-analytics/',
data=data)
@ -428,7 +407,8 @@ class Interface:
if not is_in_interactive_mode:
self.run_until_interrupted(thread, path_to_local_server)
return httpd, path_to_local_server, share_url
return app, path_to_local_server, share_url
def reset_all():
for io in Interface.get_instances():

View File

@ -4,7 +4,12 @@ from gradio import processing_utils
from skimage.segmentation import slic
import numpy as np
def default(separator=" ", n_segments=20, replace_color=None):
expected_types = {
Image: "numpy",
Textbox: "str"
}
def default(separator=" ", n_segments=20):
"""
Basic "default" interpretation method that uses "leave-one-out" to explain predictions for
the following inputs: Image, Text, and the following outputs: Label. In case of multiple
@ -22,8 +27,7 @@ def default(separator=" ", n_segments=20, replace_color=None):
def tokenize_image(image):
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
leave_one_out_tokens = []
if replace_color is None:
replace_color = np.mean(image, axis=(0, 1))
replace_color = np.mean(image, axis=(0, 1))
for (i, segVal) in enumerate(np.unique(segments_slic)):
mask = segments_slic == segVal
white_screen = np.copy(image)
@ -68,11 +72,11 @@ def default(separator=" ", n_segments=20, replace_color=None):
def quantify_difference_in_label(interface, original_output, perturbed_output):
post_original_output = interface.output_interfaces[0].postprocess(original_output[0])
post_perturbed_output = interface.output_interfaces[0].postprocess(perturbed_output[0])
original_label = post_original_output[Label.LABEL_KEY]
perturbed_label = post_perturbed_output[Label.LABEL_KEY]
original_label = post_original_output["label"]
perturbed_label = post_perturbed_output["label"]
# Handle different return types of Label interface
if Label.CONFIDENCES_KEY in post_original_output:
if "confidences" in post_original_output:
original_confidence = original_output[0][original_label]
perturbed_confidence = perturbed_output[0][original_label]
score = original_confidence - perturbed_confidence

View File

@ -5,10 +5,11 @@ Defines helper methods useful for setting up ports, launching servers, and handl
import os
import socket
import threading
from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
from flask import Flask, request, jsonify, abort, send_file, render_template
from multiprocessing import Process
import pkg_resources
from distutils import dir_util
from gradio import inputs, outputs
import gradio as gr
import time
import json
from gradio.tunneling import create_tunnel
@ -17,7 +18,7 @@ from shutil import copyfile
import requests
import sys
import csv
import copy
INITIAL_PORT_VALUE = int(os.getenv(
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
@ -29,77 +30,23 @@ GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request"
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
STATIC_PATH_LIB = pkg_resources.resource_filename("gradio", "static/")
STATIC_PATH_TEMP = "static/"
TEMPLATE_TEMP = "index.html"
BASE_JS_FILE = "static/js/all_io.js"
CONFIG_FILE = "static/config.json"
ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association"
ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association"
app = Flask(__name__,
template_folder=STATIC_TEMPLATE_LIB,
static_folder=STATIC_PATH_LIB)
app.app_globals = {}
def build_template(temp_dir):
"""
Create HTML file with supporting JS and CSS files in a given directory.
:param temp_dir: string with path to temp directory in which the html file should be built
"""
dir_util.copy_tree(STATIC_TEMPLATE_LIB, temp_dir)
dir_util.copy_tree(STATIC_PATH_LIB, os.path.join(
temp_dir, STATIC_PATH_TEMP))
# Move association file to root of temporary directory.
copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC),
os.path.join(temp_dir, ASSOCIATION_PATH_IN_ROOT))
def render_template_with_tags(template_path, context):
"""
Combines the given template with a given context dictionary by replacing all of the occurrences of tags (enclosed
in double curly braces) with corresponding values.
:param template_path: a string with the path to the template file
:param context: a dictionary whose string keys are the tags to replace and whose string values are the replacements.
"""
with open(template_path) as fin:
old_lines = fin.readlines()
new_lines = render_string_or_list_with_tags(old_lines, context)
with open(template_path, "w") as fout:
for line in new_lines:
fout.write(line)
def render_string_or_list_with_tags(old_lines, context):
# Handle string case
if isinstance(old_lines, str):
for key, value in context.items():
old_lines = old_lines.replace(r"{{" + key + r"}}", str(value))
return old_lines
# Handle list case
new_lines = []
for line in old_lines:
for key, value in context.items():
line = line.replace(r"{{" + key + r"}}", str(value))
new_lines.append(line)
return new_lines
def set_meta_tags(temp_dir, title, description, thumbnail):
title = "Gradio" if title is None else title
description = "Easy-to-use UI for your machine learning model" if description is None else description
thumbnail = "https://gradio.app/static/img/logo_only.png" if thumbnail is None else thumbnail
index_file = os.path.join(temp_dir, TEMPLATE_TEMP)
render_template_with_tags(index_file, {
def set_meta_tags(title, description, thumbnail):
app.app_globals.update({
"title": title,
"description": description,
"thumbnail": thumbnail
})
def set_config(config, temp_dir):
config_file = os.path.join(temp_dir, CONFIG_FILE)
with open(config_file, "w") as output:
json.dump(config, output)
def set_config(config):
app.app_globals["config"] = config
def get_first_available_port(initial, final):
@ -124,136 +71,108 @@ def get_first_available_port(initial, final):
)
def serve_files_in_background(interface, port, directory_to_serve=None, server_name=LOCALHOST_NAME):
class HTTPHandler(SimpleHTTPRequestHandler):
"""This handler uses server.base_path instead of always using os.getcwd()"""
def _set_headers(self):
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
def end_headers(self):
self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', 'GET, POST')
return super(HTTPHandler, self).end_headers()
def translate_path(self, path):
path = SimpleHTTPRequestHandler.translate_path(self, path)
relpath = os.path.relpath(path, os.getcwd())
fullpath = os.path.join(self.server.base_path, relpath)
return fullpath
def log_message(self, format, *args):
return
def do_POST(self):
# Read body of the request.
if self.path == "/api/predict/":
# Make the prediction.
self._set_headers()
data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
raw_input = msg["data"]
prediction, durations = interface.process(raw_input)
output = {"data": prediction, "durations": durations}
self.wfile.write(json.dumps(output).encode())
elif self.path == "/api/flag/":
self._set_headers()
data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
os.makedirs(interface.flagging_dir, exist_ok=True)
output = {'inputs': [interface.input_interfaces[
i].rebuild(
interface.flagging_dir, msg['data']['input_data'][i]) for i
in range(len(interface.input_interfaces))],
'outputs': [interface.output_interfaces[
i].rebuild(
interface.flagging_dir, msg['data']['output_data'][i])
for i
in range(len(interface.output_interfaces))]}
log_fp = "{}/log.csv".format(interface.flagging_dir)
is_new = not os.path.exists(log_fp)
with open(log_fp, "a") as csvfile:
headers = ["input_{}".format(i) for i in range(len(
output["inputs"]))] + ["output_{}".format(i) for i in
range(len(output["outputs"]))]
writer = csv.DictWriter(csvfile, delimiter=',',
lineterminator='\n',
fieldnames=headers)
if is_new:
writer.writeheader()
writer.writerow(
dict(zip(headers, output["inputs"] +
output["outputs"]))
)
elif self.path == "/api/interpret/":
self._set_headers()
data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
raw_input = msg["data"]
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(interface.input_interfaces)]
interpretation = interface.interpretation(interface, processed_input)
self.wfile.write(json.dumps(interpretation).encode())
else:
self.send_error(404, 'Path not found: {}'.format(self.path))
def do_GET(self):
if self.path.startswith("/file/"):
self.send_response(200)
self.end_headers()
with open(self.path[6:], "rb") as f:
self.wfile.write(f.read())
else:
super().do_GET()
@app.route("/", methods=["GET"])
def main():
return render_template("index.html",
title=app.app_globals["title"],
description=app.app_globals["description"],
thumbnail=app.app_globals["thumbnail"],
)
class HTTPServer(BaseHTTPServer):
"""The main server, you pass in base_path which is the path you want to serve requests from"""
def __init__(self, base_path, server_address, RequestHandlerClass=HTTPHandler):
self.base_path = base_path
BaseHTTPServer.__init__(self, server_address, RequestHandlerClass)
class QuittableHTTPThread(threading.Thread):
def __init__(self, httpd):
super().__init__(daemon=False)
self.httpd = httpd
self.keep_running =True
def run(self):
while self.keep_running:
self.httpd.handle_request()
httpd = HTTPServer(directory_to_serve, (server_name, port))
thread = QuittableHTTPThread(httpd=httpd)
thread.start()
return httpd, thread
@app.route("/config/", methods=["GET"])
def config():
return jsonify(app.app_globals["config"])
def start_simple_server(interface, directory_to_serve=None, server_name=None, server_port=None):
@app.route("/enable_sharing/<path:path>", methods=["GET"])
def enable_sharing(path):
if path == "None":
path = None
app.app_globals["config"]["share_url"] = path
return jsonify(success=True)
@app.route("/api/predict/", methods=["POST"])
def predict():
raw_input = request.json["data"]
prediction, durations = app.interface.process(raw_input)
output = {"data": prediction, "durations": durations}
return jsonify(output)
@app.route("/api/flag/", methods=["POST"])
def flag():
os.makedirs(app.interface.flagging_dir, exist_ok=True)
output = {'inputs': [app.interface.input_interfaces[
i].rebuild(
app.interface.flagging_dir, request.json['data']['input_data'][i]) for i
in range(len(app.interface.input_interfaces))],
'outputs': [app.interface.output_interfaces[
i].rebuild(
app.interface.flagging_dir, request.json['data']['output_data'][i])
for i
in range(len(app.interface.output_interfaces))]}
log_fp = "{}/log.csv".format(app.interface.flagging_dir)
is_new = not os.path.exists(log_fp)
with open(log_fp, "a") as csvfile:
headers = ["input_{}".format(i) for i in range(len(
output["inputs"]))] + ["output_{}".format(i) for i in
range(len(output["outputs"]))]
writer = csv.DictWriter(csvfile, delimiter=',',
lineterminator='\n',
fieldnames=headers)
if is_new:
writer.writeheader()
writer.writerow(
dict(zip(headers, output["inputs"] +
output["outputs"]))
)
return jsonify(success=True)
@app.route("/api/interpret/", methods=["POST"])
def interpret():
raw_input = request.json["data"]
if app.interface.interpretation == "default":
interpreter = gr.interpretation.default()
processed_input = []
for i, x in enumerate(raw_input):
input_interface = copy.deepcopy(app.interface.input_interfaces[i])
input_interface.type = gr.interpretation.expected_types[type(input_interface)]
processed_input.append(input_interface.preprocess(x))
else:
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(app.interface.input_interfaces)]
interpreter = app.interface.interpretation
interpretation = interpreter(app.interface, processed_input)
return jsonify(interpretation)
@app.route("/file/<path:path>", methods=["GET"])
def file(path):
return send_file(os.path.join(os.getcwd(), path))
def start_server(interface, server_port=None):
if server_port is None:
server_port = INITIAL_PORT_VALUE
port = get_first_available_port(
server_port, server_port + TRY_NUM_PORTS
)
httpd, thread = serve_files_in_background(interface, port, directory_to_serve, server_name)
return port, httpd, thread
app.interface = interface
process = Process(target=app.run, kwargs={"port": port})
process.start()
return port, app, process
def close_server(server):
server.server_close()
def close_server(process):
process.terminate()
process.join()
def url_request(url):
try:

0
gradio/notebook.py Normal file
View File

View File

@ -54,10 +54,10 @@ class Textbox(OutputComponent):
}
def postprocess(self, y):
if self.type == "str":
return y
elif self.type == "number" or self.type == "auto":
if self.type == "str" or self.type == "auto":
return str(y)
elif self.type == "number":
return y
else:
raise ValueError("Unknown type: " + self.type + ". Please choose from: 'str', 'number'")
@ -68,8 +68,6 @@ class Label(OutputComponent):
Output type: Union[Dict[str, float], str, int, float]
'''
LABEL_KEY = "label"
CONFIDENCE_KEY = "confidence"
CONFIDENCES_KEY = "confidences"
def __init__(self, num_top_classes=None, type="auto", label=None):
@ -85,7 +83,7 @@ class Label(OutputComponent):
def postprocess(self, y):
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or isinstance(y, Number))):
return {self.LABEL_KEY: str(y)}
return {"label": str(y)}
elif self.type == "confidences" or (self.type == "auto" and isinstance(y, dict)):
sorted_pred = sorted(
y.items(),
@ -95,11 +93,11 @@ class Label(OutputComponent):
if self.num_top_classes is not None:
sorted_pred = sorted_pred[:self.num_top_classes]
return {
self.LABEL_KEY: sorted_pred[0][0],
self.CONFIDENCES_KEY: [
"label": sorted_pred[0][0],
"confidences": [
{
self.LABEL_KEY: pred[0],
self.CONFIDENCE_KEY: pred[1]
"label": pred[0],
"confidence": pred[1]
} for pred in sorted_pred
]
}

View File

@ -84,9 +84,6 @@ input.submit {
input.submit:hover {
background-color: #f39c12;
}
.flag {
visibility: hidden;
}
.flagged {
background-color: pink !important;
}
@ -111,9 +108,6 @@ input.submit:hover {
.invisible {
display: none !important;
}
.screenshot {
visibility: hidden;
}
.screenshot_logo {
display: none;
flex-grow: 1;

View File

@ -80,6 +80,8 @@ var io_master_template = {
$.ajax({type: "POST",
url: "/api/flag/",
data: JSON.stringify(post_data),
dataType: 'json',
contentType: 'application/json; charset=utf-8',
});
},
interpret: function() {
@ -92,9 +94,10 @@ var io_master_template = {
$.ajax({type: "POST",
url: "/api/interpret/",
data: JSON.stringify(post_data),
dataType: 'json',
contentType: 'application/json; charset=utf-8',
success: function(data) {
for (let [idx, interpretation] of data.entries()) {
console.log(idx)
io.input_interfaces[idx].show_interpretation(interpretation);
}
io.target.find(".loading_in_progress").hide();

View File

@ -166,14 +166,18 @@ function gradio(config, fn, target, example_file_path) {
io_master.last_output = null;
});
if (config["allow_screenshot"]) {
target.find(".screenshot").css("visibility", "visible");
}
if (config["allow_flagging"]) {
target.find(".flag").css("visibility", "visible");
}
if (config["allow_interpretation"]) {
target.find(".interpret").css("visibility", "visible");
if (!config["allow_screenshot"] && !config["allow_flagging"] && !config["allow_interpretation"]) {
target.find(".screenshot, .flag, .interpret").css("visibility", "hidden");
} else {
if (!config["allow_screenshot"]) {
target.find(".screenshot").hide();
}
if (!config["allow_flagging"]) {
target.find(".flag").hide();
}
if (!config["allow_interpretation"]) {
target.find(".interpret").hide();
}
}
if (config["examples"]) {
target.find(".examples").removeClass("invisible");

View File

@ -117,7 +117,7 @@
<script src="/static/js/interfaces/output/file.js"></script>
<script src="/static/js/gradio.js"></script>
<script>
$.getJSON("static/config.json", function(config) {
$.getJSON("/config/", function(config) {
io = gradio_url(config, "/api/predict/", "#interface_target", "/file/");
});
const copyToClipboard = str => {

File diff suppressed because one or more lines are too long

View File

@ -1,7 +1,22 @@
import requests
import pkg_resources
from distutils.version import StrictVersion
from IPython import get_ipython
analytics_url = 'https://api.gradio.app/'
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
def version_check():
try:
current_pkg_version = pkg_resources.require("gradio")[0].version
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
print("IMPORTANT: You are using gradio version {}, "
"however version {} "
"is available, please upgrade.".format(
current_pkg_version, latest_pkg_version))
print('--------')
except: # TODO(abidlabs): don't catch all exceptions
pass
def error_analytics(type):
"""

View File

@ -16,6 +16,7 @@ setup(
install_requires=[
'numpy',
'requests',
'flask',
'paramiko',
'scipy',
'IPython',

File diff suppressed because one or more lines are too long

View File

@ -1,36 +0,0 @@
import unittest
import numpy as np
import gradio as gr
import gradio.inputs
import gradio.outputs
class TestInterface(unittest.TestCase):
def test_input_output_mapping(self):
io = gr.Interface(inputs='sketchpad', outputs='text', fn=lambda x: x,
analytics_enabled=False)
self.assertIsInstance(io.input_interfaces[0], gradio.inputs.Image)
self.assertIsInstance(io.output_interfaces[0], gradio.outputs.Textbox)
def test_input_interface_is_instance(self):
inp = gradio.inputs.Image()
io = gr.Interface(inputs=inp, outputs='text', fn=lambda x: x,
analytics_enabled=False)
self.assertEqual(io.input_interfaces[0], inp)
def test_output_interface_is_instance(self):
out = gradio.outputs.Label()
io = gr.Interface(inputs='sketchpad', outputs=out, fn=lambda x: x,
analytics_enabled=False)
self.assertEqual(io.output_interfaces[0], out)
def test_prediction(self):
def model(x):
return 2*x
io = gr.Interface(inputs='textbox', outputs='text', fn=model,
analytics_enabled=False)
self.assertEqual(io.predict[0](11), 22)
if __name__ == '__main__':
unittest.main()

4
test/test_interfaces.py Normal file
View File

@ -0,0 +1,4 @@
import unittest
if __name__ == '__main__':
unittest.main()

View File

@ -1,63 +0,0 @@
import unittest
from gradio import networking
from gradio import inputs
from gradio import outputs
import socket
import tempfile
import os
import json
class TestGetAvailablePort(unittest.TestCase):
def test_get_first_available_port_by_blocking_port(self):
initial = 7000
final = 8000
port_found = False
for port in range(initial, final):
try:
s = socket.socket() # create a socket object
s.bind((networking.LOCALHOST_NAME, port)) # Bind to the port
s.close()
port_found = True
break
except OSError:
pass
if port_found:
s = socket.socket() # create a socket object
s.bind((networking.LOCALHOST_NAME, port)) # Bind to the port
new_port = networking.get_first_available_port(initial, final)
s.close()
self.assertFalse(port == new_port)
# class TestSetSampleData(unittest.TestCase):
# def test_set_sample_data(self):
# test_array = ["test1", "test2", "test3"]
# temp_dir = tempfile.mkdtemp()
# inp = inputs.Sketchpad()
# out = outputs.Label()
# networking.build_template(temp_dir, inp, out)
# networking.set_sample_data_in_config_file(temp_dir, test_array)
# # We need to come up with a better way so that the config file isn't invalid json unless
# # the following parameters are set... (TODO: abidlabs)
# networking.set_always_flagged_in_config_file(temp_dir, False)
# networking.set_disabled_in_config_file(temp_dir, False)
# config_file = os.path.join(temp_dir, 'static/config.json')
# with open(config_file) as json_file:
# data = json.load(json_file)
# self.assertTrue(test_array == data["sample_inputs"])
# class TestCopyFiles(unittest.TestCase):
# def test_copy_files(self):
# filename = "a.txt"
# with tempfile.TemporaryDirectory() as temp_src:
# with open(os.path.join(temp_src, "a.txt"), "w+") as f:
# f.write('Hi')
# with tempfile.TemporaryDirectory() as temp_dest:
# self.assertFalse(os.path.exists(os.path.join(temp_dest, filename)))
# networking.copy_files(temp_src, temp_dest)
# self.assertTrue(os.path.exists(os.path.join(temp_dest, filename)))
if __name__ == '__main__':
unittest.main()

File diff suppressed because one or more lines are too long