mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-12 10:34:32 +08:00
Merge branch 'master' of https://github.com/gradio-app/gradio
This commit is contained in:
commit
2497a2e2a8
@ -66,7 +66,8 @@ class AbstractInput(ABC):
|
||||
|
||||
|
||||
class Sketchpad(AbstractInput):
|
||||
def __init__(self, shape=(28, 28), invert_colors=True, flatten=False, scale=1/255, shift=0,
|
||||
def __init__(self, cast_to="numpy", shape=(28, 28), invert_colors=True,
|
||||
flatten=False, scale=1/255, shift=0,
|
||||
dtype='float64', sample_inputs=None, label=None):
|
||||
self.image_width = shape[0]
|
||||
self.image_height = shape[1]
|
||||
@ -272,8 +273,9 @@ class Checkbox(AbstractInput):
|
||||
|
||||
|
||||
class ImageIn(AbstractInput):
|
||||
def __init__(self, shape=(224, 224, 3), image_mode='RGB',
|
||||
def __init__(self, cast_to=None, shape=(224, 224, 3), image_mode='RGB',
|
||||
scale=1/127.5, shift=-1, cropper_aspect_ratio=None, label=None):
|
||||
self.cast_to = cast_to
|
||||
self.image_width = shape[0]
|
||||
self.image_height = shape[1]
|
||||
self.num_channels = shape[2]
|
||||
@ -298,10 +300,29 @@ class ImageIn(AbstractInput):
|
||||
**super().get_template_context()
|
||||
}
|
||||
|
||||
def cast_to_base64(self, inp):
|
||||
return inp
|
||||
|
||||
def cast_to_im(self, inp):
|
||||
return preprocessing_utils.decode_base64_to_image(inp)
|
||||
|
||||
def cast_to_numpy(self, inp):
|
||||
im = self.cast_to_im(inp)
|
||||
arr = np.array(im).flatten()
|
||||
return arr
|
||||
|
||||
def preprocess(self, inp):
|
||||
"""
|
||||
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
|
||||
"""
|
||||
cast_to_type = {
|
||||
"base64": self.cast_to_base64,
|
||||
"numpy": self.cast_to_numpy,
|
||||
"pillow": self.cast_to_im
|
||||
}
|
||||
if self.cast_to:
|
||||
return cast_to_type[self.cast_to](inp)
|
||||
|
||||
im = preprocessing_utils.decode_base64_to_image(inp)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
@ -82,8 +82,8 @@ class Interface:
|
||||
self.session = None
|
||||
self.server_name = server_name
|
||||
|
||||
def update_config_file(self, output_directory):
|
||||
config = {
|
||||
def get_config_file(self):
|
||||
return {
|
||||
"input_interfaces": [
|
||||
(iface.__class__.__name__.lower(), iface.get_template_context())
|
||||
for iface in self.input_interfaces],
|
||||
@ -95,7 +95,38 @@ class Interface:
|
||||
"show_input": self.show_input,
|
||||
"show_output": self.show_output,
|
||||
}
|
||||
networking.set_config(config, output_directory)
|
||||
|
||||
def process(self, raw_input):
|
||||
processed_input = [input_interface.preprocess(
|
||||
raw_input[i]) for i, input_interface in enumerate(self.input_interfaces)]
|
||||
predictions = []
|
||||
for predict_fn in self.predict:
|
||||
if self.context:
|
||||
if self.capture_session:
|
||||
graph, sess = self.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
prediction = predict_fn(*processed_input,
|
||||
self.context)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input,
|
||||
self.context)
|
||||
else:
|
||||
if self.capture_session:
|
||||
graph, sess = self.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
prediction = predict_fn(*processed_input)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input)
|
||||
if len(self.output_interfaces) / \
|
||||
len(self.predict) == 1:
|
||||
prediction = [prediction]
|
||||
predictions.extend(prediction)
|
||||
processed_output = [output_interface.postprocess(
|
||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||
return processed_output
|
||||
|
||||
|
||||
def validate(self):
|
||||
if self.validate_flag:
|
||||
@ -181,8 +212,7 @@ class Interface:
|
||||
server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name)
|
||||
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
|
||||
networking.build_template(output_directory)
|
||||
|
||||
self.update_config_file(output_directory)
|
||||
networking.set_config(self.get_config_file(), output_directory)
|
||||
|
||||
self.status = "RUNNING"
|
||||
self.simple_server = httpd
|
||||
|
@ -43,7 +43,8 @@ def build_template(temp_dir):
|
||||
:param temp_dir: string with path to temp directory in which the html file should be built
|
||||
"""
|
||||
dir_util.copy_tree(STATIC_TEMPLATE_LIB, temp_dir)
|
||||
dir_util.copy_tree(STATIC_PATH_LIB, os.path.join(temp_dir, STATIC_PATH_TEMP))
|
||||
dir_util.copy_tree(STATIC_PATH_LIB, os.path.join(
|
||||
temp_dir, STATIC_PATH_TEMP))
|
||||
|
||||
# Move association file to root of temporary directory.
|
||||
copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC),
|
||||
@ -81,6 +82,7 @@ def render_string_or_list_with_tags(old_lines, context):
|
||||
new_lines.append(line)
|
||||
return new_lines
|
||||
|
||||
|
||||
def set_config(config, temp_dir):
|
||||
config_file = os.path.join(temp_dir, CONFIG_FILE)
|
||||
with open(config_file, "w") as output:
|
||||
@ -133,36 +135,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
if self.path == "/api/predict/":
|
||||
# Make the prediction.
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
data_string = self.rfile.read(
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
raw_input = msg["data"]
|
||||
processed_input = [input_interface.preprocess(raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)]
|
||||
predictions = []
|
||||
for predict_fn in interface.predict:
|
||||
if interface.context:
|
||||
if interface.capture_session:
|
||||
graph, sess = interface.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
prediction = predict_fn(*processed_input,
|
||||
interface.context)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input,
|
||||
interface.context)
|
||||
else:
|
||||
if interface.capture_session:
|
||||
graph, sess = interface.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
prediction = predict_fn(*processed_input)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input)
|
||||
if len(interface.output_interfaces) / \
|
||||
len(interface.predict) == 1:
|
||||
prediction = [prediction]
|
||||
predictions.extend(prediction)
|
||||
processed_output = [output_interface.postprocess(predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]
|
||||
output = {"action": "output", "data": processed_output}
|
||||
output = {"action": "output", "data": interface.process(raw_input)}
|
||||
if interface.saliency is not None:
|
||||
saliency = interface.saliency(raw_input, prediction)
|
||||
output['saliency'] = saliency.tolist()
|
||||
@ -182,7 +159,8 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
|
||||
elif self.path == "/api/flag/":
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
data_string = self.rfile.read(
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
flag_dir = os.path.join(FLAGGING_DIRECTORY,
|
||||
str(interface.flag_hash))
|
||||
@ -207,9 +185,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
import numpy as np
|
||||
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
data_string = self.rfile.read(
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
|
||||
img_orig = preprocessing_utils.decode_base64_to_image(
|
||||
msg["data"])
|
||||
img_orig = img_orig.convert('RGB')
|
||||
img_orig = img_orig.resize((224, 224))
|
||||
|
||||
@ -219,8 +199,10 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
for deg in range(-180, 180+45, 45):
|
||||
img = img_orig.rotate(deg)
|
||||
img_array = np.array(img) / 127.5 - 1
|
||||
prediction = interface.predict(np.expand_dims(img_array, axis=0))
|
||||
processed_output = interface.output_interface.postprocess(prediction)
|
||||
prediction = interface.predict(
|
||||
np.expand_dims(img_array, axis=0))
|
||||
processed_output = interface.output_interface.postprocess(
|
||||
prediction)
|
||||
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
|
||||
'output': interface.output_interface.rebuild_flagged(
|
||||
flag_dir, {'data': {'output': processed_output}}),
|
||||
@ -240,9 +222,11 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
from PIL import ImageEnhance
|
||||
|
||||
self._set_headers()
|
||||
data_string = self.rfile.read(int(self.headers["Content-Length"]))
|
||||
data_string = self.rfile.read(
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
img_orig = preprocessing_utils.decode_base64_to_image(msg["data"])
|
||||
img_orig = preprocessing_utils.decode_base64_to_image(
|
||||
msg["data"])
|
||||
img_orig = img_orig.convert('RGB')
|
||||
img_orig = img_orig.resize((224, 224))
|
||||
enhancer = ImageEnhance.Brightness(img_orig)
|
||||
@ -253,8 +237,10 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
for i in range(9):
|
||||
img = enhancer.enhance(i/4)
|
||||
img_array = np.array(img) / 127.5 - 1
|
||||
prediction = interface.predict(np.expand_dims(img_array, axis=0))
|
||||
processed_output = interface.output_interface.postprocess(prediction)
|
||||
prediction = interface.predict(
|
||||
np.expand_dims(img_array, axis=0))
|
||||
processed_output = interface.output_interface.postprocess(
|
||||
prediction)
|
||||
output = {'input': interface.input_interface.save_to_file(flag_dir, img),
|
||||
'output': interface.output_interface.rebuild_flagged(
|
||||
flag_dir, {'data': {'output': processed_output}}),
|
||||
@ -299,7 +285,8 @@ def start_simple_server(interface, directory_to_serve=None, server_name=None):
|
||||
port = get_first_available_port(
|
||||
INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS
|
||||
)
|
||||
httpd = serve_files_in_background(interface, port, directory_to_serve, server_name)
|
||||
httpd = serve_files_in_background(
|
||||
interface, port, directory_to_serve, server_name)
|
||||
return port, httpd
|
||||
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
align-items: center;
|
||||
font-size: 18px;
|
||||
}
|
||||
.checkbox_group input {
|
||||
.checkbox_group input, .checkbox {
|
||||
margin: 0px 4px 0px 0px;
|
||||
}
|
||||
.checkbox_group label {
|
||||
|
@ -1,6 +1,5 @@
|
||||
body {
|
||||
font-family: 'Open Sans', sans-serif;
|
||||
font-size: 12px;
|
||||
margin: 0;
|
||||
}
|
||||
button, input[type="submit"], input[type="reset"], input[type="text"], input[type="button"], select[type="submit"] {
|
||||
|
@ -1,5 +1,5 @@
|
||||
const checkbox = {
|
||||
html: `<input type="checkbox">`,
|
||||
html: `<input class="checkbox" type="checkbox">`,
|
||||
init: function(opts) {
|
||||
this.target.css("height", "auto");
|
||||
},
|
||||
|
@ -82,8 +82,8 @@ class Interface:
|
||||
self.session = None
|
||||
self.server_name = server_name
|
||||
|
||||
def update_config_file(self, output_directory):
|
||||
config = {
|
||||
def get_config_file(self):
|
||||
return {
|
||||
"input_interfaces": [
|
||||
(iface.__class__.__name__.lower(), iface.get_template_context())
|
||||
for iface in self.input_interfaces],
|
||||
@ -95,7 +95,38 @@ class Interface:
|
||||
"show_input": self.show_input,
|
||||
"show_output": self.show_output,
|
||||
}
|
||||
networking.set_config(config, output_directory)
|
||||
|
||||
def process(self, raw_input):
|
||||
processed_input = [input_interface.preprocess(
|
||||
raw_input[i]) for i, input_interface in enumerate(self.input_interfaces)]
|
||||
predictions = []
|
||||
for predict_fn in self.predict:
|
||||
if self.context:
|
||||
if self.capture_session:
|
||||
graph, sess = self.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
prediction = predict_fn(*processed_input,
|
||||
self.context)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input,
|
||||
self.context)
|
||||
else:
|
||||
if self.capture_session:
|
||||
graph, sess = self.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
prediction = predict_fn(*processed_input)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input)
|
||||
if len(self.output_interfaces) / \
|
||||
len(self.predict) == 1:
|
||||
prediction = [prediction]
|
||||
predictions.extend(prediction)
|
||||
processed_output = [output_interface.postprocess(
|
||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||
return processed_output
|
||||
|
||||
|
||||
def validate(self):
|
||||
if self.validate_flag:
|
||||
@ -181,8 +212,7 @@ class Interface:
|
||||
server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name)
|
||||
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
|
||||
networking.build_template(output_directory)
|
||||
|
||||
self.update_config_file(output_directory)
|
||||
networking.set_config(self.get_config_file(), output_directory)
|
||||
|
||||
self.status = "RUNNING"
|
||||
self.simple_server = httpd
|
||||
|
@ -139,35 +139,7 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
|
||||
int(self.headers["Content-Length"]))
|
||||
msg = json.loads(data_string)
|
||||
raw_input = msg["data"]
|
||||
processed_input = [input_interface.preprocess(
|
||||
raw_input[i]) for i, input_interface in enumerate(interface.input_interfaces)]
|
||||
predictions = []
|
||||
for predict_fn in interface.predict:
|
||||
if interface.context:
|
||||
if interface.capture_session:
|
||||
graph, sess = interface.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
prediction = predict_fn(*processed_input,
|
||||
interface.context)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input,
|
||||
interface.context)
|
||||
else:
|
||||
if interface.capture_session:
|
||||
graph, sess = interface.session
|
||||
with graph.as_default():
|
||||
with sess.as_default():
|
||||
prediction = predict_fn(*processed_input)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input)
|
||||
if len(interface.output_interfaces) / \
|
||||
len(interface.predict) == 1:
|
||||
prediction = [prediction]
|
||||
predictions.extend(prediction)
|
||||
processed_output = [output_interface.postprocess(
|
||||
predictions[i]) for i, output_interface in enumerate(interface.output_interfaces)]
|
||||
output = {"action": "output", "data": processed_output}
|
||||
output = {"action": "output", "data": interface.process(raw_input)}
|
||||
if interface.saliency is not None:
|
||||
saliency = interface.saliency(raw_input, prediction)
|
||||
output['saliency'] = saliency.tolist()
|
||||
|
Loading…
Reference in New Issue
Block a user