mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
fix default interpretation to support different types for input
This commit is contained in:
commit
102b151d42
4
.gitignore
vendored
4
.gitignore
vendored
@ -21,4 +21,6 @@ dist/*
|
||||
*.h5
|
||||
docs.json
|
||||
*.bak
|
||||
demo/tmp.zip
|
||||
demo/tmp.zip
|
||||
demo/flagged
|
||||
test.txt
|
@ -292,6 +292,7 @@ class Image(InputComponent):
|
||||
|
||||
def preprocess(self, x):
|
||||
im = processing_utils.decode_base64_to_image(x)
|
||||
fmt = im.format
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
im = im.convert(self.image_mode)
|
||||
@ -305,7 +306,7 @@ class Image(InputComponent):
|
||||
elif self.type == "numpy":
|
||||
return np.array(im)
|
||||
elif self.type == "file":
|
||||
file_obj = tempfile.NamedTemporaryFile()
|
||||
file_obj = tempfile.NamedTemporaryFile(suffix="."+fmt)
|
||||
im.save(file_obj.name)
|
||||
return file_obj
|
||||
else:
|
||||
@ -449,7 +450,7 @@ class Dataframe(InputComponent):
|
||||
else:
|
||||
return pd.DataFrame(x)
|
||||
if self.col_count == 1:
|
||||
x = x[0]
|
||||
x = [row[0] for row in x]
|
||||
if self.type == "numpy":
|
||||
return np.array(x)
|
||||
elif self.type == "array":
|
||||
|
@ -5,18 +5,10 @@ interface using the input and output types.
|
||||
|
||||
import tempfile
|
||||
import webbrowser
|
||||
|
||||
from gradio.inputs import InputComponent
|
||||
from gradio.inputs import Image
|
||||
from gradio.inputs import Textbox
|
||||
from gradio.outputs import OutputComponent
|
||||
from gradio import networking, strings, utils, processing_utils
|
||||
from distutils.version import StrictVersion
|
||||
from skimage.segmentation import slic
|
||||
from skimage.util import img_as_float
|
||||
from gradio import processing_utils
|
||||
import PIL
|
||||
import pkg_resources
|
||||
from gradio import networking, strings, utils
|
||||
import gradio.interpretation
|
||||
import requests
|
||||
import random
|
||||
import time
|
||||
@ -26,9 +18,7 @@ import sys
|
||||
import weakref
|
||||
import analytics
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||
analytics_url = 'https://api.gradio.app/'
|
||||
try:
|
||||
@ -53,8 +43,9 @@ class Interface:
|
||||
|
||||
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
|
||||
live=False, show_input=True, show_output=True,
|
||||
capture_session=False, explain_by=None, title=None, description=None,
|
||||
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
|
||||
capture_session=False, interpretation=None, title=None,
|
||||
description=None, thumbnail=None, server_port=None,
|
||||
server_name=networking.LOCALHOST_NAME,
|
||||
allow_screenshot=True, allow_flagging=True,
|
||||
flagging_dir="flagged", analytics_enabled=True):
|
||||
|
||||
@ -67,6 +58,7 @@ class Interface:
|
||||
examples (List[List[Any]]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component.
|
||||
live (bool): whether the interface should automatically reload on change.
|
||||
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
|
||||
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter.
|
||||
title (str): a title for the interface; if provided, appears above the input and output components.
|
||||
description (str): a description for the interface; if provided, appears above the input and output components.
|
||||
thumbnail (str): path to image or src to use as display picture for models listed in gradio.app/hub
|
||||
@ -108,6 +100,7 @@ class Interface:
|
||||
if not isinstance(fn, list):
|
||||
fn = [fn]
|
||||
|
||||
|
||||
self.output_interfaces *= len(fn)
|
||||
self.predict = fn
|
||||
self.verbose = verbose
|
||||
@ -117,7 +110,7 @@ class Interface:
|
||||
self.show_output = show_output
|
||||
self.flag_hash = random.getrandbits(32)
|
||||
self.capture_session = capture_session
|
||||
self.explain_by = explain_by
|
||||
self.interpretation = interpretation
|
||||
self.session = None
|
||||
self.server_name = server_name
|
||||
self.title = title
|
||||
@ -186,7 +179,7 @@ class Interface:
|
||||
"thumbnail": self.thumbnail,
|
||||
"allow_screenshot": self.allow_screenshot,
|
||||
"allow_flagging": self.allow_flagging,
|
||||
"allow_interpretation": self.explain_by is not None
|
||||
"allow_interpretation": self.interpretation is not None
|
||||
}
|
||||
try:
|
||||
param_names = inspect.getfullargspec(self.predict[0])[0]
|
||||
@ -199,21 +192,17 @@ class Interface:
|
||||
iface[1]["label"] = ret_name
|
||||
except ValueError:
|
||||
pass
|
||||
if self.examples is not None:
|
||||
processed_examples = []
|
||||
for example_set in self.examples:
|
||||
processed_set = []
|
||||
for iface, example in zip(self.input_interfaces, example_set):
|
||||
processed_set.append(iface.process_example(example))
|
||||
processed_examples.append(processed_set)
|
||||
config["examples"] = processed_examples
|
||||
return config
|
||||
|
||||
def process(self, raw_input, predict_fn=None):
|
||||
"""
|
||||
:param raw_input: a list of raw inputs to process and apply the
|
||||
prediction(s) on.
|
||||
:param predict_fn: which function to process. If not provided, all of the model functions are used.
|
||||
:return:
|
||||
processed output: a list of processed outputs to return as the
|
||||
prediction(s).
|
||||
duration: a list of time deltas measuring inference time for each
|
||||
prediction fn.
|
||||
"""
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(self.input_interfaces)]
|
||||
def run_prediction(self, processed_input, return_duration=False):
|
||||
predictions = []
|
||||
durations = []
|
||||
for predict_fn in self.predict:
|
||||
@ -243,6 +232,27 @@ class Interface:
|
||||
prediction = [prediction]
|
||||
durations.append(duration)
|
||||
predictions.extend(prediction)
|
||||
|
||||
if return_duration:
|
||||
return predictions, durations
|
||||
else:
|
||||
return predictions
|
||||
|
||||
|
||||
def process(self, raw_input, predict_fn=None):
|
||||
"""
|
||||
:param raw_input: a list of raw inputs to process and apply the
|
||||
prediction(s) on.
|
||||
:param predict_fn: which function to process. If not provided, all of the model functions are used.
|
||||
:return:
|
||||
processed output: a list of processed outputs to return as the
|
||||
prediction(s).
|
||||
duration: a list of time deltas measuring inference time for each
|
||||
prediction fn.
|
||||
"""
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(self.input_interfaces)]
|
||||
predictions, durations = self.run_prediction(processed_input, return_duration=True)
|
||||
processed_output = [output_interface.postprocess(
|
||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||
return processed_output, durations
|
||||
@ -285,33 +295,22 @@ class Interface:
|
||||
share (bool): whether to create a publicly shareable link from your computer for the interface.
|
||||
debug (bool): if True, and the interface was launched from Google Colab, prints the errors in the cell output.
|
||||
Returns
|
||||
httpd (str): HTTPServer object
|
||||
app (flask.Flask): Flask app object
|
||||
path_to_local_server (str): Locally accessible link
|
||||
share_url (str): Publicly accessible link (if share=True)
|
||||
"""
|
||||
output_directory = tempfile.mkdtemp()
|
||||
# Set up a port to serve the directory containing the static files with interface.
|
||||
server_port, httpd, thread = networking.start_simple_server(
|
||||
self, output_directory, self.server_name, server_port=self.server_port)
|
||||
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
|
||||
networking.build_template(output_directory)
|
||||
config = self.get_config_file()
|
||||
networking.set_config(config)
|
||||
networking.set_meta_tags(self.title, self.description, self.thumbnail)
|
||||
|
||||
server_port, app, thread = networking.start_server(
|
||||
self, self.server_port)
|
||||
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
|
||||
self.server_port = server_port
|
||||
self.status = "RUNNING"
|
||||
self.simple_server = httpd
|
||||
|
||||
try:
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
|
||||
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
|
||||
print("IMPORTANT: You are using gradio version {}, "
|
||||
"however version {} "
|
||||
"is available, please upgrade.".format(
|
||||
current_pkg_version, latest_pkg_version))
|
||||
print('--------')
|
||||
except: # TODO(abidlabs): don't catch all exceptions
|
||||
pass
|
||||
self.server = app
|
||||
|
||||
utils.version_check()
|
||||
is_colab = utils.colab_check()
|
||||
if not is_colab:
|
||||
if not networking.url_ok(path_to_local_server):
|
||||
@ -381,20 +380,7 @@ class Interface:
|
||||
else:
|
||||
display(IFrame(path_to_local_server, width=1000, height=500))
|
||||
|
||||
config = self.get_config_file()
|
||||
config["share_url"] = share_url
|
||||
|
||||
processed_examples = []
|
||||
if self.examples is not None:
|
||||
for example_set in self.examples:
|
||||
processed_set = []
|
||||
for iface, example in zip(self.input_interfaces, example_set):
|
||||
processed_set.append(iface.process_example(example))
|
||||
processed_examples.append(processed_set)
|
||||
config["examples"] = processed_examples
|
||||
|
||||
networking.set_config(config, output_directory)
|
||||
networking.set_meta_tags(output_directory, self.title, self.description, self.thumbnail)
|
||||
r = requests.get(path_to_local_server + "enable_sharing/" + (share_url or "None"))
|
||||
|
||||
if debug:
|
||||
while True:
|
||||
@ -402,14 +388,15 @@ class Interface:
|
||||
time.sleep(0.1)
|
||||
|
||||
launch_method = 'browser' if inbrowser else 'inline'
|
||||
data = {'launch_method': launch_method,
|
||||
|
||||
if self.analytics_enabled:
|
||||
data = {
|
||||
'launch_method': launch_method,
|
||||
'is_google_colab': is_colab,
|
||||
'is_sharing_on': share,
|
||||
'share_url': share_url,
|
||||
'ip_address': ip_address
|
||||
}
|
||||
|
||||
if self.analytics_enabled:
|
||||
}
|
||||
try:
|
||||
requests.post(analytics_url + 'gradio-launched-analytics/',
|
||||
data=data)
|
||||
@ -420,96 +407,8 @@ class Interface:
|
||||
if not is_in_interactive_mode:
|
||||
self.run_until_interrupted(thread, path_to_local_server)
|
||||
|
||||
return httpd, path_to_local_server, share_url
|
||||
|
||||
def tokenize_text(self, text):
|
||||
leave_one_out_tokens = []
|
||||
tokens = text.split()
|
||||
for idx, _ in enumerate(tokens):
|
||||
new_token_array = tokens.copy()
|
||||
del new_token_array[idx]
|
||||
leave_one_out_tokens.append(new_token_array)
|
||||
return tokens, leave_one_out_tokens
|
||||
|
||||
def tokenize_image(self, image):
|
||||
image = np.array(processing_utils.decode_base64_to_image(image))
|
||||
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
|
||||
leave_one_out_tokens = []
|
||||
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||
mask = segments_slic == segVal
|
||||
white_screen = np.copy(image)
|
||||
white_screen[segments_slic == segVal] = 255
|
||||
leave_one_out_tokens.append((mask, white_screen))
|
||||
return leave_one_out_tokens
|
||||
|
||||
def score_text(self, tokens, leave_one_out_tokens, text):
|
||||
original_label = ""
|
||||
original_confidence = 0
|
||||
tokens = text.split()
|
||||
|
||||
input_text = " ".join(tokens)
|
||||
original_output = self.process([input_text])
|
||||
output = {result["label"] : result["confidence"]
|
||||
for result in original_output[0][0]['confidences']}
|
||||
original_label = original_output[0][0]["label"]
|
||||
original_confidence = output[original_label]
|
||||
|
||||
scores = []
|
||||
for idx, input_text in enumerate(leave_one_out_tokens):
|
||||
input_text = " ".join(input_text)
|
||||
raw_output = self.process([input_text])
|
||||
output = {result["label"] : result["confidence"]
|
||||
for result in raw_output[0][0]['confidences']}
|
||||
scores.append(original_confidence - output[original_label])
|
||||
|
||||
scores_by_char = []
|
||||
for idx, token in enumerate(tokens):
|
||||
if idx != 0:
|
||||
scores_by_char.append((" ", 0))
|
||||
for char in token:
|
||||
scores_by_char.append((char, scores[idx]))
|
||||
return scores_by_char
|
||||
|
||||
def score_image(self, leave_one_out_tokens, image):
|
||||
original_output = self.process([image])
|
||||
output = {result["label"] : result["confidence"]
|
||||
for result in original_output[0][0]['confidences']}
|
||||
original_label = original_output[0][0]["label"]
|
||||
original_confidence = output[original_label]
|
||||
|
||||
image_interface = self.input_interfaces[0]
|
||||
shape = processing_utils.decode_base64_to_image(image).size
|
||||
output_scores = np.full((shape[1], shape[0]), 0.0)
|
||||
|
||||
for mask, input_image in leave_one_out_tokens:
|
||||
input_image_base64 = processing_utils.encode_array_to_base64(
|
||||
input_image)
|
||||
raw_output = self.process([input_image_base64])
|
||||
output = {result["label"] : result["confidence"]
|
||||
for result in raw_output[0][0]['confidences']}
|
||||
score = original_confidence - output[original_label]
|
||||
output_scores += score * mask
|
||||
max_val = np.max(np.abs(output_scores))
|
||||
if max_val > 0:
|
||||
output_scores = output_scores / max_val
|
||||
return output_scores.tolist()
|
||||
|
||||
def simple_explanation(self, x):
|
||||
if isinstance(self.input_interfaces[0], Textbox):
|
||||
tokens, leave_one_out_tokens = self.tokenize_text(x[0])
|
||||
return [self.score_text(tokens, leave_one_out_tokens, x[0])]
|
||||
elif isinstance(self.input_interfaces[0], Image):
|
||||
leave_one_out_tokens = self.tokenize_image(x[0])
|
||||
return [self.score_image(leave_one_out_tokens, x[0])]
|
||||
else:
|
||||
print("Not valid input type")
|
||||
|
||||
def explain(self, x):
|
||||
if self.explain_by == "default":
|
||||
return self.simple_explanation(x)
|
||||
else:
|
||||
preprocessed_x = [input_interface(x_i) for x_i, input_interface in zip(x, self.input_interfaces)]
|
||||
return self.explain_by(*preprocessed_x)
|
||||
return app, path_to_local_server, share_url
|
||||
|
||||
def reset_all():
|
||||
for io in Interface.get_instances():
|
||||
|
103
build/lib/gradio/interpretation.py
Normal file
103
build/lib/gradio/interpretation.py
Normal file
@ -0,0 +1,103 @@
|
||||
from gradio.inputs import Image, Textbox
|
||||
from gradio.outputs import Label
|
||||
from gradio import processing_utils
|
||||
from skimage.segmentation import slic
|
||||
import numpy as np
|
||||
|
||||
expected_types = {
|
||||
Image: "numpy",
|
||||
Textbox: "str"
|
||||
}
|
||||
|
||||
def default(separator=" ", n_segments=20):
|
||||
"""
|
||||
Basic "default" interpretation method that uses "leave-one-out" to explain predictions for
|
||||
the following inputs: Image, Text, and the following outputs: Label. In case of multiple
|
||||
inputs and outputs, uses the first component.
|
||||
"""
|
||||
def tokenize_text(text):
|
||||
leave_one_out_tokens = []
|
||||
tokens = text.split(separator)
|
||||
for idx, _ in enumerate(tokens):
|
||||
new_token_array = tokens.copy()
|
||||
del new_token_array[idx]
|
||||
leave_one_out_tokens.append(new_token_array)
|
||||
return leave_one_out_tokens
|
||||
|
||||
def tokenize_image(image):
|
||||
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
|
||||
leave_one_out_tokens = []
|
||||
replace_color = np.mean(image, axis=(0, 1))
|
||||
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||
mask = segments_slic == segVal
|
||||
white_screen = np.copy(image)
|
||||
white_screen[segments_slic == segVal] = replace_color
|
||||
leave_one_out_tokens.append((mask, white_screen))
|
||||
return leave_one_out_tokens
|
||||
|
||||
def score_text(interface, leave_one_out_tokens, text):
|
||||
tokens = text.split(separator)
|
||||
original_output = interface.run_prediction([text])
|
||||
|
||||
scores_by_words = []
|
||||
for idx, input_text in enumerate(leave_one_out_tokens):
|
||||
perturbed_text = separator.join(input_text)
|
||||
perturbed_output = interface.run_prediction([perturbed_text])
|
||||
score = quantify_difference_in_label(interface, original_output, perturbed_output)
|
||||
scores_by_words.append(score)
|
||||
|
||||
scores_by_char = []
|
||||
for idx, token in enumerate(tokens):
|
||||
if idx != 0:
|
||||
scores_by_char.append((" ", 0))
|
||||
for char in token:
|
||||
scores_by_char.append((char, scores_by_words[idx]))
|
||||
|
||||
return scores_by_char
|
||||
|
||||
def score_image(interface, leave_one_out_tokens, image):
|
||||
output_scores = np.zeros((image.shape[0], image.shape[1]))
|
||||
original_output = interface.run_prediction([image])
|
||||
|
||||
for mask, perturbed_image in leave_one_out_tokens:
|
||||
perturbed_output = interface.run_prediction([perturbed_image])
|
||||
score = quantify_difference_in_label(interface, original_output, perturbed_output)
|
||||
output_scores += score * mask
|
||||
|
||||
max_val, min_val = np.max(output_scores), np.min(output_scores)
|
||||
if max_val > 0:
|
||||
output_scores = (output_scores - min_val) / (max_val - min_val)
|
||||
return output_scores.tolist()
|
||||
|
||||
def quantify_difference_in_label(interface, original_output, perturbed_output):
|
||||
post_original_output = interface.output_interfaces[0].postprocess(original_output[0])
|
||||
post_perturbed_output = interface.output_interfaces[0].postprocess(perturbed_output[0])
|
||||
original_label = post_original_output["label"]
|
||||
perturbed_label = post_perturbed_output["label"]
|
||||
|
||||
# Handle different return types of Label interface
|
||||
if "confidences" in post_original_output:
|
||||
original_confidence = original_output[0][original_label]
|
||||
perturbed_confidence = perturbed_output[0][original_label]
|
||||
score = original_confidence - perturbed_confidence
|
||||
else:
|
||||
try: # try computing numerical difference
|
||||
score = float(original_label) - float(perturbed_label)
|
||||
except ValueError: # otherwise, look at strict difference in label
|
||||
score = int(not(perturbed_label == original_label))
|
||||
return score
|
||||
|
||||
def default_interpretation(interface, x):
|
||||
if isinstance(interface.input_interfaces[0], Textbox) \
|
||||
and isinstance(interface.output_interfaces[0], Label):
|
||||
leave_one_out_tokens = tokenize_text(x[0])
|
||||
return [score_text(interface, leave_one_out_tokens, x[0])]
|
||||
if isinstance(interface.input_interfaces[0], Image) \
|
||||
and isinstance(interface.output_interfaces[0], Label):
|
||||
leave_one_out_tokens = tokenize_image(x[0])
|
||||
return [score_image(interface, leave_one_out_tokens, x[0])]
|
||||
else:
|
||||
print("Not valid input or output types for 'default' interpretation")
|
||||
|
||||
return default_interpretation
|
||||
|
@ -5,10 +5,11 @@ Defines helper methods useful for setting up ports, launching servers, and handl
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
|
||||
from flask import Flask, request, jsonify, abort, send_file, render_template
|
||||
from multiprocessing import Process
|
||||
import pkg_resources
|
||||
from distutils import dir_util
|
||||
from gradio import inputs, outputs
|
||||
import gradio as gr
|
||||
import time
|
||||
import json
|
||||
from gradio.tunneling import create_tunnel
|
||||
@ -17,7 +18,7 @@ from shutil import copyfile
|
||||
import requests
|
||||
import sys
|
||||
import csv
|
||||
|
||||
import copy
|
||||
|
||||
INITIAL_PORT_VALUE = int(os.getenv(
|
||||
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||
@ -29,77 +30,23 @@ GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request"
|
||||
|
||||
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
|
||||
STATIC_PATH_LIB = pkg_resources.resource_filename("gradio", "static/")
|
||||
STATIC_PATH_TEMP = "static/"
|
||||
TEMPLATE_TEMP = "index.html"
|
||||
BASE_JS_FILE = "static/js/all_io.js"
|
||||
CONFIG_FILE = "static/config.json"
|
||||
|
||||
ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association"
|
||||
ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association"
|
||||
app = Flask(__name__,
|
||||
template_folder=STATIC_TEMPLATE_LIB,
|
||||
static_folder=STATIC_PATH_LIB)
|
||||
app.app_globals = {}
|
||||
|
||||
|
||||
def build_template(temp_dir):
|
||||
"""
|
||||
Create HTML file with supporting JS and CSS files in a given directory.
|
||||
:param temp_dir: string with path to temp directory in which the html file should be built
|
||||
"""
|
||||
dir_util.copy_tree(STATIC_TEMPLATE_LIB, temp_dir)
|
||||
dir_util.copy_tree(STATIC_PATH_LIB, os.path.join(
|
||||
temp_dir, STATIC_PATH_TEMP))
|
||||
|
||||
# Move association file to root of temporary directory.
|
||||
copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC),
|
||||
os.path.join(temp_dir, ASSOCIATION_PATH_IN_ROOT))
|
||||
|
||||
|
||||
def render_template_with_tags(template_path, context):
|
||||
"""
|
||||
Combines the given template with a given context dictionary by replacing all of the occurrences of tags (enclosed
|
||||
in double curly braces) with corresponding values.
|
||||
:param template_path: a string with the path to the template file
|
||||
:param context: a dictionary whose string keys are the tags to replace and whose string values are the replacements.
|
||||
"""
|
||||
with open(template_path) as fin:
|
||||
old_lines = fin.readlines()
|
||||
new_lines = render_string_or_list_with_tags(old_lines, context)
|
||||
with open(template_path, "w") as fout:
|
||||
for line in new_lines:
|
||||
fout.write(line)
|
||||
|
||||
|
||||
def render_string_or_list_with_tags(old_lines, context):
|
||||
# Handle string case
|
||||
if isinstance(old_lines, str):
|
||||
for key, value in context.items():
|
||||
old_lines = old_lines.replace(r"{{" + key + r"}}", str(value))
|
||||
return old_lines
|
||||
|
||||
# Handle list case
|
||||
new_lines = []
|
||||
for line in old_lines:
|
||||
for key, value in context.items():
|
||||
line = line.replace(r"{{" + key + r"}}", str(value))
|
||||
new_lines.append(line)
|
||||
return new_lines
|
||||
|
||||
|
||||
def set_meta_tags(temp_dir, title, description, thumbnail):
|
||||
title = "Gradio" if title is None else title
|
||||
description = "Easy-to-use UI for your machine learning model" if description is None else description
|
||||
thumbnail = "https://gradio.app/static/img/logo_only.png" if thumbnail is None else thumbnail
|
||||
|
||||
index_file = os.path.join(temp_dir, TEMPLATE_TEMP)
|
||||
render_template_with_tags(index_file, {
|
||||
def set_meta_tags(title, description, thumbnail):
|
||||
app.app_globals.update({
|
||||
"title": title,
|
||||
"description": description,
|
||||
"thumbnail": thumbnail
|
||||
})
|
||||
|
||||
|
||||
def set_config(config, temp_dir):
|
||||
config_file = os.path.join(temp_dir, CONFIG_FILE)
|
||||
with open(config_file, "w") as output:
|
||||
json.dump(config, output)
|
||||
def set_config(config):
|
||||
app.app_globals["config"] = config
|
||||
|
||||
|
||||
def get_first_available_port(initial, final):
|
||||
@ -124,134 +71,108 @@ def get_first_available_port(initial, final):
|
||||
)
|
||||
|
||||
|
||||
def serve_files_in_background(interface, port, directory_to_serve=None, server_name=LOCALHOST_NAME):
|
||||
class HTTPHandler(SimpleHTTPRequestHandler):
|
||||
"""This handler uses server.base_path instead of always using os.getcwd()"""
|
||||
def _set_headers(self):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
|
||||
def end_headers(self):
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.send_header('Access-Control-Allow-Methods', 'GET, POST')
|
||||
return super(HTTPHandler, self).end_headers()
|
||||
|
||||
def translate_path(self, path):
|
||||
path = SimpleHTTPRequestHandler.translate_path(self, path)
|
||||
relpath = os.path.relpath(path, os.getcwd())
|
||||
fullpath = os.path.join(self.server.base_path, relpath)
|
||||
return fullpath
|
||||
|
||||
def log_message(self, format, *args):
|
||||
return
|
||||
|
||||
def do_POST(self):
|
||||
# Read body of the request.
|
||||
if self.path == "/api/predict/":
|
||||
# Make the prediction.
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
raw_input = msg["data"]
|
||||
prediction, durations = interface.process(raw_input)
|
||||
|
||||
output = {"data": prediction, "durations": durations}
|
||||
self.wfile.write(json.dumps(output).encode())
|
||||
|
||||
elif self.path == "/api/flag/":
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
os.makedirs(interface.flagging_dir, exist_ok=True)
|
||||
output = {'inputs': [interface.input_interfaces[
|
||||
i].rebuild(
|
||||
interface.flagging_dir, msg['data']['input_data'][i]) for i
|
||||
in range(len(interface.input_interfaces))],
|
||||
'outputs': [interface.output_interfaces[
|
||||
i].rebuild(
|
||||
interface.flagging_dir, msg['data']['output_data'][i])
|
||||
for i
|
||||
in range(len(interface.output_interfaces))]}
|
||||
|
||||
log_fp = "{}/log.csv".format(interface.flagging_dir)
|
||||
|
||||
is_new = not os.path.exists(log_fp)
|
||||
|
||||
with open(log_fp, "a") as csvfile:
|
||||
headers = ["input_{}".format(i) for i in range(len(
|
||||
output["inputs"]))] + ["output_{}".format(i) for i in
|
||||
range(len(output["outputs"]))]
|
||||
writer = csv.DictWriter(csvfile, delimiter=',',
|
||||
lineterminator='\n',
|
||||
fieldnames=headers)
|
||||
if is_new:
|
||||
writer.writeheader()
|
||||
|
||||
writer.writerow(
|
||||
dict(zip(headers, output["inputs"] +
|
||||
output["outputs"]))
|
||||
)
|
||||
elif self.path == "/api/interpret/":
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
interpretation = interface.explain(msg["data"])
|
||||
self.wfile.write(json.dumps(interpretation).encode())
|
||||
else:
|
||||
self.send_error(404, 'Path not found: {}'.format(self.path))
|
||||
@app.route("/", methods=["GET"])
|
||||
def main():
|
||||
return render_template("index.html",
|
||||
title=app.app_globals["title"],
|
||||
description=app.app_globals["description"],
|
||||
thumbnail=app.app_globals["thumbnail"],
|
||||
)
|
||||
|
||||
|
||||
def do_GET(self):
|
||||
if self.path.startswith("/file/"):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
with open(self.path[6:], "rb") as f:
|
||||
self.wfile.write(f.read())
|
||||
else:
|
||||
super().do_GET()
|
||||
@app.route("/config/", methods=["GET"])
|
||||
def config():
|
||||
return jsonify(app.app_globals["config"])
|
||||
|
||||
|
||||
@app.route("/enable_sharing/<path:path>", methods=["GET"])
|
||||
def enable_sharing(path):
|
||||
if path == "None":
|
||||
path = None
|
||||
app.app_globals["config"]["share_url"] = path
|
||||
return jsonify(success=True)
|
||||
|
||||
|
||||
@app.route("/api/predict/", methods=["POST"])
|
||||
def predict():
|
||||
raw_input = request.json["data"]
|
||||
prediction, durations = app.interface.process(raw_input)
|
||||
output = {"data": prediction, "durations": durations}
|
||||
return jsonify(output)
|
||||
|
||||
|
||||
@app.route("/api/flag/", methods=["POST"])
|
||||
def flag():
|
||||
os.makedirs(app.interface.flagging_dir, exist_ok=True)
|
||||
output = {'inputs': [app.interface.input_interfaces[
|
||||
i].rebuild(
|
||||
app.interface.flagging_dir, request.json['data']['input_data'][i]) for i
|
||||
in range(len(app.interface.input_interfaces))],
|
||||
'outputs': [app.interface.output_interfaces[
|
||||
i].rebuild(
|
||||
app.interface.flagging_dir, request.json['data']['output_data'][i])
|
||||
for i
|
||||
in range(len(app.interface.output_interfaces))]}
|
||||
|
||||
log_fp = "{}/log.csv".format(app.interface.flagging_dir)
|
||||
|
||||
is_new = not os.path.exists(log_fp)
|
||||
|
||||
with open(log_fp, "a") as csvfile:
|
||||
headers = ["input_{}".format(i) for i in range(len(
|
||||
output["inputs"]))] + ["output_{}".format(i) for i in
|
||||
range(len(output["outputs"]))]
|
||||
writer = csv.DictWriter(csvfile, delimiter=',',
|
||||
lineterminator='\n',
|
||||
fieldnames=headers)
|
||||
if is_new:
|
||||
writer.writeheader()
|
||||
|
||||
writer.writerow(
|
||||
dict(zip(headers, output["inputs"] +
|
||||
output["outputs"]))
|
||||
)
|
||||
return jsonify(success=True)
|
||||
|
||||
|
||||
@app.route("/api/interpret/", methods=["POST"])
|
||||
def interpret():
|
||||
raw_input = request.json["data"]
|
||||
if app.interface.interpretation == "default":
|
||||
interpreter = gr.interpretation.default()
|
||||
processed_input = []
|
||||
for i, x in enumerate(raw_input):
|
||||
input_interface = copy.deepcopy(app.interface.input_interfaces[i])
|
||||
input_interface.type = gr.interpretation.expected_types[type(input_interface)]
|
||||
processed_input.append(input_interface.preprocess(x))
|
||||
else:
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
||||
interpreter = app.interface.interpretation
|
||||
interpretation = interpreter(app.interface, processed_input)
|
||||
return jsonify(interpretation)
|
||||
|
||||
|
||||
@app.route("/file/<path:path>", methods=["GET"])
|
||||
def file(path):
|
||||
return send_file(os.path.join(os.getcwd(), path))
|
||||
|
||||
|
||||
class HTTPServer(BaseHTTPServer):
|
||||
"""The main server, you pass in base_path which is the path you want to serve requests from"""
|
||||
|
||||
def __init__(self, base_path, server_address, RequestHandlerClass=HTTPHandler):
|
||||
self.base_path = base_path
|
||||
BaseHTTPServer.__init__(self, server_address, RequestHandlerClass)
|
||||
|
||||
class QuittableHTTPThread(threading.Thread):
|
||||
def __init__(self, httpd):
|
||||
super().__init__(daemon=False)
|
||||
self.httpd = httpd
|
||||
self.keep_running =True
|
||||
|
||||
def run(self):
|
||||
while self.keep_running:
|
||||
self.httpd.handle_request()
|
||||
|
||||
httpd = HTTPServer(directory_to_serve, (server_name, port))
|
||||
thread = QuittableHTTPThread(httpd=httpd)
|
||||
thread.start()
|
||||
|
||||
return httpd, thread
|
||||
|
||||
|
||||
def start_simple_server(interface, directory_to_serve=None, server_name=None, server_port=None):
|
||||
def start_server(interface, server_port=None):
|
||||
if server_port is None:
|
||||
server_port = INITIAL_PORT_VALUE
|
||||
port = get_first_available_port(
|
||||
server_port, server_port + TRY_NUM_PORTS
|
||||
)
|
||||
httpd, thread = serve_files_in_background(interface, port, directory_to_serve, server_name)
|
||||
return port, httpd, thread
|
||||
app.interface = interface
|
||||
process = Process(target=app.run, kwargs={"port": port})
|
||||
process.start()
|
||||
return port, app, process
|
||||
|
||||
|
||||
def close_server(server):
|
||||
server.server_close()
|
||||
|
||||
def close_server(process):
|
||||
process.terminate()
|
||||
process.join()
|
||||
|
||||
def url_request(url):
|
||||
try:
|
||||
|
0
build/lib/gradio/notebook.py
Normal file
0
build/lib/gradio/notebook.py
Normal file
@ -55,9 +55,9 @@ class Textbox(OutputComponent):
|
||||
|
||||
def postprocess(self, y):
|
||||
if self.type == "str" or self.type == "auto":
|
||||
return y
|
||||
elif self.type == "number":
|
||||
return str(y)
|
||||
elif self.type == "number":
|
||||
return y
|
||||
else:
|
||||
raise ValueError("Unknown type: " + self.type + ". Please choose from: 'str', 'number'")
|
||||
|
||||
@ -68,8 +68,6 @@ class Label(OutputComponent):
|
||||
Output type: Union[Dict[str, float], str, int, float]
|
||||
'''
|
||||
|
||||
LABEL_KEY = "label"
|
||||
CONFIDENCE_KEY = "confidence"
|
||||
CONFIDENCES_KEY = "confidences"
|
||||
|
||||
def __init__(self, num_top_classes=None, type="auto", label=None):
|
||||
@ -85,7 +83,7 @@ class Label(OutputComponent):
|
||||
|
||||
def postprocess(self, y):
|
||||
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or isinstance(y, Number))):
|
||||
return {self.LABEL_KEY: str(y)}
|
||||
return {"label": str(y)}
|
||||
elif self.type == "confidences" or (self.type == "auto" and isinstance(y, dict)):
|
||||
sorted_pred = sorted(
|
||||
y.items(),
|
||||
@ -95,11 +93,11 @@ class Label(OutputComponent):
|
||||
if self.num_top_classes is not None:
|
||||
sorted_pred = sorted_pred[:self.num_top_classes]
|
||||
return {
|
||||
self.LABEL_KEY: sorted_pred[0][0],
|
||||
self.CONFIDENCES_KEY: [
|
||||
"label": sorted_pred[0][0],
|
||||
"confidences": [
|
||||
{
|
||||
self.LABEL_KEY: pred[0],
|
||||
self.CONFIDENCE_KEY: pred[1]
|
||||
"label": pred[0],
|
||||
"confidence": pred[1]
|
||||
} for pred in sorted_pred
|
||||
]
|
||||
}
|
||||
|
@ -84,9 +84,6 @@ input.submit {
|
||||
input.submit:hover {
|
||||
background-color: #f39c12;
|
||||
}
|
||||
.flag {
|
||||
visibility: hidden;
|
||||
}
|
||||
.flagged {
|
||||
background-color: pink !important;
|
||||
}
|
||||
@ -111,9 +108,6 @@ input.submit:hover {
|
||||
.invisible {
|
||||
display: none !important;
|
||||
}
|
||||
.screenshot {
|
||||
visibility: hidden;
|
||||
}
|
||||
.screenshot_logo {
|
||||
display: none;
|
||||
flex-grow: 1;
|
||||
|
@ -80,6 +80,8 @@ var io_master_template = {
|
||||
$.ajax({type: "POST",
|
||||
url: "/api/flag/",
|
||||
data: JSON.stringify(post_data),
|
||||
dataType: 'json',
|
||||
contentType: 'application/json; charset=utf-8',
|
||||
});
|
||||
},
|
||||
interpret: function() {
|
||||
@ -92,9 +94,10 @@ var io_master_template = {
|
||||
$.ajax({type: "POST",
|
||||
url: "/api/interpret/",
|
||||
data: JSON.stringify(post_data),
|
||||
dataType: 'json',
|
||||
contentType: 'application/json; charset=utf-8',
|
||||
success: function(data) {
|
||||
for (let [idx, interpretation] of data.entries()) {
|
||||
console.log(idx)
|
||||
io.input_interfaces[idx].show_interpretation(interpretation);
|
||||
}
|
||||
io.target.find(".loading_in_progress").hide();
|
||||
|
@ -166,14 +166,18 @@ function gradio(config, fn, target, example_file_path) {
|
||||
io_master.last_output = null;
|
||||
});
|
||||
|
||||
if (config["allow_screenshot"]) {
|
||||
target.find(".screenshot").css("visibility", "visible");
|
||||
}
|
||||
if (config["allow_flagging"]) {
|
||||
target.find(".flag").css("visibility", "visible");
|
||||
}
|
||||
if (config["allow_interpretation"]) {
|
||||
target.find(".interpret").css("visibility", "visible");
|
||||
if (!config["allow_screenshot"] && !config["allow_flagging"] && !config["allow_interpretation"]) {
|
||||
target.find(".screenshot, .flag, .interpret").css("visibility", "hidden");
|
||||
} else {
|
||||
if (!config["allow_screenshot"]) {
|
||||
target.find(".screenshot").hide();
|
||||
}
|
||||
if (!config["allow_flagging"]) {
|
||||
target.find(".flag").hide();
|
||||
}
|
||||
if (!config["allow_interpretation"]) {
|
||||
target.find(".interpret").hide();
|
||||
}
|
||||
}
|
||||
if (config["examples"]) {
|
||||
target.find(".examples").removeClass("invisible");
|
||||
|
@ -117,7 +117,7 @@
|
||||
<script src="/static/js/interfaces/output/file.js"></script>
|
||||
<script src="/static/js/gradio.js"></script>
|
||||
<script>
|
||||
$.getJSON("static/config.json", function(config) {
|
||||
$.getJSON("/config/", function(config) {
|
||||
io = gradio_url(config, "/api/predict/", "#interface_target", "/file/");
|
||||
});
|
||||
const copyToClipboard = str => {
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,7 +1,22 @@
|
||||
import requests
|
||||
import pkg_resources
|
||||
from distutils.version import StrictVersion
|
||||
from IPython import get_ipython
|
||||
analytics_url = 'https://api.gradio.app/'
|
||||
|
||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||
def version_check():
|
||||
try:
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
|
||||
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
|
||||
print("IMPORTANT: You are using gradio version {}, "
|
||||
"however version {} "
|
||||
"is available, please upgrade.".format(
|
||||
current_pkg_version, latest_pkg_version))
|
||||
print('--------')
|
||||
except: # TODO(abidlabs): don't catch all exceptions
|
||||
pass
|
||||
|
||||
def error_analytics(type):
|
||||
"""
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def sentence_builder(quantity, animal, place, activity_list, morning):
|
||||
return f"""The {quantity} {animal}s went to the {place} where they {" and ".join(activity_list)} until the {"morning" if morning else "night"}"""
|
||||
|
||||
|
@ -3,10 +3,11 @@ README.md
|
||||
setup.py
|
||||
gradio/__init__.py
|
||||
gradio/component.py
|
||||
gradio/explain.py
|
||||
gradio/inputs.py
|
||||
gradio/interface.py
|
||||
gradio/interpretation.py
|
||||
gradio/networking.py
|
||||
gradio/notebook.py
|
||||
gradio/outputs.py
|
||||
gradio/processing_utils.py
|
||||
gradio/strings.py
|
||||
@ -119,6 +120,5 @@ gradio/static/js/vendor/webcam.min.js
|
||||
gradio/static/js/vendor/white-theme.js
|
||||
gradio/templates/index.html
|
||||
test/test_inputs.py
|
||||
test/test_interface.py
|
||||
test/test_networking.py
|
||||
test/test_interfaces.py
|
||||
test/test_outputs.py
|
@ -1,5 +1,6 @@
|
||||
numpy
|
||||
requests
|
||||
flask
|
||||
paramiko
|
||||
scipy
|
||||
IPython
|
||||
|
@ -292,6 +292,7 @@ class Image(InputComponent):
|
||||
|
||||
def preprocess(self, x):
|
||||
im = processing_utils.decode_base64_to_image(x)
|
||||
fmt = im.format
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
im = im.convert(self.image_mode)
|
||||
@ -305,7 +306,7 @@ class Image(InputComponent):
|
||||
elif self.type == "numpy":
|
||||
return np.array(im)
|
||||
elif self.type == "file":
|
||||
file_obj = tempfile.NamedTemporaryFile()
|
||||
file_obj = tempfile.NamedTemporaryFile(suffix="."+fmt)
|
||||
im.save(file_obj.name)
|
||||
return file_obj
|
||||
else:
|
||||
@ -449,7 +450,7 @@ class Dataframe(InputComponent):
|
||||
else:
|
||||
return pd.DataFrame(x)
|
||||
if self.col_count == 1:
|
||||
x = x[0]
|
||||
x = [row[0] for row in x]
|
||||
if self.type == "numpy":
|
||||
return np.array(x)
|
||||
elif self.type == "array":
|
||||
|
@ -9,8 +9,6 @@ from gradio.inputs import InputComponent
|
||||
from gradio.outputs import OutputComponent
|
||||
from gradio import networking, strings, utils
|
||||
import gradio.interpretation
|
||||
from distutils.version import StrictVersion
|
||||
import pkg_resources
|
||||
import requests
|
||||
import random
|
||||
import time
|
||||
@ -21,8 +19,6 @@ import weakref
|
||||
import analytics
|
||||
import os
|
||||
|
||||
|
||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
|
||||
analytics_url = 'https://api.gradio.app/'
|
||||
try:
|
||||
@ -62,6 +58,7 @@ class Interface:
|
||||
examples (List[List[Any]]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component.
|
||||
live (bool): whether the interface should automatically reload on change.
|
||||
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
|
||||
interpretation (Union[Callable, str]): function that provides interpretation explaining prediction output. Pass "default" to use built-in interpreter.
|
||||
title (str): a title for the interface; if provided, appears above the input and output components.
|
||||
description (str): a description for the interface; if provided, appears above the input and output components.
|
||||
thumbnail (str): path to image or src to use as display picture for models listed in gradio.app/hub
|
||||
@ -103,10 +100,6 @@ class Interface:
|
||||
if not isinstance(fn, list):
|
||||
fn = [fn]
|
||||
|
||||
if interpretation == "default":
|
||||
self.interpretation = gradio.interpretation.default()
|
||||
else:
|
||||
self.interpretation = interpretation
|
||||
|
||||
self.output_interfaces *= len(fn)
|
||||
self.predict = fn
|
||||
@ -117,6 +110,7 @@ class Interface:
|
||||
self.show_output = show_output
|
||||
self.flag_hash = random.getrandbits(32)
|
||||
self.capture_session = capture_session
|
||||
self.interpretation = interpretation
|
||||
self.session = None
|
||||
self.server_name = server_name
|
||||
self.title = title
|
||||
@ -198,6 +192,14 @@ class Interface:
|
||||
iface[1]["label"] = ret_name
|
||||
except ValueError:
|
||||
pass
|
||||
if self.examples is not None:
|
||||
processed_examples = []
|
||||
for example_set in self.examples:
|
||||
processed_set = []
|
||||
for iface, example in zip(self.input_interfaces, example_set):
|
||||
processed_set.append(iface.process_example(example))
|
||||
processed_examples.append(processed_set)
|
||||
config["examples"] = processed_examples
|
||||
return config
|
||||
|
||||
def run_prediction(self, processed_input, return_duration=False):
|
||||
@ -293,33 +295,22 @@ class Interface:
|
||||
share (bool): whether to create a publicly shareable link from your computer for the interface.
|
||||
debug (bool): if True, and the interface was launched from Google Colab, prints the errors in the cell output.
|
||||
Returns
|
||||
httpd (str): HTTPServer object
|
||||
app (flask.Flask): Flask app object
|
||||
path_to_local_server (str): Locally accessible link
|
||||
share_url (str): Publicly accessible link (if share=True)
|
||||
"""
|
||||
output_directory = tempfile.mkdtemp()
|
||||
# Set up a port to serve the directory containing the static files with interface.
|
||||
server_port, httpd, thread = networking.start_simple_server(
|
||||
self, output_directory, self.server_name, server_port=self.server_port)
|
||||
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
|
||||
networking.build_template(output_directory)
|
||||
config = self.get_config_file()
|
||||
networking.set_config(config)
|
||||
networking.set_meta_tags(self.title, self.description, self.thumbnail)
|
||||
|
||||
server_port, app, thread = networking.start_server(
|
||||
self, self.server_port)
|
||||
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
|
||||
self.server_port = server_port
|
||||
self.status = "RUNNING"
|
||||
self.simple_server = httpd
|
||||
|
||||
try:
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
|
||||
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
|
||||
print("IMPORTANT: You are using gradio version {}, "
|
||||
"however version {} "
|
||||
"is available, please upgrade.".format(
|
||||
current_pkg_version, latest_pkg_version))
|
||||
print('--------')
|
||||
except: # TODO(abidlabs): don't catch all exceptions
|
||||
pass
|
||||
self.server = app
|
||||
|
||||
utils.version_check()
|
||||
is_colab = utils.colab_check()
|
||||
if not is_colab:
|
||||
if not networking.url_ok(path_to_local_server):
|
||||
@ -389,20 +380,7 @@ class Interface:
|
||||
else:
|
||||
display(IFrame(path_to_local_server, width=1000, height=500))
|
||||
|
||||
config = self.get_config_file()
|
||||
config["share_url"] = share_url
|
||||
|
||||
processed_examples = []
|
||||
if self.examples is not None:
|
||||
for example_set in self.examples:
|
||||
processed_set = []
|
||||
for iface, example in zip(self.input_interfaces, example_set):
|
||||
processed_set.append(iface.process_example(example))
|
||||
processed_examples.append(processed_set)
|
||||
config["examples"] = processed_examples
|
||||
|
||||
networking.set_config(config, output_directory)
|
||||
networking.set_meta_tags(output_directory, self.title, self.description, self.thumbnail)
|
||||
r = requests.get(path_to_local_server + "enable_sharing/" + (share_url or "None"))
|
||||
|
||||
if debug:
|
||||
while True:
|
||||
@ -410,14 +388,15 @@ class Interface:
|
||||
time.sleep(0.1)
|
||||
|
||||
launch_method = 'browser' if inbrowser else 'inline'
|
||||
data = {'launch_method': launch_method,
|
||||
|
||||
if self.analytics_enabled:
|
||||
data = {
|
||||
'launch_method': launch_method,
|
||||
'is_google_colab': is_colab,
|
||||
'is_sharing_on': share,
|
||||
'share_url': share_url,
|
||||
'ip_address': ip_address
|
||||
}
|
||||
|
||||
if self.analytics_enabled:
|
||||
}
|
||||
try:
|
||||
requests.post(analytics_url + 'gradio-launched-analytics/',
|
||||
data=data)
|
||||
@ -428,7 +407,8 @@ class Interface:
|
||||
if not is_in_interactive_mode:
|
||||
self.run_until_interrupted(thread, path_to_local_server)
|
||||
|
||||
return httpd, path_to_local_server, share_url
|
||||
|
||||
return app, path_to_local_server, share_url
|
||||
|
||||
def reset_all():
|
||||
for io in Interface.get_instances():
|
||||
|
@ -4,7 +4,12 @@ from gradio import processing_utils
|
||||
from skimage.segmentation import slic
|
||||
import numpy as np
|
||||
|
||||
def default(separator=" ", n_segments=20, replace_color=None):
|
||||
expected_types = {
|
||||
Image: "numpy",
|
||||
Textbox: "str"
|
||||
}
|
||||
|
||||
def default(separator=" ", n_segments=20):
|
||||
"""
|
||||
Basic "default" interpretation method that uses "leave-one-out" to explain predictions for
|
||||
the following inputs: Image, Text, and the following outputs: Label. In case of multiple
|
||||
@ -22,8 +27,7 @@ def default(separator=" ", n_segments=20, replace_color=None):
|
||||
def tokenize_image(image):
|
||||
segments_slic = slic(image, n_segments=20, compactness=10, sigma=1)
|
||||
leave_one_out_tokens = []
|
||||
if replace_color is None:
|
||||
replace_color = np.mean(image, axis=(0, 1))
|
||||
replace_color = np.mean(image, axis=(0, 1))
|
||||
for (i, segVal) in enumerate(np.unique(segments_slic)):
|
||||
mask = segments_slic == segVal
|
||||
white_screen = np.copy(image)
|
||||
@ -68,11 +72,11 @@ def default(separator=" ", n_segments=20, replace_color=None):
|
||||
def quantify_difference_in_label(interface, original_output, perturbed_output):
|
||||
post_original_output = interface.output_interfaces[0].postprocess(original_output[0])
|
||||
post_perturbed_output = interface.output_interfaces[0].postprocess(perturbed_output[0])
|
||||
original_label = post_original_output[Label.LABEL_KEY]
|
||||
perturbed_label = post_perturbed_output[Label.LABEL_KEY]
|
||||
original_label = post_original_output["label"]
|
||||
perturbed_label = post_perturbed_output["label"]
|
||||
|
||||
# Handle different return types of Label interface
|
||||
if Label.CONFIDENCES_KEY in post_original_output:
|
||||
if "confidences" in post_original_output:
|
||||
original_confidence = original_output[0][original_label]
|
||||
perturbed_confidence = perturbed_output[0][original_label]
|
||||
score = original_confidence - perturbed_confidence
|
||||
|
@ -5,10 +5,11 @@ Defines helper methods useful for setting up ports, launching servers, and handl
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
|
||||
from flask import Flask, request, jsonify, abort, send_file, render_template
|
||||
from multiprocessing import Process
|
||||
import pkg_resources
|
||||
from distutils import dir_util
|
||||
from gradio import inputs, outputs
|
||||
import gradio as gr
|
||||
import time
|
||||
import json
|
||||
from gradio.tunneling import create_tunnel
|
||||
@ -17,7 +18,7 @@ from shutil import copyfile
|
||||
import requests
|
||||
import sys
|
||||
import csv
|
||||
|
||||
import copy
|
||||
|
||||
INITIAL_PORT_VALUE = int(os.getenv(
|
||||
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||
@ -29,77 +30,23 @@ GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request"
|
||||
|
||||
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
|
||||
STATIC_PATH_LIB = pkg_resources.resource_filename("gradio", "static/")
|
||||
STATIC_PATH_TEMP = "static/"
|
||||
TEMPLATE_TEMP = "index.html"
|
||||
BASE_JS_FILE = "static/js/all_io.js"
|
||||
CONFIG_FILE = "static/config.json"
|
||||
|
||||
ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association"
|
||||
ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association"
|
||||
app = Flask(__name__,
|
||||
template_folder=STATIC_TEMPLATE_LIB,
|
||||
static_folder=STATIC_PATH_LIB)
|
||||
app.app_globals = {}
|
||||
|
||||
|
||||
def build_template(temp_dir):
|
||||
"""
|
||||
Create HTML file with supporting JS and CSS files in a given directory.
|
||||
:param temp_dir: string with path to temp directory in which the html file should be built
|
||||
"""
|
||||
dir_util.copy_tree(STATIC_TEMPLATE_LIB, temp_dir)
|
||||
dir_util.copy_tree(STATIC_PATH_LIB, os.path.join(
|
||||
temp_dir, STATIC_PATH_TEMP))
|
||||
|
||||
# Move association file to root of temporary directory.
|
||||
copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC),
|
||||
os.path.join(temp_dir, ASSOCIATION_PATH_IN_ROOT))
|
||||
|
||||
|
||||
def render_template_with_tags(template_path, context):
|
||||
"""
|
||||
Combines the given template with a given context dictionary by replacing all of the occurrences of tags (enclosed
|
||||
in double curly braces) with corresponding values.
|
||||
:param template_path: a string with the path to the template file
|
||||
:param context: a dictionary whose string keys are the tags to replace and whose string values are the replacements.
|
||||
"""
|
||||
with open(template_path) as fin:
|
||||
old_lines = fin.readlines()
|
||||
new_lines = render_string_or_list_with_tags(old_lines, context)
|
||||
with open(template_path, "w") as fout:
|
||||
for line in new_lines:
|
||||
fout.write(line)
|
||||
|
||||
|
||||
def render_string_or_list_with_tags(old_lines, context):
|
||||
# Handle string case
|
||||
if isinstance(old_lines, str):
|
||||
for key, value in context.items():
|
||||
old_lines = old_lines.replace(r"{{" + key + r"}}", str(value))
|
||||
return old_lines
|
||||
|
||||
# Handle list case
|
||||
new_lines = []
|
||||
for line in old_lines:
|
||||
for key, value in context.items():
|
||||
line = line.replace(r"{{" + key + r"}}", str(value))
|
||||
new_lines.append(line)
|
||||
return new_lines
|
||||
|
||||
|
||||
def set_meta_tags(temp_dir, title, description, thumbnail):
|
||||
title = "Gradio" if title is None else title
|
||||
description = "Easy-to-use UI for your machine learning model" if description is None else description
|
||||
thumbnail = "https://gradio.app/static/img/logo_only.png" if thumbnail is None else thumbnail
|
||||
|
||||
index_file = os.path.join(temp_dir, TEMPLATE_TEMP)
|
||||
render_template_with_tags(index_file, {
|
||||
def set_meta_tags(title, description, thumbnail):
|
||||
app.app_globals.update({
|
||||
"title": title,
|
||||
"description": description,
|
||||
"thumbnail": thumbnail
|
||||
})
|
||||
|
||||
|
||||
def set_config(config, temp_dir):
|
||||
config_file = os.path.join(temp_dir, CONFIG_FILE)
|
||||
with open(config_file, "w") as output:
|
||||
json.dump(config, output)
|
||||
def set_config(config):
|
||||
app.app_globals["config"] = config
|
||||
|
||||
|
||||
def get_first_available_port(initial, final):
|
||||
@ -124,136 +71,108 @@ def get_first_available_port(initial, final):
|
||||
)
|
||||
|
||||
|
||||
def serve_files_in_background(interface, port, directory_to_serve=None, server_name=LOCALHOST_NAME):
|
||||
class HTTPHandler(SimpleHTTPRequestHandler):
|
||||
"""This handler uses server.base_path instead of always using os.getcwd()"""
|
||||
def _set_headers(self):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
@app.route("/", methods=["GET"])
|
||||
def main():
|
||||
return render_template("index.html",
|
||||
title=app.app_globals["title"],
|
||||
description=app.app_globals["description"],
|
||||
thumbnail=app.app_globals["thumbnail"],
|
||||
)
|
||||
|
||||
def end_headers(self):
|
||||
self.send_header('Access-Control-Allow-Origin', '*')
|
||||
self.send_header('Access-Control-Allow-Methods', 'GET, POST')
|
||||
return super(HTTPHandler, self).end_headers()
|
||||
|
||||
def translate_path(self, path):
|
||||
path = SimpleHTTPRequestHandler.translate_path(self, path)
|
||||
relpath = os.path.relpath(path, os.getcwd())
|
||||
fullpath = os.path.join(self.server.base_path, relpath)
|
||||
return fullpath
|
||||
@app.route("/config/", methods=["GET"])
|
||||
def config():
|
||||
return jsonify(app.app_globals["config"])
|
||||
|
||||
def log_message(self, format, *args):
|
||||
return
|
||||
|
||||
def do_POST(self):
|
||||
# Read body of the request.
|
||||
if self.path == "/api/predict/":
|
||||
# Make the prediction.
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
raw_input = msg["data"]
|
||||
prediction, durations = interface.process(raw_input)
|
||||
output = {"data": prediction, "durations": durations}
|
||||
self.wfile.write(json.dumps(output).encode())
|
||||
@app.route("/enable_sharing/<path:path>", methods=["GET"])
|
||||
def enable_sharing(path):
|
||||
if path == "None":
|
||||
path = None
|
||||
app.app_globals["config"]["share_url"] = path
|
||||
return jsonify(success=True)
|
||||
|
||||
|
||||
elif self.path == "/api/flag/":
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
os.makedirs(interface.flagging_dir, exist_ok=True)
|
||||
output = {'inputs': [interface.input_interfaces[
|
||||
i].rebuild(
|
||||
interface.flagging_dir, msg['data']['input_data'][i]) for i
|
||||
in range(len(interface.input_interfaces))],
|
||||
'outputs': [interface.output_interfaces[
|
||||
i].rebuild(
|
||||
interface.flagging_dir, msg['data']['output_data'][i])
|
||||
for i
|
||||
in range(len(interface.output_interfaces))]}
|
||||
@app.route("/api/predict/", methods=["POST"])
|
||||
def predict():
|
||||
raw_input = request.json["data"]
|
||||
prediction, durations = app.interface.process(raw_input)
|
||||
output = {"data": prediction, "durations": durations}
|
||||
return jsonify(output)
|
||||
|
||||
log_fp = "{}/log.csv".format(interface.flagging_dir)
|
||||
|
||||
is_new = not os.path.exists(log_fp)
|
||||
@app.route("/api/flag/", methods=["POST"])
|
||||
def flag():
|
||||
os.makedirs(app.interface.flagging_dir, exist_ok=True)
|
||||
output = {'inputs': [app.interface.input_interfaces[
|
||||
i].rebuild(
|
||||
app.interface.flagging_dir, request.json['data']['input_data'][i]) for i
|
||||
in range(len(app.interface.input_interfaces))],
|
||||
'outputs': [app.interface.output_interfaces[
|
||||
i].rebuild(
|
||||
app.interface.flagging_dir, request.json['data']['output_data'][i])
|
||||
for i
|
||||
in range(len(app.interface.output_interfaces))]}
|
||||
|
||||
with open(log_fp, "a") as csvfile:
|
||||
headers = ["input_{}".format(i) for i in range(len(
|
||||
output["inputs"]))] + ["output_{}".format(i) for i in
|
||||
range(len(output["outputs"]))]
|
||||
writer = csv.DictWriter(csvfile, delimiter=',',
|
||||
lineterminator='\n',
|
||||
fieldnames=headers)
|
||||
if is_new:
|
||||
writer.writeheader()
|
||||
log_fp = "{}/log.csv".format(app.interface.flagging_dir)
|
||||
|
||||
writer.writerow(
|
||||
dict(zip(headers, output["inputs"] +
|
||||
output["outputs"]))
|
||||
)
|
||||
elif self.path == "/api/interpret/":
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
raw_input = msg["data"]
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(interface.input_interfaces)]
|
||||
|
||||
interpretation = interface.interpretation(interface, processed_input)
|
||||
self.wfile.write(json.dumps(interpretation).encode())
|
||||
else:
|
||||
self.send_error(404, 'Path not found: {}'.format(self.path))
|
||||
is_new = not os.path.exists(log_fp)
|
||||
|
||||
def do_GET(self):
|
||||
if self.path.startswith("/file/"):
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
with open(self.path[6:], "rb") as f:
|
||||
self.wfile.write(f.read())
|
||||
else:
|
||||
super().do_GET()
|
||||
with open(log_fp, "a") as csvfile:
|
||||
headers = ["input_{}".format(i) for i in range(len(
|
||||
output["inputs"]))] + ["output_{}".format(i) for i in
|
||||
range(len(output["outputs"]))]
|
||||
writer = csv.DictWriter(csvfile, delimiter=',',
|
||||
lineterminator='\n',
|
||||
fieldnames=headers)
|
||||
if is_new:
|
||||
writer.writeheader()
|
||||
|
||||
writer.writerow(
|
||||
dict(zip(headers, output["inputs"] +
|
||||
output["outputs"]))
|
||||
)
|
||||
return jsonify(success=True)
|
||||
|
||||
|
||||
@app.route("/api/interpret/", methods=["POST"])
|
||||
def interpret():
|
||||
raw_input = request.json["data"]
|
||||
if app.interface.interpretation == "default":
|
||||
interpreter = gr.interpretation.default()
|
||||
processed_input = []
|
||||
for i, x in enumerate(raw_input):
|
||||
input_interface = copy.deepcopy(app.interface.input_interfaces[i])
|
||||
input_interface.type = gr.interpretation.expected_types[type(input_interface)]
|
||||
processed_input.append(input_interface.preprocess(x))
|
||||
else:
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
||||
interpreter = app.interface.interpretation
|
||||
interpretation = interpreter(app.interface, processed_input)
|
||||
return jsonify(interpretation)
|
||||
|
||||
|
||||
@app.route("/file/<path:path>", methods=["GET"])
|
||||
def file(path):
|
||||
return send_file(os.path.join(os.getcwd(), path))
|
||||
|
||||
|
||||
class HTTPServer(BaseHTTPServer):
|
||||
"""The main server, you pass in base_path which is the path you want to serve requests from"""
|
||||
|
||||
def __init__(self, base_path, server_address, RequestHandlerClass=HTTPHandler):
|
||||
self.base_path = base_path
|
||||
BaseHTTPServer.__init__(self, server_address, RequestHandlerClass)
|
||||
|
||||
class QuittableHTTPThread(threading.Thread):
|
||||
def __init__(self, httpd):
|
||||
super().__init__(daemon=False)
|
||||
self.httpd = httpd
|
||||
self.keep_running =True
|
||||
|
||||
def run(self):
|
||||
while self.keep_running:
|
||||
self.httpd.handle_request()
|
||||
|
||||
httpd = HTTPServer(directory_to_serve, (server_name, port))
|
||||
thread = QuittableHTTPThread(httpd=httpd)
|
||||
thread.start()
|
||||
|
||||
return httpd, thread
|
||||
|
||||
|
||||
def start_simple_server(interface, directory_to_serve=None, server_name=None, server_port=None):
|
||||
def start_server(interface, server_port=None):
|
||||
if server_port is None:
|
||||
server_port = INITIAL_PORT_VALUE
|
||||
port = get_first_available_port(
|
||||
server_port, server_port + TRY_NUM_PORTS
|
||||
)
|
||||
httpd, thread = serve_files_in_background(interface, port, directory_to_serve, server_name)
|
||||
return port, httpd, thread
|
||||
app.interface = interface
|
||||
process = Process(target=app.run, kwargs={"port": port})
|
||||
process.start()
|
||||
return port, app, process
|
||||
|
||||
|
||||
def close_server(server):
|
||||
server.server_close()
|
||||
|
||||
def close_server(process):
|
||||
process.terminate()
|
||||
process.join()
|
||||
|
||||
def url_request(url):
|
||||
try:
|
||||
|
0
gradio/notebook.py
Normal file
0
gradio/notebook.py
Normal file
@ -54,10 +54,10 @@ class Textbox(OutputComponent):
|
||||
}
|
||||
|
||||
def postprocess(self, y):
|
||||
if self.type == "str":
|
||||
return y
|
||||
elif self.type == "number" or self.type == "auto":
|
||||
if self.type == "str" or self.type == "auto":
|
||||
return str(y)
|
||||
elif self.type == "number":
|
||||
return y
|
||||
else:
|
||||
raise ValueError("Unknown type: " + self.type + ". Please choose from: 'str', 'number'")
|
||||
|
||||
@ -68,8 +68,6 @@ class Label(OutputComponent):
|
||||
Output type: Union[Dict[str, float], str, int, float]
|
||||
'''
|
||||
|
||||
LABEL_KEY = "label"
|
||||
CONFIDENCE_KEY = "confidence"
|
||||
CONFIDENCES_KEY = "confidences"
|
||||
|
||||
def __init__(self, num_top_classes=None, type="auto", label=None):
|
||||
@ -85,7 +83,7 @@ class Label(OutputComponent):
|
||||
|
||||
def postprocess(self, y):
|
||||
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or isinstance(y, Number))):
|
||||
return {self.LABEL_KEY: str(y)}
|
||||
return {"label": str(y)}
|
||||
elif self.type == "confidences" or (self.type == "auto" and isinstance(y, dict)):
|
||||
sorted_pred = sorted(
|
||||
y.items(),
|
||||
@ -95,11 +93,11 @@ class Label(OutputComponent):
|
||||
if self.num_top_classes is not None:
|
||||
sorted_pred = sorted_pred[:self.num_top_classes]
|
||||
return {
|
||||
self.LABEL_KEY: sorted_pred[0][0],
|
||||
self.CONFIDENCES_KEY: [
|
||||
"label": sorted_pred[0][0],
|
||||
"confidences": [
|
||||
{
|
||||
self.LABEL_KEY: pred[0],
|
||||
self.CONFIDENCE_KEY: pred[1]
|
||||
"label": pred[0],
|
||||
"confidence": pred[1]
|
||||
} for pred in sorted_pred
|
||||
]
|
||||
}
|
||||
|
@ -84,9 +84,6 @@ input.submit {
|
||||
input.submit:hover {
|
||||
background-color: #f39c12;
|
||||
}
|
||||
.flag {
|
||||
visibility: hidden;
|
||||
}
|
||||
.flagged {
|
||||
background-color: pink !important;
|
||||
}
|
||||
@ -111,9 +108,6 @@ input.submit:hover {
|
||||
.invisible {
|
||||
display: none !important;
|
||||
}
|
||||
.screenshot {
|
||||
visibility: hidden;
|
||||
}
|
||||
.screenshot_logo {
|
||||
display: none;
|
||||
flex-grow: 1;
|
||||
|
@ -80,6 +80,8 @@ var io_master_template = {
|
||||
$.ajax({type: "POST",
|
||||
url: "/api/flag/",
|
||||
data: JSON.stringify(post_data),
|
||||
dataType: 'json',
|
||||
contentType: 'application/json; charset=utf-8',
|
||||
});
|
||||
},
|
||||
interpret: function() {
|
||||
@ -92,9 +94,10 @@ var io_master_template = {
|
||||
$.ajax({type: "POST",
|
||||
url: "/api/interpret/",
|
||||
data: JSON.stringify(post_data),
|
||||
dataType: 'json',
|
||||
contentType: 'application/json; charset=utf-8',
|
||||
success: function(data) {
|
||||
for (let [idx, interpretation] of data.entries()) {
|
||||
console.log(idx)
|
||||
io.input_interfaces[idx].show_interpretation(interpretation);
|
||||
}
|
||||
io.target.find(".loading_in_progress").hide();
|
||||
|
@ -166,14 +166,18 @@ function gradio(config, fn, target, example_file_path) {
|
||||
io_master.last_output = null;
|
||||
});
|
||||
|
||||
if (config["allow_screenshot"]) {
|
||||
target.find(".screenshot").css("visibility", "visible");
|
||||
}
|
||||
if (config["allow_flagging"]) {
|
||||
target.find(".flag").css("visibility", "visible");
|
||||
}
|
||||
if (config["allow_interpretation"]) {
|
||||
target.find(".interpret").css("visibility", "visible");
|
||||
if (!config["allow_screenshot"] && !config["allow_flagging"] && !config["allow_interpretation"]) {
|
||||
target.find(".screenshot, .flag, .interpret").css("visibility", "hidden");
|
||||
} else {
|
||||
if (!config["allow_screenshot"]) {
|
||||
target.find(".screenshot").hide();
|
||||
}
|
||||
if (!config["allow_flagging"]) {
|
||||
target.find(".flag").hide();
|
||||
}
|
||||
if (!config["allow_interpretation"]) {
|
||||
target.find(".interpret").hide();
|
||||
}
|
||||
}
|
||||
if (config["examples"]) {
|
||||
target.find(".examples").removeClass("invisible");
|
||||
|
@ -117,7 +117,7 @@
|
||||
<script src="/static/js/interfaces/output/file.js"></script>
|
||||
<script src="/static/js/gradio.js"></script>
|
||||
<script>
|
||||
$.getJSON("static/config.json", function(config) {
|
||||
$.getJSON("/config/", function(config) {
|
||||
io = gradio_url(config, "/api/predict/", "#interface_target", "/file/");
|
||||
});
|
||||
const copyToClipboard = str => {
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,7 +1,22 @@
|
||||
import requests
|
||||
import pkg_resources
|
||||
from distutils.version import StrictVersion
|
||||
from IPython import get_ipython
|
||||
analytics_url = 'https://api.gradio.app/'
|
||||
|
||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||
def version_check():
|
||||
try:
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
|
||||
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
|
||||
print("IMPORTANT: You are using gradio version {}, "
|
||||
"however version {} "
|
||||
"is available, please upgrade.".format(
|
||||
current_pkg_version, latest_pkg_version))
|
||||
print('--------')
|
||||
except: # TODO(abidlabs): don't catch all exceptions
|
||||
pass
|
||||
|
||||
def error_analytics(type):
|
||||
"""
|
||||
|
1
setup.py
1
setup.py
@ -16,6 +16,7 @@ setup(
|
||||
install_requires=[
|
||||
'numpy',
|
||||
'requests',
|
||||
'flask',
|
||||
'paramiko',
|
||||
'scipy',
|
||||
'IPython',
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,36 +0,0 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
import gradio as gr
|
||||
import gradio.inputs
|
||||
import gradio.outputs
|
||||
|
||||
|
||||
class TestInterface(unittest.TestCase):
|
||||
def test_input_output_mapping(self):
|
||||
io = gr.Interface(inputs='sketchpad', outputs='text', fn=lambda x: x,
|
||||
analytics_enabled=False)
|
||||
self.assertIsInstance(io.input_interfaces[0], gradio.inputs.Image)
|
||||
self.assertIsInstance(io.output_interfaces[0], gradio.outputs.Textbox)
|
||||
|
||||
def test_input_interface_is_instance(self):
|
||||
inp = gradio.inputs.Image()
|
||||
io = gr.Interface(inputs=inp, outputs='text', fn=lambda x: x,
|
||||
analytics_enabled=False)
|
||||
self.assertEqual(io.input_interfaces[0], inp)
|
||||
|
||||
def test_output_interface_is_instance(self):
|
||||
out = gradio.outputs.Label()
|
||||
io = gr.Interface(inputs='sketchpad', outputs=out, fn=lambda x: x,
|
||||
analytics_enabled=False)
|
||||
self.assertEqual(io.output_interfaces[0], out)
|
||||
|
||||
def test_prediction(self):
|
||||
def model(x):
|
||||
return 2*x
|
||||
io = gr.Interface(inputs='textbox', outputs='text', fn=model,
|
||||
analytics_enabled=False)
|
||||
self.assertEqual(io.predict[0](11), 22)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
4
test/test_interfaces.py
Normal file
4
test/test_interfaces.py
Normal file
@ -0,0 +1,4 @@
|
||||
import unittest
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -1,63 +0,0 @@
|
||||
import unittest
|
||||
from gradio import networking
|
||||
from gradio import inputs
|
||||
from gradio import outputs
|
||||
import socket
|
||||
import tempfile
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
class TestGetAvailablePort(unittest.TestCase):
|
||||
def test_get_first_available_port_by_blocking_port(self):
|
||||
initial = 7000
|
||||
final = 8000
|
||||
port_found = False
|
||||
for port in range(initial, final):
|
||||
try:
|
||||
s = socket.socket() # create a socket object
|
||||
s.bind((networking.LOCALHOST_NAME, port)) # Bind to the port
|
||||
s.close()
|
||||
port_found = True
|
||||
break
|
||||
except OSError:
|
||||
pass
|
||||
if port_found:
|
||||
s = socket.socket() # create a socket object
|
||||
s.bind((networking.LOCALHOST_NAME, port)) # Bind to the port
|
||||
new_port = networking.get_first_available_port(initial, final)
|
||||
s.close()
|
||||
self.assertFalse(port == new_port)
|
||||
|
||||
|
||||
# class TestSetSampleData(unittest.TestCase):
|
||||
# def test_set_sample_data(self):
|
||||
# test_array = ["test1", "test2", "test3"]
|
||||
# temp_dir = tempfile.mkdtemp()
|
||||
# inp = inputs.Sketchpad()
|
||||
# out = outputs.Label()
|
||||
# networking.build_template(temp_dir, inp, out)
|
||||
# networking.set_sample_data_in_config_file(temp_dir, test_array)
|
||||
# # We need to come up with a better way so that the config file isn't invalid json unless
|
||||
# # the following parameters are set... (TODO: abidlabs)
|
||||
# networking.set_always_flagged_in_config_file(temp_dir, False)
|
||||
# networking.set_disabled_in_config_file(temp_dir, False)
|
||||
# config_file = os.path.join(temp_dir, 'static/config.json')
|
||||
# with open(config_file) as json_file:
|
||||
# data = json.load(json_file)
|
||||
# self.assertTrue(test_array == data["sample_inputs"])
|
||||
|
||||
# class TestCopyFiles(unittest.TestCase):
|
||||
# def test_copy_files(self):
|
||||
# filename = "a.txt"
|
||||
# with tempfile.TemporaryDirectory() as temp_src:
|
||||
# with open(os.path.join(temp_src, "a.txt"), "w+") as f:
|
||||
# f.write('Hi')
|
||||
# with tempfile.TemporaryDirectory() as temp_dest:
|
||||
# self.assertFalse(os.path.exists(os.path.join(temp_dest, filename)))
|
||||
# networking.copy_files(temp_src, temp_dest)
|
||||
# self.assertTrue(os.path.exists(os.path.join(temp_dest, filename)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user