mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-12 10:34:32 +08:00
more changes
This commit is contained in:
parent
55f0bef538
commit
cf877ba744
@ -311,9 +311,10 @@ class ImageIn(AbstractInput):
|
||||
im = np.array(im).flatten()
|
||||
im = im * self.scale + self.shift
|
||||
if self.num_channels is None:
|
||||
array = im.reshape(self.image_width, self.image_height)
|
||||
array = im.reshape(1, self.image_width, self.image_height)
|
||||
else:
|
||||
array = im.reshape(self.image_width, self.image_height, self.num_channels)
|
||||
array = im.reshape(1, self.image_width, self.image_height, \
|
||||
self.num_channels)
|
||||
return array
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
|
@ -16,6 +16,7 @@ import requests
|
||||
import random
|
||||
import time
|
||||
from IPython import get_ipython
|
||||
import tensorflow as tf
|
||||
|
||||
LOCALHOST_IP = "0.0.0.0"
|
||||
TRY_NUM_PORTS = 100
|
||||
@ -29,8 +30,9 @@ class Interface:
|
||||
"""
|
||||
|
||||
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False,
|
||||
live=False, show_input=True, show_output=True,
|
||||
load_fn=None, server_name=LOCALHOST_IP):
|
||||
live=False, show_input=True, show_output=True,
|
||||
load_fn=None, capture_session=False,
|
||||
server_name=LOCALHOST_IP):
|
||||
"""
|
||||
:param fn: a function that will process the input panel data from the interface and return the output panel data.
|
||||
:param inputs: a string or `AbstractInput` representing the input interface.
|
||||
@ -42,7 +44,9 @@ class Interface:
|
||||
elif isinstance(iface, gradio.inputs.AbstractInput):
|
||||
return iface
|
||||
else:
|
||||
raise ValueError("Input interface must be of type `str` or `AbstractInput`")
|
||||
raise ValueError("Input interface must be of type `str` or "
|
||||
"`AbstractInput`")
|
||||
|
||||
def get_output_instance(iface):
|
||||
if isinstance(iface, str):
|
||||
return gradio.outputs.shortcuts[iface]
|
||||
@ -50,7 +54,8 @@ class Interface:
|
||||
return iface
|
||||
else:
|
||||
raise ValueError(
|
||||
"Output interface must be of type `str` or `AbstractOutput`"
|
||||
"Output interface must be of type `str` or "
|
||||
"`AbstractOutput`"
|
||||
)
|
||||
if isinstance(inputs, list):
|
||||
self.input_interfaces = [get_input_instance(i) for i in inputs]
|
||||
@ -73,6 +78,8 @@ class Interface:
|
||||
self.show_input = show_input
|
||||
self.show_output = show_output
|
||||
self.flag_hash = random.getrandbits(32)
|
||||
self.capture_session = capture_session
|
||||
self.session = None
|
||||
self.server_name = server_name
|
||||
|
||||
def update_config_file(self, output_directory):
|
||||
@ -155,6 +162,10 @@ class Interface:
|
||||
context = self.load_fn() if self.load_fn else None
|
||||
self.context = context
|
||||
|
||||
if self.capture_session:
|
||||
self.session = tf.get_default_graph(), \
|
||||
tf.keras.backend.get_session()
|
||||
|
||||
# If an existing interface is running with this instance, close it.
|
||||
if self.status == "RUNNING":
|
||||
if self.verbose:
|
||||
@ -241,7 +252,8 @@ class Interface:
|
||||
|
||||
if (
|
||||
is_colab
|
||||
): # Embed the remote interface page if on google colab; otherwise, embed the local page.
|
||||
): # Embed the remote interface page if on google colab;
|
||||
# otherwise, embed the local page.
|
||||
display(IFrame(share_url, width=1000, height=500))
|
||||
else:
|
||||
display(IFrame(path_to_local_server, width=1000, height=500))
|
||||
|
@ -140,11 +140,25 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
predictions = []
|
||||
for predict_fn in interface.predict:
|
||||
if interface.context:
|
||||
prediction = predict_fn(*processed_input,
|
||||
interface.context)
|
||||
if interface.capture_session:
|
||||
graph, sess = interface.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
prediction = predict_fn(*processed_input,
|
||||
interface.context)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input,
|
||||
interface.context)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input)
|
||||
if len(interface.output_interfaces) / len(interface.predict) == 1:
|
||||
if interface.capture_session:
|
||||
graph, sess = interface.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
prediction = predict_fn(*processed_input)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input)
|
||||
if len(interface.output_interfaces) / \
|
||||
len(interface.predict) == 1:
|
||||
prediction = [prediction]
|
||||
predictions.extend(prediction)
|
||||
processed_output = [output_interface.postprocess(predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]
|
||||
|
@ -19,7 +19,7 @@ function gradio(config, fn, target) {
|
||||
</div>
|
||||
</div>
|
||||
</div>`);
|
||||
let io_master = Object.create(io_master_template);
|
||||
io_master = Object.create(io_master_template);
|
||||
io_master.fn = fn
|
||||
io_master.target = target;
|
||||
io_master.config = config;
|
||||
@ -68,9 +68,9 @@ function gradio(config, fn, target) {
|
||||
`);
|
||||
input_interface.target = target.find(`.input_interface[interface_id=${_id}]`);
|
||||
set_interface_id(input_interface, _id);
|
||||
input_interface.io_master = io_master;
|
||||
input_interface.init(input_interface_data[1]);
|
||||
input_interfaces.push(input_interface);
|
||||
input_interface.io_master = io_master;
|
||||
_id++;
|
||||
}
|
||||
for (let i = 0; i < config["output_interfaces"].length; i++) {
|
||||
@ -92,9 +92,9 @@ function gradio(config, fn, target) {
|
||||
`);
|
||||
output_interface.target = target.find(`.output_interface[interface_id=${_id}]`);
|
||||
set_interface_id(output_interface, _id);
|
||||
output_interface.io_master = io_master;
|
||||
output_interface.init(output_interface_data[1]);
|
||||
output_interfaces.push(output_interface);
|
||||
output_interface.io_master = io_master;
|
||||
_id++;
|
||||
}
|
||||
io_master.input_interfaces = input_interfaces;
|
||||
|
@ -6,40 +6,21 @@ const sketchpad_input = {
|
||||
<div id="brush_3" size="24" class="brush"></div>
|
||||
</div>
|
||||
<div class="view_holders">
|
||||
<div class="saliency_holder hide">
|
||||
<canvas class="saliency"></canvas>
|
||||
</div>
|
||||
<div class="canvas_holder">
|
||||
<canvas class="sketch"></canvas>
|
||||
</div>
|
||||
</div>`,
|
||||
disabled_html: `
|
||||
<div class="view_holders">
|
||||
<div class="saliency_holder hide">
|
||||
<canvas class="saliency"></canvas>
|
||||
</div>
|
||||
<div class="canvas_holder">
|
||||
<canvas></canvas>
|
||||
</div>
|
||||
</div>`,
|
||||
init: function() {
|
||||
var io = this;
|
||||
var dimension = Math.min(this.target.find(".canvas_holder").width(),
|
||||
this.target.find(".canvas_holder").height()) - 2 // dimension - border
|
||||
var id = this.id;
|
||||
if (this.io_master.config.disabled) {
|
||||
this.target.find('.canvas_holder canvas')
|
||||
.attr("width", dimension).attr("height", dimension);
|
||||
} else {
|
||||
this.sketchpad = new Sketchpad({
|
||||
element: '.interface[interface_id=' + id + '] .sketch',
|
||||
width: dimension,
|
||||
height: dimension
|
||||
});
|
||||
this.sketchpad.penSize = this.target.find(".brush.selected").attr("size");
|
||||
}
|
||||
this.target.find(".saliency")
|
||||
.attr("width", dimension+"px").attr("height", dimension+"px");
|
||||
this.sketchpad = new Sketchpad({
|
||||
element: '.interface[interface_id=' + id + '] .sketch',
|
||||
width: dimension,
|
||||
height: dimension
|
||||
});
|
||||
this.sketchpad.penSize = this.target.find(".brush.selected").attr("size");
|
||||
this.canvas = this.target.find('.canvas_holder canvas')[0];
|
||||
this.context = this.canvas.getContext("2d");
|
||||
this.target.find(".brush").click(function (e) {
|
||||
@ -52,17 +33,9 @@ const sketchpad_input = {
|
||||
var dataURL = this.canvas.toDataURL("image/png");
|
||||
this.io_master.input(this.id, dataURL);
|
||||
},
|
||||
output: function(data) {
|
||||
this.target.find(".saliency_holder").removeClass("hide");
|
||||
let ctx = this.target.find(".saliency")[0].getContext('2d');
|
||||
let dimension = this.target.find(".saliency").width();
|
||||
ctx.clearRect(0,0,dimension,dimension);
|
||||
paintSaliency(data, dimension, dimension, ctx);
|
||||
},
|
||||
clear: function() {
|
||||
this.context.clearRect(0, 0, this.context.canvas.width, this.context.
|
||||
canvas.height);
|
||||
this.target.find(".saliency_holder").addClass("hide");
|
||||
},
|
||||
renderFeatured: function(data) {
|
||||
return `<img src=${data}>`;
|
||||
|
@ -84,8 +84,8 @@
|
||||
$.ajax({type: "POST",
|
||||
url: "/api/predict/",
|
||||
data: JSON.stringify({"data": data}),
|
||||
success: (data) => {console.log("y"); resolve(data)},
|
||||
error: (data) => {console.log("n"); reject()},
|
||||
success: resolve,
|
||||
error: reject,
|
||||
});
|
||||
});
|
||||
}, "#interface_target");
|
||||
|
@ -43,7 +43,8 @@ def build_template(temp_dir):
|
||||
:param temp_dir: string with path to temp directory in which the html file should be built
|
||||
"""
|
||||
dir_util.copy_tree(STATIC_TEMPLATE_LIB, temp_dir)
|
||||
dir_util.copy_tree(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
|
||||
dir_util.copy_tree(STATIC_PATH_LIB, os.path.join(
|
||||
temp_dir, STATIC_PATH_TEMP))
|
||||
|
||||
# Move association file to root of temporary directory.
|
||||
copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC),
|
||||
@ -81,6 +82,7 @@ def render_string_or_list_with_tags(old_lines, context):
|
||||
new_lines.append(line)
|
||||
return new_lines
|
||||
|
||||
|
||||
def set_config(config, temp_dir):
|
||||
config_file = os.path.join(temp_dir, CONFIG_FILE)
|
||||
with open(config_file, "w") as output:
|
||||
@ -133,10 +135,12 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
if self.path == "/api/predict/":
|
||||
# Make the prediction.
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
data_string = self.rfile.read(
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
raw_input = msg["data"]
|
||||
processed_input = [input_interface.preprocess(raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)]
|
||||
processed_input = [input_interface.preprocess(
|
||||
raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)]
|
||||
predictions = []
|
||||
for predict_fn in interface.predict:
|
||||
if interface.context:
|
||||
@ -161,7 +165,8 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
len(interface.predict) == 1:
|
||||
prediction = [prediction]
|
||||
predictions.extend(prediction)
|
||||
processed_output = [output_interface.postprocess(predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]
|
||||
processed_output = [output_interface.postprocess(
|
||||
predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]
|
||||
output = {"action": "output", "data": processed_output}
|
||||
if interface.saliency is not None:
|
||||
saliency = interface.saliency(raw_input, prediction)
|
||||
@ -182,34 +187,37 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
|
||||
elif self.path == "/api/flag/":
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
data_string = self.rfile.read(
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
flag_dir = os.path.join(FLAGGING_DIRECTORY,
|
||||
str(interface.flag_hash))
|
||||
os.makedirs(flag_dir, exist_ok=True)
|
||||
output = {'inputs': [interface.input_interfaces[
|
||||
i].rebuild_flagged(
|
||||
flag_dir, msg['data']['input_data']) for i
|
||||
flag_dir, msg['data']['input_data']) for i
|
||||
in range(len(interface.input_interfaces))],
|
||||
'outputs': [interface.output_interfaces[
|
||||
i].rebuild_flagged(
|
||||
flag_dir, msg['data']['output_data']) for i
|
||||
'outputs': [interface.output_interfaces[
|
||||
i].rebuild_flagged(
|
||||
flag_dir, msg['data']['output_data']) for i
|
||||
in range(len(interface.output_interfaces))],
|
||||
'message': msg['data']['message']}
|
||||
'message': msg['data']['message']}
|
||||
|
||||
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
|
||||
f.write(json.dumps(output))
|
||||
f.write("\n")
|
||||
|
||||
#TODO(abidlabs): clean this up
|
||||
# TODO(abidlabs): clean this up
|
||||
elif self.path == "/api/auto/rotation":
|
||||
from gradio import validation_data, preprocessing_utils
|
||||
import numpy as np
|
||||
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
data_string = self.rfile.read(
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
|
||||
img_orig = preprocessing_utils.decode_base64_to_image(
|
||||
msg["data"])
|
||||
img_orig = img_orig.convert('RGB')
|
||||
img_orig = img_orig.resize((224, 224))
|
||||
|
||||
@ -219,8 +227,10 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
for deg in range(-180, 180+45, 45):
|
||||
img = img_orig.rotate(deg)
|
||||
img_array = np.array(img) / 127.5 - 1
|
||||
prediction = interface.predict(np.expand_dims(img_array, axis=0))
|
||||
processed_output = interface.output_interface.postprocess(prediction)
|
||||
prediction = interface.predict(
|
||||
np.expand_dims(img_array, axis=0))
|
||||
processed_output = interface.output_interface.postprocess(
|
||||
prediction)
|
||||
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
|
||||
'output': interface.output_interface.rebuild_flagged(
|
||||
flag_dir, {'data': {'output': processed_output}}),
|
||||
@ -240,9 +250,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
from PIL import ImageEnhance
|
||||
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
data_string = self.rfile.read(
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
|
||||
img_orig = preprocessing_utils.decode_base64_to_image(
|
||||
msg["data"])
|
||||
img_orig = img_orig.convert('RGB')
|
||||
img_orig = img_orig.resize((224, 224))
|
||||
enhancer = ImageEnhance.Brightness(img_orig)
|
||||
@ -253,8 +265,10 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
for i in range(9):
|
||||
img = enhancer.enhance(i/4)
|
||||
img_array = np.array(img) / 127.5 - 1
|
||||
prediction = interface.predict(np.expand_dims(img_array, axis=0))
|
||||
processed_output = interface.output_interface.postprocess(prediction)
|
||||
prediction = interface.predict(
|
||||
np.expand_dims(img_array, axis=0))
|
||||
processed_output = interface.output_interface.postprocess(
|
||||
prediction)
|
||||
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
|
||||
'output': interface.output_interface.rebuild_flagged(
|
||||
flag_dir, {'data': {'output': processed_output}}),
|
||||
@ -299,7 +313,8 @@ def start_simple_server(interface, directory_to_serve=None, server_name=None):
|
||||
port = get_first_available_port(
|
||||
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
|
||||
)
|
||||
httpd = serve_files_in_background(interface, port, directory_to_serve, server_name)
|
||||
httpd = serve_files_in_background(
|
||||
interface, port, directory_to_serve, server_name)
|
||||
return port, httpd
|
||||
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
align-items: center;
|
||||
font-size: 18px;
|
||||
}
|
||||
.checkbox_group input {
|
||||
.checkbox_group input, .checkbox {
|
||||
margin: 0px 4px 0px 0px;
|
||||
}
|
||||
.checkbox_group label {
|
||||
|
@ -1,6 +1,5 @@
|
||||
body {
|
||||
font-family: 'Open Sans', sans-serif;
|
||||
font-size: 12px;
|
||||
margin: 0;
|
||||
}
|
||||
button, input[type="submit"], input[type="reset"], input[type="text"], input[type="button"], select[type="submit"] {
|
||||
|
@ -19,7 +19,7 @@ function gradio(config, fn, target) {
|
||||
</div>
|
||||
</div>
|
||||
</div>`);
|
||||
let io_master = Object.create(io_master_template);
|
||||
io_master = Object.create(io_master_template);
|
||||
io_master.fn = fn
|
||||
io_master.target = target;
|
||||
io_master.config = config;
|
||||
@ -68,9 +68,9 @@ function gradio(config, fn, target) {
|
||||
`);
|
||||
input_interface.target = target.find(`.input_interface[interface_id=${_id}]`);
|
||||
set_interface_id(input_interface, _id);
|
||||
input_interface.io_master = io_master;
|
||||
input_interface.init(input_interface_data[1]);
|
||||
input_interfaces.push(input_interface);
|
||||
input_interface.io_master = io_master;
|
||||
_id++;
|
||||
}
|
||||
for (let i = 0; i < config["output_interfaces"].length; i++) {
|
||||
@ -92,9 +92,9 @@ function gradio(config, fn, target) {
|
||||
`);
|
||||
output_interface.target = target.find(`.output_interface[interface_id=${_id}]`);
|
||||
set_interface_id(output_interface, _id);
|
||||
output_interface.io_master = io_master;
|
||||
output_interface.init(output_interface_data[1]);
|
||||
output_interfaces.push(output_interface);
|
||||
output_interface.io_master = io_master;
|
||||
_id++;
|
||||
}
|
||||
io_master.input_interfaces = input_interfaces;
|
||||
|
@ -1,5 +1,5 @@
|
||||
const checkbox = {
|
||||
html: `<input type="checkbox">`,
|
||||
html: `<input class="checkbox" type="checkbox">`,
|
||||
init: function(opts) {
|
||||
this.target.css("height", "auto");
|
||||
},
|
||||
|
@ -6,40 +6,21 @@ const sketchpad_input = {
|
||||
<div id="brush_3" size="24" class="brush"></div>
|
||||
</div>
|
||||
<div class="view_holders">
|
||||
<div class="saliency_holder hide">
|
||||
<canvas class="saliency"></canvas>
|
||||
</div>
|
||||
<div class="canvas_holder">
|
||||
<canvas class="sketch"></canvas>
|
||||
</div>
|
||||
</div>`,
|
||||
disabled_html: `
|
||||
<div class="view_holders">
|
||||
<div class="saliency_holder hide">
|
||||
<canvas class="saliency"></canvas>
|
||||
</div>
|
||||
<div class="canvas_holder">
|
||||
<canvas></canvas>
|
||||
</div>
|
||||
</div>`,
|
||||
init: function() {
|
||||
var io = this;
|
||||
var dimension = Math.min(this.target.find(".canvas_holder").width(),
|
||||
this.target.find(".canvas_holder").height()) - 2 // dimension - border
|
||||
var id = this.id;
|
||||
if (this.io_master.config.disabled) {
|
||||
this.target.find('.canvas_holder canvas')
|
||||
.attr("width", dimension).attr("height", dimension);
|
||||
} else {
|
||||
this.sketchpad = new Sketchpad({
|
||||
element: '.interface[interface_id=' + id + '] .sketch',
|
||||
width: dimension,
|
||||
height: dimension
|
||||
});
|
||||
this.sketchpad.penSize = this.target.find(".brush.selected").attr("size");
|
||||
}
|
||||
this.target.find(".saliency")
|
||||
.attr("width", dimension+"px").attr("height", dimension+"px");
|
||||
this.sketchpad = new Sketchpad({
|
||||
element: '.interface[interface_id=' + id + '] .sketch',
|
||||
width: dimension,
|
||||
height: dimension
|
||||
});
|
||||
this.sketchpad.penSize = this.target.find(".brush.selected").attr("size");
|
||||
this.canvas = this.target.find('.canvas_holder canvas')[0];
|
||||
this.context = this.canvas.getContext("2d");
|
||||
this.target.find(".brush").click(function (e) {
|
||||
@ -52,17 +33,9 @@ const sketchpad_input = {
|
||||
var dataURL = this.canvas.toDataURL("image/png");
|
||||
this.io_master.input(this.id, dataURL);
|
||||
},
|
||||
output: function(data) {
|
||||
this.target.find(".saliency_holder").removeClass("hide");
|
||||
let ctx = this.target.find(".saliency")[0].getContext('2d');
|
||||
let dimension = this.target.find(".saliency").width();
|
||||
ctx.clearRect(0,0,dimension,dimension);
|
||||
paintSaliency(data, dimension, dimension, ctx);
|
||||
},
|
||||
clear: function() {
|
||||
this.context.clearRect(0, 0, this.context.canvas.width, this.context.
|
||||
canvas.height);
|
||||
this.target.find(".saliency_holder").addClass("hide");
|
||||
},
|
||||
renderFeatured: function(data) {
|
||||
return `<img src=${data}>`;
|
||||
|
@ -84,8 +84,8 @@
|
||||
$.ajax({type: "POST",
|
||||
url: "/api/predict/",
|
||||
data: JSON.stringify({"data": data}),
|
||||
success: (data) => {console.log("y"); resolve(data)},
|
||||
error: (data) => {console.log("n"); reject()},
|
||||
success: resolve,
|
||||
error: reject,
|
||||
});
|
||||
});
|
||||
}, "#interface_target");
|
||||
|
Loading…
Reference in New Issue
Block a user