This commit is contained in:
Ali Abid 2020-07-07 10:46:23 -07:00
parent c9edbb3eb8
commit b1b03eb953
13 changed files with 134 additions and 97 deletions

View File

@ -1 +1 @@
from gradio.interface import Interface # This makes it possible to import `Interface` as `gradio.Interface`.
from gradio.interface import * # This makes it possible to import `Interface` as `gradio.Interface`.

View File

@ -67,7 +67,7 @@ class AbstractInput(ABC):
class Sketchpad(AbstractInput):
def __init__(self, cast_to="numpy", shape=(28, 28), invert_colors=True,
def __init__(self, shape=(28, 28), invert_colors=True,
flatten=False, scale=1/255, shift=0,
dtype='float64', sample_inputs=None, label=None):
self.image_width = shape[0]
@ -110,10 +110,10 @@ class Sketchpad(AbstractInput):
class Webcam(AbstractInput):
def __init__(self, shape=(224, 224), label=None):
self.image_width = shape[0]
self.image_height = shape[1]
self.num_channels = 3
def __init__(self, image_width=224, image_height=224, num_channels=3, label=None):
self.image_width = image_width
self.image_height = image_height
self.num_channels = num_channels
super().__init__(label)
def get_validation_inputs(self):
@ -132,7 +132,8 @@ class Webcam(AbstractInput):
im = preprocessing_utils.decode_base64_to_image(inp)
im = im.convert('RGB')
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
return np.array(im)
array = np.array(im).flatten().reshape(self.image_width, self.image_height, self.num_channels)
return array
class Textbox(AbstractInput):
@ -243,11 +244,16 @@ class Checkbox(AbstractInput):
class Image(AbstractInput):
def __init__(self, cast_to=None, shape=(224, 224), image_mode='RGB', label=None):
def __init__(self, cast_to=None, shape=(224, 224, 3), image_mode='RGB',
scale=1/127.5, shift=-1, cropper_aspect_ratio=None, label=None):
self.cast_to = cast_to
self.image_width = shape[0]
self.image_height = shape[1]
self.num_channels = shape[2]
self.image_mode = image_mode
self.scale = scale
self.shift = shift
self.cropper_aspect_ratio = "false" if cropper_aspect_ratio is None else cropper_aspect_ratio
super().__init__(label)
def get_validation_inputs(self):
@ -294,7 +300,14 @@ class Image(AbstractInput):
im = im.convert(self.image_mode)
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
return np.array(im)
im = np.array(im).flatten()
im = im * self.scale + self.shift
if self.num_channels is None:
array = im.reshape(self.image_width, self.image_height)
else:
array = im.reshape(self.image_width, self.image_height, \
self.num_channels)
return array
def process_example(self, example):
if os.path.exists(example):

View File

@ -17,6 +17,8 @@ import random
import time
import inspect
from IPython import get_ipython
import sys
import weakref
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
@ -27,6 +29,7 @@ class Interface:
The Interface class represents a general input/output interface for a machine learning model. During construction,
the appropriate inputs and outputs
"""
instances = weakref.WeakSet()
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, examples=None,
live=False, show_input=True, show_output=True,
@ -82,6 +85,9 @@ class Interface:
self.description = description
self.thumbnail = thumbnail
self.examples = examples
self.server_port = None
self.simple_server = None
Interface.instances.add(self)
def get_config_file(self):
config = {
@ -143,8 +149,7 @@ class Interface:
raise exception
duration = time.time() - start
if len(self.output_interfaces) / \
len(self.predict) == 1:
if len(self.output_interfaces) == len(self.predict):
prediction = [prediction]
durations.append(duration)
predictions.extend(prediction)
@ -204,7 +209,12 @@ class Interface:
return
raise RuntimeError("Validation did not pass")
def launch(self, inline=None, inbrowser=None, share=False, validate=True):
def close(self):
if self.simple_server and not(self.simple_server.fileno() == -1): # checks to see if server is running
print("Closing Gradio server on port {}...".format(self.server_port))
networking.close_server(self.simple_server)
def launch(self, inline=None, inbrowser=None, share=False, validate=True, debug=False):
"""
Standard method shared by interfaces that creates the interface and sets up a websocket to communicate with it.
:param inline: boolean. If True, then a gradio interface is created inline (e.g. in jupyter or colab notebook)
@ -223,22 +233,13 @@ class Interface:
except (ImportError, AttributeError): # If they are using TF >= 2.0 or don't have TF, just ignore this.
pass
# If an existing interface is running with this instance, close it.
if self.status == "RUNNING":
if self.verbose:
print("Closing existing server...")
if self.simple_server is not None:
try:
networking.close_server(self.simple_server)
except OSError:
pass
output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface.
server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory)
self.server_port = server_port
self.status = "RUNNING"
self.simple_server = httpd
@ -247,6 +248,7 @@ class Interface:
from_ipynb = get_ipython()
if "google.colab" in str(from_ipynb):
is_colab = True
print("Google colab notebook detected.")
except NameError:
pass
@ -264,11 +266,17 @@ class Interface:
if not is_colab:
print(strings.en["RUNNING_LOCALLY"].format(path_to_local_server))
else:
if debug:
print("This cell will run indefinitely so that you can see errors and logs. To turn off, "
"set debug=False in launch().")
else:
print("To show errors in colab notebook, set debug=True in launch()")
if share:
try:
share_url = networking.setup_tunnel(server_port)
print("External URL:", share_url)
print("Running on External URL:", share_url)
except RuntimeError:
share_url = None
if self.verbose:
@ -330,4 +338,18 @@ class Interface:
networking.set_config(config, output_directory)
if debug:
while True:
sys.stdout.flush()
time.sleep(0.1)
return httpd, path_to_local_server, share_url
@classmethod
def get_instances(cls):
return list(Interface.instances) #Returns list of all current instances
def reset_all():
for io in Interface.get_instances():
io.close()

