mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
changes
This commit is contained in:
parent
c9edbb3eb8
commit
b1b03eb953
@ -1 +1 @@
|
||||
from gradio.interface import Interface # This makes it possible to import `Interface` as `gradio.Interface`.
|
||||
from gradio.interface import * # This makes it possible to import `Interface` as `gradio.Interface`.
|
||||
|
@ -67,7 +67,7 @@ class AbstractInput(ABC):
|
||||
|
||||
|
||||
class Sketchpad(AbstractInput):
|
||||
def __init__(self, cast_to="numpy", shape=(28, 28), invert_colors=True,
|
||||
def __init__(self, shape=(28, 28), invert_colors=True,
|
||||
flatten=False, scale=1/255, shift=0,
|
||||
dtype='float64', sample_inputs=None, label=None):
|
||||
self.image_width = shape[0]
|
||||
@ -110,10 +110,10 @@ class Sketchpad(AbstractInput):
|
||||
|
||||
|
||||
class Webcam(AbstractInput):
|
||||
def __init__(self, shape=(224, 224), label=None):
|
||||
self.image_width = shape[0]
|
||||
self.image_height = shape[1]
|
||||
self.num_channels = 3
|
||||
def __init__(self, image_width=224, image_height=224, num_channels=3, label=None):
|
||||
self.image_width = image_width
|
||||
self.image_height = image_height
|
||||
self.num_channels = num_channels
|
||||
super().__init__(label)
|
||||
|
||||
def get_validation_inputs(self):
|
||||
@ -132,7 +132,8 @@ class Webcam(AbstractInput):
|
||||
im = preprocessing_utils.decode_base64_to_image(inp)
|
||||
im = im.convert('RGB')
|
||||
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
|
||||
return np.array(im)
|
||||
array = np.array(im).flatten().reshape(self.image_width, self.image_height, self.num_channels)
|
||||
return array
|
||||
|
||||
|
||||
class Textbox(AbstractInput):
|
||||
@ -243,11 +244,16 @@ class Checkbox(AbstractInput):
|
||||
|
||||
|
||||
class Image(AbstractInput):
|
||||
def __init__(self, cast_to=None, shape=(224, 224), image_mode='RGB', label=None):
|
||||
def __init__(self, cast_to=None, shape=(224, 224, 3), image_mode='RGB',
|
||||
scale=1/127.5, shift=-1, cropper_aspect_ratio=None, label=None):
|
||||
self.cast_to = cast_to
|
||||
self.image_width = shape[0]
|
||||
self.image_height = shape[1]
|
||||
self.num_channels = shape[2]
|
||||
self.image_mode = image_mode
|
||||
self.scale = scale
|
||||
self.shift = shift
|
||||
self.cropper_aspect_ratio = "false" if cropper_aspect_ratio is None else cropper_aspect_ratio
|
||||
super().__init__(label)
|
||||
|
||||
def get_validation_inputs(self):
|
||||
@ -294,7 +300,14 @@ class Image(AbstractInput):
|
||||
im = im.convert(self.image_mode)
|
||||
|
||||
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
|
||||
return np.array(im)
|
||||
im = np.array(im).flatten()
|
||||
im = im * self.scale + self.shift
|
||||
if self.num_channels is None:
|
||||
array = im.reshape(self.image_width, self.image_height)
|
||||
else:
|
||||
array = im.reshape(self.image_width, self.image_height, \
|
||||
self.num_channels)
|
||||
return array
|
||||
|
||||
def process_example(self, example):
|
||||
if os.path.exists(example):
|
||||
|
@ -17,6 +17,8 @@ import random
|
||||
import time
|
||||
import inspect
|
||||
from IPython import get_ipython
|
||||
import sys
|
||||
import weakref
|
||||
|
||||
|
||||
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
|
||||
@ -27,6 +29,7 @@ class Interface:
|
||||
The Interface class represents a general input/output interface for a machine learning model. During construction,
|
||||
the appropriate inputs and outputs
|
||||
"""
|
||||
instances = weakref.WeakSet()
|
||||
|
||||
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, examples=None,
|
||||
live=False, show_input=True, show_output=True,
|
||||
@ -82,6 +85,9 @@ class Interface:
|
||||
self.description = description
|
||||
self.thumbnail = thumbnail
|
||||
self.examples = examples
|
||||
self.server_port = None
|
||||
self.simple_server = None
|
||||
Interface.instances.add(self)
|
||||
|
||||
def get_config_file(self):
|
||||
config = {
|
||||
@ -143,8 +149,7 @@ class Interface:
|
||||
raise exception
|
||||
duration = time.time() - start
|
||||
|
||||
if len(self.output_interfaces) / \
|
||||
len(self.predict) == 1:
|
||||
if len(self.output_interfaces) == len(self.predict):
|
||||
prediction = [prediction]
|
||||
durations.append(duration)
|
||||
predictions.extend(prediction)
|
||||
@ -204,7 +209,12 @@ class Interface:
|
||||
return
|
||||
raise RuntimeError("Validation did not pass")
|
||||
|
||||
def launch(self, inline=None, inbrowser=None, share=False, validate=True):
|
||||
def close(self):
|
||||
if self.simple_server and not(self.simple_server.fileno() == -1): # checks to see if server is running
|
||||
print("Closing Gradio server on port {}...".format(self.server_port))
|
||||
networking.close_server(self.simple_server)
|
||||
|
||||
def launch(self, inline=None, inbrowser=None, share=False, validate=True, debug=False):
|
||||
"""
|
||||
Standard method shared by interfaces that creates the interface and sets up a websocket to communicate with it.
|
||||
:param inline: boolean. If True, then a gradio interface is created inline (e.g. in jupyter or colab notebook)
|
||||
@ -223,22 +233,13 @@ class Interface:
|
||||
except (ImportError, AttributeError): # If they are using TF >= 2.0 or don't have TF, just ignore this.
|
||||
pass
|
||||
|
||||
# If an existing interface is running with this instance, close it.
|
||||
if self.status == "RUNNING":
|
||||
if self.verbose:
|
||||
print("Closing existing server...")
|
||||
if self.simple_server is not None:
|
||||
try:
|
||||
networking.close_server(self.simple_server)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
output_directory = tempfile.mkdtemp()
|
||||
# Set up a port to serve the directory containing the static files with interface.
|
||||
server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name)
|
||||
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
|
||||
networking.build_template(output_directory)
|
||||
|
||||
self.server_port = server_port
|
||||
self.status = "RUNNING"
|
||||
self.simple_server = httpd
|
||||
|
||||
@ -247,6 +248,7 @@ class Interface:
|
||||
from_ipynb = get_ipython()
|
||||
if "google.colab" in str(from_ipynb):
|
||||
is_colab = True
|
||||
print("Google colab notebook detected.")
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
@ -264,11 +266,17 @@ class Interface:
|
||||
|
||||
if not is_colab:
|
||||
print(strings.en["RUNNING_LOCALLY"].format(path_to_local_server))
|
||||
else:
|
||||
if debug:
|
||||
print("This cell will run indefinitely so that you can see errors and logs. To turn off, "
|
||||
"set debug=False in launch().")
|
||||
else:
|
||||
print("To show errors in colab notebook, set debug=True in launch()")
|
||||
|
||||
if share:
|
||||
try:
|
||||
share_url = networking.setup_tunnel(server_port)
|
||||
print("External URL:", share_url)
|
||||
print("Running on External URL:", share_url)
|
||||
except RuntimeError:
|
||||
share_url = None
|
||||
if self.verbose:
|
||||
@ -330,4 +338,18 @@ class Interface:
|
||||
|
||||
networking.set_config(config, output_directory)
|
||||
|
||||
if debug:
|
||||
while True:
|
||||
sys.stdout.flush()
|
||||
time.sleep(0.1)
|
||||
|
||||
return httpd, path_to_local_server, share_url
|
||||
|
||||
@classmethod
|
||||
def get_instances(cls):
|
||||
return list(Interface.instances) #Returns list of all current instances
|
||||
|
||||
|
||||
def reset_all():
|
||||
for io in Interface.get_instances():
|
||||
io.close()
|
||||
|
@ -44,6 +44,10 @@ class AbstractOutput(ABC):
|
||||
|
||||
|
||||
class Label(AbstractOutput):
|
||||
LABEL_KEY = "label"
|
||||
CONFIDENCE_KEY = "confidence"
|
||||
CONFIDENCES_KEY = "confidences"
|
||||
|
||||
def __init__(self, num_top_classes=None, label=None):
|
||||
self.num_top_classes = num_top_classes
|
||||
super().__init__(label)
|
||||
@ -60,16 +64,19 @@ class Label(AbstractOutput):
|
||||
if self.num_top_classes is not None:
|
||||
sorted_pred = sorted_pred[:self.num_top_classes]
|
||||
return {
|
||||
"label": sorted_pred[0][0],
|
||||
"confidences": [
|
||||
self.LABEL_KEY: sorted_pred[0][0],
|
||||
self.CONFIDENCES_KEY: [
|
||||
{
|
||||
"label": pred[0],
|
||||
"confidence" : pred[1]
|
||||
self.LABEL_KEY: pred[0],
|
||||
self.CONFIDENCE_KEY: pred[1]
|
||||
} for pred in sorted_pred
|
||||
]
|
||||
}
|
||||
elif isinstance(prediction, int) or isinstance(prediction, float):
|
||||
return {self.LABEL_KEY: str(prediction)}
|
||||
else:
|
||||
raise ValueError("Function output should be string or dict")
|
||||
raise ValueError("The `Label` output interface expects one of: a string label, or an int label, a "
|
||||
"float label, or a dictionary whose keys are labels and values are confidences.")
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
@ -82,6 +89,13 @@ class KeyValues(AbstractOutput):
|
||||
def __init__(self, label=None):
|
||||
super().__init__(label)
|
||||
|
||||
def postprocess(self, prediction):
|
||||
if isinstance(prediction, dict):
|
||||
return prediction
|
||||
else:
|
||||
raise ValueError("The `KeyValues` output interface expects an output that is a dictionary whose keys are "
|
||||
"labels and values are corresponding values.")
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
return {
|
||||
@ -111,9 +125,11 @@ class Textbox(AbstractOutput):
|
||||
}
|
||||
|
||||
def postprocess(self, prediction):
|
||||
"""
|
||||
"""
|
||||
return prediction
|
||||
if isinstance(prediction, str) or isinstance(prediction, int) or isinstance(prediction, float):
|
||||
return str(prediction)
|
||||
else:
|
||||
raise ValueError("The `Textbox` output interface expects an output that is one of: a string, or"
|
||||
"an int/float that can be converted to a string.")
|
||||
|
||||
|
||||
class Image(AbstractOutput):
|
||||
@ -132,9 +148,16 @@ class Image(AbstractOutput):
|
||||
"""
|
||||
"""
|
||||
if self.plot:
|
||||
return preprocessing_utils.encode_plot_to_base64(prediction)
|
||||
try:
|
||||
return preprocessing_utils.encode_plot_to_base64(prediction)
|
||||
except:
|
||||
raise ValueError("The `Image` output interface expects a `matplotlib.pyplot` object"
|
||||
"if plt=True.")
|
||||
else:
|
||||
return preprocessing_utils.encode_array_to_base64(prediction)
|
||||
try:
|
||||
return preprocessing_utils.encode_array_to_base64(prediction)
|
||||
except:
|
||||
raise ValueError("The `Image` output interface (with plt=False) expects a numpy array.")
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
"""
|
||||
|
@ -28,7 +28,7 @@ var io_master_template = {
|
||||
this.fn(this.last_input).then((output) => {
|
||||
io.output(output);
|
||||
}).catch((error) => {
|
||||
console.error(error)
|
||||
console.error(error);
|
||||
this.target.find(".loading_in_progress").hide();
|
||||
this.target.find(".loading_failed").show();
|
||||
})
|
||||
@ -38,8 +38,14 @@ var io_master_template = {
|
||||
|
||||
for (let i = 0; i < this.output_interfaces.length; i++) {
|
||||
this.output_interfaces[i].output(data["data"][i]);
|
||||
// this.output_interfaces[i].target.parent().find(`.loading_time[interface="${i}"]`).text("Latency: " + ((data["durations"][i])).toFixed(2) + "s");
|
||||
}
|
||||
if (data["durations"]) {
|
||||
let ratio = this.output_interfaces.length / data["durations"].length;
|
||||
for (let i = 0; i < this.output_interfaces.length; i = i + ratio) {
|
||||
this.output_interfaces[i].target.parent().find(`.loading_time[interface="${i + ratio - 1}"]`).text("Latency: " + ((data["durations"][i / ratio])).toFixed(2) + "s");
|
||||
}
|
||||
}
|
||||
|
||||
if (this.config.live) {
|
||||
this.gather();
|
||||
} else {
|
||||
|
@ -132,4 +132,16 @@ function gradio(config, fn, target) {
|
||||
}
|
||||
|
||||
return io_master;
|
||||
}
|
||||
function gradio_url(config, url, target) {
|
||||
return gradio(config, function(data) {
|
||||
return new Promise((resolve, reject) => {
|
||||
$.ajax({type: "POST",
|
||||
url: url,
|
||||
data: JSON.stringify({"data": data}),
|
||||
success: resolve,
|
||||
error: reject,
|
||||
});
|
||||
});
|
||||
}, target);
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
en = {
|
||||
"BETA_MESSAGE": "NOTE: Gradio is in beta stage, please report all bugs to: gradio.app@gmail.com",
|
||||
"RUNNING_LOCALLY": "Model is running locally at: {}",
|
||||
"RUNNING_LOCALLY": "Running locally at: {}",
|
||||
"NGROK_NO_INTERNET": "Unable to create public link for interface, please check internet connection or try "
|
||||
"restarting python interpreter.",
|
||||
"COLAB_NO_LOCAL": "Cannot display local interface on google colab, public link created.",
|
||||
|
@ -81,16 +81,7 @@
|
||||
<script src="../static/js/gradio.js"></script>
|
||||
<script>
|
||||
$.getJSON("static/config.json", function(config) {
|
||||
io = gradio(config, function(data) {
|
||||
return new Promise((resolve, reject) => {
|
||||
$.ajax({type: "POST",
|
||||
url: "/api/predict/",
|
||||
data: JSON.stringify({"data": data}),
|
||||
success: resolve,
|
||||
error: reject,
|
||||
});
|
||||
});
|
||||
}, "#interface_target");
|
||||
io = gradio_url(config, "/api/predict/", "#interface_target");
|
||||
if (config["title"]) {
|
||||
$("#title").text(config["title"]);
|
||||
}
|
||||
|
@ -137,10 +137,11 @@ class Webcam(AbstractInput):
|
||||
|
||||
|
||||
class Textbox(AbstractInput):
|
||||
def __init__(self, sample_inputs=None, lines=1, placeholder=None, label=None, numeric=False):
|
||||
def __init__(self, sample_inputs=None, lines=1, placeholder=None, default=None, label=None, numeric=False):
|
||||
self.sample_inputs = sample_inputs
|
||||
self.lines = lines
|
||||
self.placeholder = placeholder
|
||||
self.default = default
|
||||
self.numeric = numeric
|
||||
super().__init__(label)
|
||||
|
||||
@ -151,6 +152,7 @@ class Textbox(AbstractInput):
|
||||
return {
|
||||
"lines": self.lines,
|
||||
"placeholder": self.placeholder,
|
||||
"default": self.default,
|
||||
**super().get_template_context()
|
||||
}
|
||||
|
||||
@ -242,16 +244,9 @@ class Checkbox(AbstractInput):
|
||||
|
||||
|
||||
class Image(AbstractInput):
|
||||
def __init__(self, cast_to=None, shape=(224, 224, 3), image_mode='RGB',
|
||||
scale=1/127.5, shift=-1, cropper_aspect_ratio=None, label=None):
|
||||
self.cast_to = cast_to
|
||||
def __init__(self, cast_to=None, shape=(224, 224), label=None):
|
||||
self.image_width = shape[0]
|
||||
self.image_height = shape[1]
|
||||
self.num_channels = shape[2]
|
||||
self.image_mode = image_mode
|
||||
self.scale = scale
|
||||
self.shift = shift
|
||||
self.cropper_aspect_ratio = "false" if cropper_aspect_ratio is None else cropper_aspect_ratio
|
||||
super().__init__(label)
|
||||
|
||||
def get_validation_inputs(self):
|
||||
@ -269,29 +264,10 @@ class Image(AbstractInput):
|
||||
**super().get_template_context()
|
||||
}
|
||||
|
||||
def cast_to_base64(self, inp):
|
||||
return inp
|
||||
|
||||
def cast_to_im(self, inp):
|
||||
return preprocessing_utils.decode_base64_to_image(inp)
|
||||
|
||||
def cast_to_numpy(self, inp):
|
||||
im = self.cast_to_im(inp)
|
||||
arr = np.array(im).flatten()
|
||||
return arr
|
||||
|
||||
def preprocess(self, inp):
|
||||
"""
|
||||
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
|
||||
"""
|
||||
cast_to_type = {
|
||||
"base64": self.cast_to_base64,
|
||||
"numpy": self.cast_to_numpy,
|
||||
"pillow": self.cast_to_im
|
||||
}
|
||||
if self.cast_to:
|
||||
return cast_to_type[self.cast_to](inp)
|
||||
|
||||
im = preprocessing_utils.decode_base64_to_image(inp)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
@ -149,8 +149,7 @@ class Interface:
|
||||
raise exception
|
||||
duration = time.time() - start
|
||||
|
||||
if len(self.output_interfaces) / \
|
||||
len(self.predict) == 1:
|
||||
if len(self.output_interfaces) == len(self.predict):
|
||||
prediction = [prediction]
|
||||
durations.append(duration)
|
||||
predictions.extend(prediction)
|
||||
|
@ -28,7 +28,7 @@ var io_master_template = {
|
||||
this.fn(this.last_input).then((output) => {
|
||||
io.output(output);
|
||||
}).catch((error) => {
|
||||
console.error(error)
|
||||
console.error(error);
|
||||
this.target.find(".loading_in_progress").hide();
|
||||
this.target.find(".loading_failed").show();
|
||||
})
|
||||
@ -39,16 +39,8 @@ var io_master_template = {
|
||||
for (let i = 0; i < this.output_interfaces.length; i++) {
|
||||
this.output_interfaces[i].output(data["data"][i]);
|
||||
}
|
||||
|
||||
let ratio;
|
||||
if (data["durations"].length === 1) {
|
||||
this.output_interfaces[0].target.parent().find(`.loading_time[interface="${this.output_interfaces.length - 1}"]`).text("Latency: " + ((data["durations"][0])).toFixed(2) + "s");
|
||||
} else if (this.output_interfaces.length === data["durations"].length) {
|
||||
for (let i = 0; i < this.output_interfaces.length; i++) {
|
||||
this.output_interfaces[i].target.parent().find(`.loading_time[interface="${i}"]`).text("Latency: " + ((data["durations"][i])).toFixed(2) + "s");
|
||||
}
|
||||
} else {
|
||||
ratio = this.output_interfaces.length / data["durations"].length;
|
||||
if (data["durations"]) {
|
||||
let ratio = this.output_interfaces.length / data["durations"].length;
|
||||
for (let i = 0; i < this.output_interfaces.length; i = i + ratio) {
|
||||
this.output_interfaces[i].target.parent().find(`.loading_time[interface="${i + ratio - 1}"]`).text("Latency: " + ((data["durations"][i / ratio])).toFixed(2) + "s");
|
||||
}
|
||||
|
@ -132,4 +132,16 @@ function gradio(config, fn, target) {
|
||||
}
|
||||
|
||||
return io_master;
|
||||
}
|
||||
function gradio_url(config, url, target) {
|
||||
return gradio(config, function(data) {
|
||||
return new Promise((resolve, reject) => {
|
||||
$.ajax({type: "POST",
|
||||
url: url,
|
||||
data: JSON.stringify({"data": data}),
|
||||
success: resolve,
|
||||
error: reject,
|
||||
});
|
||||
});
|
||||
}, target);
|
||||
}
|
@ -81,16 +81,7 @@
|
||||
<script src="../static/js/gradio.js"></script>
|
||||
<script>
|
||||
$.getJSON("static/config.json", function(config) {
|
||||
io = gradio(config, function(data) {
|
||||
return new Promise((resolve, reject) => {
|
||||
$.ajax({type: "POST",
|
||||
url: "/api/predict/",
|
||||
data: JSON.stringify({"data": data}),
|
||||
success: resolve,
|
||||
error: reject,
|
||||
});
|
||||
});
|
||||
}, "#interface_target");
|
||||
io = gradio_url(config, "/api/predict/", "#interface_target");
|
||||
if (config["title"]) {
|
||||
$("#title").text(config["title"]);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user