mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-17 11:29:58 +08:00
fix examples
This commit is contained in:
commit
3f359358d6
15
README.md
15
README.md
@ -40,20 +40,15 @@ import gradio
|
||||
import tensorflow as tf
|
||||
from imagenetlabels import idx_to_labels
|
||||
|
||||
graph = tf.get_default_graph()
|
||||
sess = tf.keras.backend.get_session()
|
||||
|
||||
def classify_image(inp):
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
inp = inp.reshape((1, 224, 224, 3))
|
||||
prediction = mobile_net.predict(inp).flatten()
|
||||
return {idx_to_labels[i].split(',')[0]: float(prediction[i]) for i in range(1000)}
|
||||
inp = inp.reshape((1, 224, 224, 3))
|
||||
prediction = mobile_net.predict(inp).flatten()
|
||||
return {idx_to_labels[i].split(',')[0]: float(prediction[i]) for i in range(1000)}
|
||||
|
||||
imagein = gradio.inputs.ImageIn(shape=(224, 224, 3))
|
||||
imagein = gradio.inputs.Image(shape=(224, 224, 3))
|
||||
label = gradio.outputs.Label(num_top_classes=3)
|
||||
|
||||
gr.Interface(classify_image, imagein, label).launch();
|
||||
gr.Interface(classify_image, imagein, label, capture_session=True).launch();
|
||||
```
|
||||
|
||||

|
||||
|
@ -209,15 +209,17 @@ class CheckboxGroup(AbstractInput):
|
||||
|
||||
|
||||
class Slider(AbstractInput):
|
||||
def __init__(self, minimum=0, maximum=100, label=None):
|
||||
def __init__(self, minimum=0, maximum=100, default=None, label=None):
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
self.default = minimum if default is None else default
|
||||
super().__init__(label)
|
||||
|
||||
def get_template_context(self):
|
||||
return {
|
||||
"minimum": self.minimum,
|
||||
"maximum": self.maximum,
|
||||
"default": self.default,
|
||||
**super().get_template_context()
|
||||
}
|
||||
|
||||
|
@ -40,7 +40,7 @@ class Interface:
|
||||
"""
|
||||
def get_input_instance(iface):
|
||||
if isinstance(iface, str):
|
||||
return gradio.inputs.shortcuts[iface]
|
||||
return gradio.inputs.shortcuts[iface.lower()]
|
||||
elif isinstance(iface, gradio.inputs.AbstractInput):
|
||||
return iface
|
||||
else:
|
||||
@ -49,7 +49,7 @@ class Interface:
|
||||
|
||||
def get_output_instance(iface):
|
||||
if isinstance(iface, str):
|
||||
return gradio.outputs.shortcuts[iface]
|
||||
return gradio.outputs.shortcuts[iface.lower()]
|
||||
elif isinstance(iface, gradio.outputs.AbstractOutput):
|
||||
return iface
|
||||
else:
|
||||
@ -115,8 +115,10 @@ class Interface:
|
||||
raw_input[i]) for i, input_interface in
|
||||
enumerate(self.input_interfaces)]
|
||||
predictions = []
|
||||
durations = []
|
||||
for predict_fn in self.predict:
|
||||
if self.capture_session:
|
||||
start = time.time()
|
||||
if self.capture_session and not(self.session is None):
|
||||
graph, sess = self.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
@ -135,14 +137,16 @@ class Interface:
|
||||
"error.")
|
||||
else:
|
||||
raise exception
|
||||
duration = time.time() - start
|
||||
|
||||
if len(self.output_interfaces) / \
|
||||
len(self.predict) == 1:
|
||||
prediction = [prediction]
|
||||
durations.append(duration)
|
||||
predictions.extend(prediction)
|
||||
processed_output = [output_interface.postprocess(
|
||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||
return processed_output
|
||||
return processed_output, durations
|
||||
|
||||
def validate(self):
|
||||
if self.validate_flag:
|
||||
@ -208,9 +212,12 @@ class Interface:
|
||||
# self.validate()
|
||||
|
||||
if self.capture_session:
|
||||
import tensorflow as tf
|
||||
self.session = tf.get_default_graph(), \
|
||||
tf.keras.backend.get_session()
|
||||
try:
|
||||
import tensorflow as tf
|
||||
self.session = tf.get_default_graph(), \
|
||||
tf.keras.backend.get_session()
|
||||
except (ImportError, AttributeError): # If they are using TF >= 2.0 or don't have TF, just ignore this.
|
||||
pass
|
||||
|
||||
# If an existing interface is running with this instance, close it.
|
||||
if self.status == "RUNNING":
|
||||
@ -257,7 +264,7 @@ class Interface:
|
||||
if share:
|
||||
try:
|
||||
share_url = networking.setup_tunnel(server_port)
|
||||
print(share_url)
|
||||
print("External URL:", share_url)
|
||||
except RuntimeError:
|
||||
share_url = None
|
||||
if self.verbose:
|
||||
@ -299,6 +306,7 @@ class Interface:
|
||||
is_colab
|
||||
): # Embed the remote interface page if on google colab;
|
||||
# otherwise, embed the local page.
|
||||
time.sleep(1)
|
||||
display(IFrame(share_url, width=1000, height=500))
|
||||
else:
|
||||
display(IFrame(path_to_local_server, width=1000, height=500))
|
||||
|
@ -139,7 +139,9 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
raw_input = msg["data"]
|
||||
output = {"data": interface.process(raw_input)}
|
||||
prediction, durations = interface.process(raw_input)
|
||||
|
||||
output = {"data": prediction, "durations": durations}
|
||||
if interface.saliency is not None:
|
||||
saliency = interface.saliency(raw_input, prediction)
|
||||
output['saliency'] = saliency.tolist()
|
||||
|
@ -13,6 +13,14 @@ button, input[type="submit"], input[type="reset"], input[type="text"], input[typ
|
||||
-webkit-appearance: none;
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
.loading_time {
|
||||
font-size: large;
|
||||
color: #EEA45D;
|
||||
text-align: right;
|
||||
padding-top: 5px;
|
||||
}
|
||||
|
||||
nav {
|
||||
text-align: center;
|
||||
padding: 16px 0 4px;
|
||||
|
@ -32,16 +32,18 @@ var io_master_template = {
|
||||
this.target.find(".loading_failed").show();
|
||||
})
|
||||
},
|
||||
output: function(data) {
|
||||
output: function(data) {
|
||||
this.last_output = data["data"];
|
||||
|
||||
for (let i = 0; i < this.output_interfaces.length; i++) {
|
||||
this.output_interfaces[i].output(data["data"][i]);
|
||||
this.output_interfaces[i].target.parent().find(`.loading_time[interface="${i}"]`).text("Latency: " + ((data["durations"][i])).toFixed(2) + "s");
|
||||
}
|
||||
if (this.config.live) {
|
||||
this.gather();
|
||||
} else {
|
||||
this.target.find(".loading").addClass("invisible");
|
||||
this.target.find(".output_interface").removeClass("invisible");
|
||||
this.target.find(".output_interface").removeClass("invisible");
|
||||
this.target.find(".output_interfaces .panel_header").removeClass("invisible");
|
||||
}
|
||||
},
|
||||
|
@ -91,6 +91,9 @@ function gradio(config, fn, target) {
|
||||
${output_interface.html}
|
||||
</div>
|
||||
`);
|
||||
target.find(".output_interfaces").append(`
|
||||
<div class="loading_time" interface="${i}"> </div>
|
||||
`);
|
||||
output_interface.target = target.find(`.output_interface[interface_id=${_id}]`);
|
||||
set_interface_id(output_interface, _id);
|
||||
output_interface.io_master = io_master;
|
||||
@ -110,6 +113,7 @@ function gradio(config, fn, target) {
|
||||
target.find(".flag").removeClass("flagged");
|
||||
target.find(".flag_message").empty();
|
||||
target.find(".loading").addClass("invisible");
|
||||
target.find(".loading_time").text("");
|
||||
target.find(".output_interface").removeClass("invisible");
|
||||
io_master.last_input = null;
|
||||
io_master.last_output = null;
|
||||
|
@ -12,6 +12,7 @@ const slider = {
|
||||
this.target.css("height", "auto");
|
||||
this.target.find(".min").text(opts.minimum);
|
||||
this.target.find(".max").text(opts.maximum);
|
||||
this.target.find(".value").text(opts.default);
|
||||
let difference = opts.maximum - opts.minimum;
|
||||
if (difference <= 1) {
|
||||
step = 0.01;
|
||||
@ -23,6 +24,7 @@ const slider = {
|
||||
this.target.find(".slider")
|
||||
.attr("min", opts.minimum)
|
||||
.attr("max", opts.maximum)
|
||||
.attr("value", opts.default)
|
||||
.attr("step", step)
|
||||
.on("change", function() {
|
||||
io.target.find(".value").text($(this).val());
|
||||
@ -33,7 +35,7 @@ const slider = {
|
||||
this.io_master.input(this.id, parseFloat(value));
|
||||
},
|
||||
clear: function() {
|
||||
this.target.find("input").val(this.minimum);
|
||||
this.target.find("input").val(this.default);
|
||||
},
|
||||
load_example: function(data) {
|
||||
this.target.find("input").val(data);
|
||||
|
BIN
dist/gradio-0.9.4.tar.gz
vendored
BIN
dist/gradio-0.9.4.tar.gz
vendored
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
dist/gradio-0.9.5.tar.gz
vendored
Normal file
BIN
dist/gradio-0.9.5.tar.gz
vendored
Normal file
Binary file not shown.
@ -1,6 +1,6 @@
|
||||
Metadata-Version: 1.0
|
||||
Name: gradio
|
||||
Version: 0.9.4
|
||||
Version: 0.9.5
|
||||
Summary: Python library for easily interacting with trained machine learning models
|
||||
Home-page: https://github.com/abidlabs/gradio
|
||||
Author: Abubakar Abid
|
||||
|
@ -209,15 +209,17 @@ class CheckboxGroup(AbstractInput):
|
||||
|
||||
|
||||
class Slider(AbstractInput):
|
||||
def __init__(self, minimum=0, maximum=100, label=None):
|
||||
def __init__(self, minimum=0, maximum=100, default=None, label=None):
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
self.default = minimum if default is None else default
|
||||
super().__init__(label)
|
||||
|
||||
def get_template_context(self):
|
||||
return {
|
||||
"minimum": self.minimum,
|
||||
"maximum": self.maximum,
|
||||
"default": self.default,
|
||||
**super().get_template_context()
|
||||
}
|
||||
|
||||
|
@ -40,7 +40,7 @@ class Interface:
|
||||
"""
|
||||
def get_input_instance(iface):
|
||||
if isinstance(iface, str):
|
||||
return gradio.inputs.shortcuts[iface]
|
||||
return gradio.inputs.shortcuts[iface.lower()]
|
||||
elif isinstance(iface, gradio.inputs.AbstractInput):
|
||||
return iface
|
||||
else:
|
||||
@ -49,7 +49,7 @@ class Interface:
|
||||
|
||||
def get_output_instance(iface):
|
||||
if isinstance(iface, str):
|
||||
return gradio.outputs.shortcuts[iface]
|
||||
return gradio.outputs.shortcuts[iface.lower()]
|
||||
elif isinstance(iface, gradio.outputs.AbstractOutput):
|
||||
return iface
|
||||
else:
|
||||
@ -115,8 +115,10 @@ class Interface:
|
||||
raw_input[i]) for i, input_interface in
|
||||
enumerate(self.input_interfaces)]
|
||||
predictions = []
|
||||
durations = []
|
||||
for predict_fn in self.predict:
|
||||
if self.capture_session:
|
||||
start = time.time()
|
||||
if self.capture_session and not(self.session is None):
|
||||
graph, sess = self.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
@ -135,14 +137,16 @@ class Interface:
|
||||
"error.")
|
||||
else:
|
||||
raise exception
|
||||
duration = time.time() - start
|
||||
|
||||
if len(self.output_interfaces) / \
|
||||
len(self.predict) == 1:
|
||||
prediction = [prediction]
|
||||
durations.append(duration)
|
||||
predictions.extend(prediction)
|
||||
processed_output = [output_interface.postprocess(
|
||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||
return processed_output
|
||||
return processed_output, durations
|
||||
|
||||
def validate(self):
|
||||
if self.validate_flag:
|
||||
@ -208,9 +212,12 @@ class Interface:
|
||||
# self.validate()
|
||||
|
||||
if self.capture_session:
|
||||
import tensorflow as tf
|
||||
self.session = tf.get_default_graph(), \
|
||||
tf.keras.backend.get_session()
|
||||
try:
|
||||
import tensorflow as tf
|
||||
self.session = tf.get_default_graph(), \
|
||||
tf.keras.backend.get_session()
|
||||
except (ImportError, AttributeError): # If they are using TF >= 2.0 or don't have TF, just ignore this.
|
||||
pass
|
||||
|
||||
# If an existing interface is running with this instance, close it.
|
||||
if self.status == "RUNNING":
|
||||
@ -257,7 +264,7 @@ class Interface:
|
||||
if share:
|
||||
try:
|
||||
share_url = networking.setup_tunnel(server_port)
|
||||
print(share_url)
|
||||
print("External URL:", share_url)
|
||||
except RuntimeError:
|
||||
share_url = None
|
||||
if self.verbose:
|
||||
@ -299,6 +306,7 @@ class Interface:
|
||||
is_colab
|
||||
): # Embed the remote interface page if on google colab;
|
||||
# otherwise, embed the local page.
|
||||
time.sleep(1)
|
||||
display(IFrame(share_url, width=1000, height=500))
|
||||
else:
|
||||
display(IFrame(path_to_local_server, width=1000, height=500))
|
||||
|
@ -139,7 +139,9 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
raw_input = msg["data"]
|
||||
output = {"data": interface.process(raw_input)}
|
||||
prediction, durations = interface.process(raw_input)
|
||||
|
||||
output = {"data": prediction, "durations": durations}
|
||||
if interface.saliency is not None:
|
||||
saliency = interface.saliency(raw_input, prediction)
|
||||
output['saliency'] = saliency.tolist()
|
||||
|
@ -13,6 +13,14 @@ button, input[type="submit"], input[type="reset"], input[type="text"], input[typ
|
||||
-webkit-appearance: none;
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
.loading_time {
|
||||
font-size: large;
|
||||
color: #EEA45D;
|
||||
text-align: right;
|
||||
padding-top: 5px;
|
||||
}
|
||||
|
||||
nav {
|
||||
text-align: center;
|
||||
padding: 16px 0 4px;
|
||||
|
@ -32,16 +32,18 @@ var io_master_template = {
|
||||
this.target.find(".loading_failed").show();
|
||||
})
|
||||
},
|
||||
output: function(data) {
|
||||
output: function(data) {
|
||||
this.last_output = data["data"];
|
||||
|
||||
for (let i = 0; i < this.output_interfaces.length; i++) {
|
||||
this.output_interfaces[i].output(data["data"][i]);
|
||||
this.output_interfaces[i].target.parent().find(`.loading_time[interface="${i}"]`).text("Latency: " + ((data["durations"][i])).toFixed(2) + "s");
|
||||
}
|
||||
if (this.config.live) {
|
||||
this.gather();
|
||||
} else {
|
||||
this.target.find(".loading").addClass("invisible");
|
||||
this.target.find(".output_interface").removeClass("invisible");
|
||||
this.target.find(".output_interface").removeClass("invisible");
|
||||
this.target.find(".output_interfaces .panel_header").removeClass("invisible");
|
||||
}
|
||||
},
|
||||
|
@ -91,6 +91,9 @@ function gradio(config, fn, target) {
|
||||
${output_interface.html}
|
||||
</div>
|
||||
`);
|
||||
target.find(".output_interfaces").append(`
|
||||
<div class="loading_time" interface="${i}"> </div>
|
||||
`);
|
||||
output_interface.target = target.find(`.output_interface[interface_id=${_id}]`);
|
||||
set_interface_id(output_interface, _id);
|
||||
output_interface.io_master = io_master;
|
||||
@ -110,6 +113,7 @@ function gradio(config, fn, target) {
|
||||
target.find(".flag").removeClass("flagged");
|
||||
target.find(".flag_message").empty();
|
||||
target.find(".loading").addClass("invisible");
|
||||
target.find(".loading_time").text("");
|
||||
target.find(".output_interface").removeClass("invisible");
|
||||
io_master.last_input = null;
|
||||
io_master.last_output = null;
|
||||
|
@ -12,6 +12,7 @@ const slider = {
|
||||
this.target.css("height", "auto");
|
||||
this.target.find(".min").text(opts.minimum);
|
||||
this.target.find(".max").text(opts.maximum);
|
||||
this.target.find(".value").text(opts.default);
|
||||
let difference = opts.maximum - opts.minimum;
|
||||
if (difference <= 1) {
|
||||
step = 0.01;
|
||||
@ -23,6 +24,7 @@ const slider = {
|
||||
this.target.find(".slider")
|
||||
.attr("min", opts.minimum)
|
||||
.attr("max", opts.maximum)
|
||||
.attr("value", opts.default)
|
||||
.attr("step", step)
|
||||
.on("change", function() {
|
||||
io.target.find(".value").text($(this).val());
|
||||
@ -33,7 +35,7 @@ const slider = {
|
||||
this.io_master.input(this.id, parseFloat(value));
|
||||
},
|
||||
clear: function() {
|
||||
this.target.find("input").val(this.minimum);
|
||||
this.target.find("input").val(this.default);
|
||||
},
|
||||
load_example: function(data) {
|
||||
this.target.find("input").val(data);
|
||||
|
2
setup.py
2
setup.py
@ -5,7 +5,7 @@ except ImportError:
|
||||
|
||||
setup(
|
||||
name='gradio',
|
||||
version='0.9.4',
|
||||
version='0.9.5',
|
||||
include_package_data=True,
|
||||
description='Python library for easily interacting with trained machine learning models',
|
||||
author='Abubakar Abid',
|
||||
|
@ -11,7 +11,7 @@ PACKAGE_NAME = 'gradio'
|
||||
class TestSketchpad(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Sketchpad()
|
||||
path = inputs.BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__)
|
||||
path = inputs.BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__.lower())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
@ -23,7 +23,7 @@ class TestSketchpad(unittest.TestCase):
|
||||
class TestWebcam(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Webcam()
|
||||
path = inputs.BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__)
|
||||
path = inputs.BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__.lower())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
@ -35,7 +35,8 @@ class TestWebcam(unittest.TestCase):
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Textbox()
|
||||
path = inputs.BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__)
|
||||
path = inputs.BASE_INPUT_INTERFACE_JS_PATH.format(
|
||||
inp.__class__.__name__.lower())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
@ -47,7 +48,7 @@ class TestTextbox(unittest.TestCase):
|
||||
class TestImageUpload(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Image()
|
||||
path = inputs.BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__)
|
||||
path = inputs.BASE_INPUT_INTERFACE_JS_PATH.format(inp.__class__.__name__.lower())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
|
@ -6,7 +6,6 @@ import socket
|
||||
import tempfile
|
||||
import os
|
||||
import json
|
||||
LOCALHOST_NAME = 'localhost'
|
||||
|
||||
|
||||
class TestGetAvailablePort(unittest.TestCase):
|
||||
@ -17,7 +16,7 @@ class TestGetAvailablePort(unittest.TestCase):
|
||||
for port in range(initial, final):
|
||||
try:
|
||||
s = socket.socket() # create a socket object
|
||||
s.bind((LOCALHOST_NAME, port)) # Bind to the port
|
||||
s.bind((networking.LOCALHOST_NAME, port)) # Bind to the port
|
||||
s.close()
|
||||
port_found = True
|
||||
break
|
||||
@ -25,7 +24,7 @@ class TestGetAvailablePort(unittest.TestCase):
|
||||
pass
|
||||
if port_found:
|
||||
s = socket.socket() # create a socket object
|
||||
s.bind((LOCALHOST_NAME, port)) # Bind to the port
|
||||
s.bind((networking.LOCALHOST_NAME, port)) # Bind to the port
|
||||
new_port = networking.get_first_available_port(initial, final)
|
||||
s.close()
|
||||
self.assertFalse(port==new_port)
|
||||
|
@ -11,7 +11,7 @@ BASE64_IMG = "data:image/jpeg;base64,/9j/4AAQSkZJRgABAQEASABIAAD/2wBDAAYEBQYFBAY
|
||||
class TestLabel(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
out = outputs.Label()
|
||||
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__)
|
||||
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__.lower())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
# def test_postprocessing_string(self):
|
||||
@ -50,7 +50,7 @@ class TestLabel(unittest.TestCase):
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
out = outputs.Textbox()
|
||||
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__)
|
||||
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__.lower())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_postprocessing(self):
|
||||
@ -63,7 +63,7 @@ class TestTextbox(unittest.TestCase):
|
||||
class TestImage(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
out = outputs.Image()
|
||||
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__qualname__)
|
||||
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__qualname__.lower())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_postprocessing(self):
|
||||
|
Loading…
Reference in New Issue
Block a user