View File

@ -44,6 +44,10 @@ class AbstractOutput(ABC):
class Label(AbstractOutput):
LABEL_KEY = "label"
CONFIDENCE_KEY = "confidence"
CONFIDENCES_KEY = "confidences"
def __init__(self, num_top_classes=None, label=None):
self.num_top_classes = num_top_classes
super().__init__(label)
@ -60,16 +64,19 @@ class Label(AbstractOutput):
if self.num_top_classes is not None:
sorted_pred = sorted_pred[:self.num_top_classes]
return {
"label": sorted_pred[0][0],
"confidences": [
self.LABEL_KEY: sorted_pred[0][0],
self.CONFIDENCES_KEY: [
{
"label": pred[0],
"confidence" : pred[1]
self.LABEL_KEY: pred[0],
self.CONFIDENCE_KEY: pred[1]
} for pred in sorted_pred
]
}
elif isinstance(prediction, int) or isinstance(prediction, float):
return {self.LABEL_KEY: str(prediction)}
else:
raise ValueError("Function output should be string or dict")
raise ValueError("The `Label` output interface expects one of: a string label, or an int label, a "
"float label, or a dictionary whose keys are labels and values are confidences.")
@classmethod
def get_shortcut_implementations(cls):
@ -82,6 +89,13 @@ class KeyValues(AbstractOutput):
def __init__(self, label=None):
super().__init__(label)
def postprocess(self, prediction):
if isinstance(prediction, dict):
return prediction
else:
raise ValueError("The `KeyValues` output interface expects an output that is a dictionary whose keys are "
"labels and values are corresponding values.")
@classmethod
def get_shortcut_implementations(cls):
return {
@ -111,9 +125,11 @@ class Textbox(AbstractOutput):
}
def postprocess(self, prediction):
"""
"""
return prediction
if isinstance(prediction, str) or isinstance(prediction, int) or isinstance(prediction, float):
return str(prediction)
else:
raise ValueError("The `Textbox` output interface expects an output that is one of: a string, or"
"an int/float that can be converted to a string.")
class Image(AbstractOutput):
@ -132,9 +148,16 @@ class Image(AbstractOutput):
"""
"""
if self.plot:
return preprocessing_utils.encode_plot_to_base64(prediction)
try:
return preprocessing_utils.encode_plot_to_base64(prediction)
except:
raise ValueError("The `Image` output interface expects a `matplotlib.pyplot` object"
"if plt=True.")
else:
return preprocessing_utils.encode_array_to_base64(prediction)
try:
return preprocessing_utils.encode_array_to_base64(prediction)
except:
raise ValueError("The `Image` output interface (with plt=False) expects a numpy array.")
def rebuild_flagged(self, dir, msg):
"""

View File

@ -28,7 +28,7 @@ var io_master_template = {
this.fn(this.last_input).then((output) => {
io.output(output);
}).catch((error) => {
console.error(error)
console.error(error);
this.target.find(".loading_in_progress").hide();
this.target.find(".loading_failed").show();
})
@ -38,8 +38,14 @@ var io_master_template = {
for (let i = 0; i < this.output_interfaces.length; i++) {
this.output_interfaces[i].output(data["data"][i]);
// this.output_interfaces[i].target.parent().find(`.loading_time[interface="${i}"]`).text("Latency: " + ((data["durations"][i])).toFixed(2) + "s");
}
if (data["durations"]) {
let ratio = this.output_interfaces.length / data["durations"].length;
for (let i = 0; i < this.output_interfaces.length; i = i + ratio) {
this.output_interfaces[i].target.parent().find(`.loading_time[interface="${i + ratio - 1}"]`).text("Latency: " + ((data["durations"][i / ratio])).toFixed(2) + "s");
}
}
if (this.config.live) {
this.gather();
} else {

View File

@ -132,4 +132,16 @@ function gradio(config, fn, target) {
}
return io_master;
}
function gradio_url(config, url, target) {
return gradio(config, function(data) {
return new Promise((resolve, reject) => {
$.ajax({type: "POST",
url: url,
data: JSON.stringify({"data": data}),
success: resolve,
error: reject,
});
});
}, target);
}

View File

@ -1,6 +1,6 @@
en = {
"BETA_MESSAGE": "NOTE: Gradio is in beta stage, please report all bugs to: gradio.app@gmail.com",
"RUNNING_LOCALLY": "Model is running locally at: {}",
"RUNNING_LOCALLY": "Running locally at: {}",
"NGROK_NO_INTERNET": "Unable to create public link for interface, please check internet connection or try "
"restarting python interpreter.",
"COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.",

View File

@ -81,16 +81,7 @@
<script src="../static/js/gradio.js"></script>
<script>
$.getJSON("static/config.json", function(config) {
io = gradio(config, function(data) {
return new Promise((resolve, reject) => {
$.ajax({type: "POST",
url: "/api/predict/",
data: JSON.stringify({"data": data}),
success: resolve,
error: reject,
});
});
}, "#interface_target");
io = gradio_url(config, "/api/predict/", "#interface_target");
if (config["title"]) {
$("#title").text(config["title"]);
}

View File

@ -137,10 +137,11 @@ class Webcam(AbstractInput):
class Textbox(AbstractInput):
def __init__(self, sample_inputs=None, lines=1, placeholder=None, label=None, numeric=False):
def __init__(self, sample_inputs=None, lines=1, placeholder=None, default=None, label=None, numeric=False):
self.sample_inputs = sample_inputs
self.lines = lines
self.placeholder = placeholder
self.default = default
self.numeric = numeric
super().__init__(label)
@ -151,6 +152,7 @@ class Textbox(AbstractInput):
return {
"lines": self.lines,
"placeholder": self.placeholder,
"default": self.default,
**super().get_template_context()
}
@ -242,16 +244,9 @@ class Checkbox(AbstractInput):
class Image(AbstractInput):
def __init__(self, cast_to=None, shape=(224, 224, 3), image_mode='RGB',
scale=1/127.5, shift=-1, cropper_aspect_ratio=None, label=None):
self.cast_to = cast_to
def __init__(self, cast_to=None, shape=(224, 224), label=None):
self.image_width = shape[0]
self.image_height = shape[1]
self.num_channels = shape[2]
self.image_mode = image_mode
self.scale = scale
self.shift = shift
self.cropper_aspect_ratio = "false" if cropper_aspect_ratio is None else cropper_aspect_ratio
super().__init__(label)
def get_validation_inputs(self):
@ -269,29 +264,10 @@ class Image(AbstractInput):
**super().get_template_context()
}
def cast_to_base64(self, inp):
return inp
def cast_to_im(self, inp):
return preprocessing_utils.decode_base64_to_image(inp)
def cast_to_numpy(self, inp):
im = self.cast_to_im(inp)
arr = np.array(im).flatten()
return arr
def preprocess(self, inp):
"""
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
"""
cast_to_type = {
"base64": self.cast_to_base64,
"numpy": self.cast_to_numpy,
"pillow": self.cast_to_im
}
if self.cast_to:
return cast_to_type[self.cast_to](inp)
im = preprocessing_utils.decode_base64_to_image(inp)
with warnings.catch_warnings():
warnings.simplefilter("ignore")

View File

@ -149,8 +149,7 @@ class Interface:
raise exception
duration = time.time() - start
if len(self.output_interfaces) / \
len(self.predict) == 1:
if len(self.output_interfaces) == len(self.predict):
prediction = [prediction]
durations.append(duration)
predictions.extend(prediction)

View File

@ -28,7 +28,7 @@ var io_master_template = {
this.fn(this.last_input).then((output) => {
io.output(output);
}).catch((error) => {
console.error(error)
console.error(error);
this.target.find(".loading_in_progress").hide();
this.target.find(".loading_failed").show();
})
@ -39,16 +39,8 @@ var io_master_template = {
for (let i = 0; i < this.output_interfaces.length; i++) {
this.output_interfaces[i].output(data["data"][i]);
}
let ratio;
if (data["durations"].length === 1) {
this.output_interfaces[0].target.parent().find(`.loading_time[interface="${this.output_interfaces.length - 1}"]`).text("Latency: " + ((data["durations"][0])).toFixed(2) + "s");
} else if (this.output_interfaces.length === data["durations"].length) {
for (let i = 0; i < this.output_interfaces.length; i++) {
this.output_interfaces[i].target.parent().find(`.loading_time[interface="${i}"]`).text("Latency: " + ((data["durations"][i])).toFixed(2) + "s");
}
} else {
ratio = this.output_interfaces.length / data["durations"].length;
if (data["durations"]) {
let ratio = this.output_interfaces.length / data["durations"].length;
for (let i = 0; i < this.output_interfaces.length; i = i + ratio) {
this.output_interfaces[i].target.parent().find(`.loading_time[interface="${i + ratio - 1}"]`).text("Latency: " + ((data["durations"][i / ratio])).toFixed(2) + "s");
}

View File

@ -132,4 +132,16 @@ function gradio(config, fn, target) {
}
return io_master;
}
function gradio_url(config, url, target) {
return gradio(config, function(data) {
return new Promise((resolve, reject) => {
$.ajax({type: "POST",
url: url,
data: JSON.stringify({"data": data}),
success: resolve,
error: reject,
});
});
}, target);
}

View File

@ -81,16 +81,7 @@
<script src="../static/js/gradio.js"></script>
<script>
$.getJSON("static/config.json", function(config) {
io = gradio(config, function(data) {
return new Promise((resolve, reject) => {
$.ajax({type: "POST",
url: "/api/predict/",
data: JSON.stringify({"data": data}),
success: resolve,
error: reject,
});
});
}, "#interface_target");
io = gradio_url(config, "/api/predict/", "#interface_target");
if (config["title"]) {
$("#title").text(config["title"]);
}