merge bug

This commit is contained in:
Ali Abid 2020-06-29 12:22:32 -07:00
parent 4472e27438
commit 1a6eec7a57
12 changed files with 100 additions and 267 deletions

View File

@ -7,7 +7,7 @@ automatically added to a registry, which allows them to be easily referenced in
from abc import ABC, abstractmethod
from gradio import preprocessing_utils, validation_data
import numpy as np
from PIL import Image, ImageOps
import PIL.Image, PIL.ImageOps
import time
import warnings
import json
@ -58,11 +58,12 @@ class AbstractInput(ABC):
"""
return {}
def rebuild_flagged(self, dir, msg):
@classmethod
def process_example(self, example):
"""
All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64)
Proprocess example for UI
"""
pass
return example
class Sketchpad(AbstractInput):
@ -84,11 +85,11 @@ class Sketchpad(AbstractInput):
Default preprocessing method for the SketchPad is to convert the sketch to black and white and resize 28x28
"""
im_transparent = preprocessing_utils.decode_base64_to_image(inp)
im = Image.new("RGBA", im_transparent.size, "WHITE") # Create a white background for the alpha channel
im = PIL.Image.new("RGBA", im_transparent.size, "WHITE") # Create a white background for the alpha channel
im.paste(im_transparent, (0, 0), im_transparent)
im = im.convert('L')
if self.invert_colors:
im = ImageOps.invert(im)
im = PIL.ImageOps.invert(im)
im = im.resize((self.image_width, self.image_height))
if self.flatten:
array = np.array(im).flatten().reshape(1, self.image_width * self.image_height)
@ -98,30 +99,6 @@ class Sketchpad(AbstractInput):
array = array.astype(self.dtype)
return array
# TODO(abidlabs): clean this up
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(msg)
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
return filename
def get_sample_inputs(self):
encoded_images = []
if self.sample_inputs is not None:
for input in self.sample_inputs:
if self.flatten:
input = input.reshape((self.image_width, self.image_height))
if self.invert_colors:
input = 1 - input
encoded_images.append(preprocessing_utils.encode_array_to_base64(input))
return encoded_images
class Webcam(AbstractInput):
def __init__(self, image_width=224, image_height=224, num_channels=3, label=None):
@ -149,17 +126,6 @@ class Webcam(AbstractInput):
array = np.array(im).flatten().reshape(self.image_width, self.image_height, self.num_channels)
return array
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method to decode a base64 image
"""
inp = msg['data']['input']
im = preprocessing_utils.decode_base64_to_image(inp)
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
return filename
class Textbox(AbstractInput):
def __init__(self, sample_inputs=None, lines=1, placeholder=None, label=None, numeric=False):
@ -196,15 +162,6 @@ class Textbox(AbstractInput):
else:
return inp
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for text saves it .txt file
"""
return json.loads(msg)
def get_sample_inputs(self):
return self.sample_inputs
class Radio(AbstractInput):
def __init__(self, choices, label=None):
@ -261,6 +218,7 @@ class Slider(AbstractInput):
"checkbox": {},
}
class Checkbox(AbstractInput):
def __init__(self, label=None):
super().__init__(label)
@ -272,7 +230,7 @@ class Checkbox(AbstractInput):
}
class ImageIn(AbstractInput):
class Image(AbstractInput):
def __init__(self, cast_to=None, shape=(224, 224, 3), image_mode='RGB',
scale=1/127.5, shift=-1, cropper_aspect_ratio=None, label=None):
self.cast_to = cast_to
@ -338,43 +296,6 @@ class ImageIn(AbstractInput):
self.num_channels)
return array
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(msg)
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
return filename
# TODO(abidlabs): clean this up
def save_to_file(self, dir, img):
"""
"""
timestamp = time.time()*1000
filename = 'input_{}.png'.format(timestamp)
img.save('{}/{}'.format(dir, filename), 'PNG')
return filename
class CSV(AbstractInput):
def get_name(self):
return 'csv'
def preprocess(self, inp):
"""
By default, no pre-processing is applied to a CSV file (TODO:aliabid94 fix this)
"""
return inp
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for csv
"""
return json.loads(msg)
class Microphone(AbstractInput):
@ -386,12 +307,6 @@ class Microphone(AbstractInput):
mfcc_array = preprocessing_utils.generate_mfcc_features_from_audio_file(file_obj.name)
return mfcc_array
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for csv
"""
return json.loads(msg)
# Automatically adds all shortcut implementations in AbstractInput into a dictionary.
shortcuts = {}

View File

@ -16,7 +16,6 @@ import requests
import random
import time
from IPython import get_ipython
import tensorflow as tf
LOCALHOST_IP = "0.0.0.0"
TRY_NUM_PORTS = 100
@ -29,9 +28,9 @@ class Interface:
the appropriate inputs and outputs
"""
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False,
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, examples=None,
live=False, show_input=True, show_output=True,
load_fn=None, capture_session=False,
load_fn=None, capture_session=False, title=None, description=None,
server_name=LOCALHOST_IP):
"""
:param fn: a function that will process the input panel data from the interface and return the output panel data.
@ -81,6 +80,9 @@ class Interface:
self.capture_session = capture_session
self.session = None
self.server_name = server_name
self.title = title
self.description = description
self.examples = examples
def get_config_file(self):
return {
@ -93,7 +95,9 @@ class Interface:
"function_count": len(self.predict),
"live": self.live,
"show_input": self.show_input,
"show_output": self.show_output,
"show_output": self.show_output,
"title": self.title,
"description": self.description,
}
def process(self, raw_input):
@ -109,8 +113,15 @@ class Interface:
prediction = predict_fn(*processed_input,
self.context)
else:
prediction = predict_fn(*processed_input,
self.context)
try:
prediction = predict_fn(*processed_input, self.context)
except ValueError:
print("It looks like you might be "
"using tensorflow < 2.0. Please pass "
"capture_session=True in Interface to avoid "
"a 'Tensor is not an element of this graph.' "
"error.")
prediction = predict_fn(*processed_input, self.context)
else:
if self.capture_session:
graph, sess = self.session
@ -118,7 +129,16 @@ class Interface:
with sess.as_default():
prediction = predict_fn(*processed_input)
else:
prediction = predict_fn(*processed_input)
try:
prediction = predict_fn(*processed_input)
except ValueError:
print("It looks like you might be "
"using tensorflow < 2.0. Please pass "
"capture_session=True in Interface to avoid "
"a 'Tensor is not an element of this graph.' "
"error.")
prediction = predict_fn(*processed_input)
if len(self.output_interfaces) / \
len(self.predict) == 1:
prediction = [prediction]
@ -127,7 +147,6 @@ class Interface:
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
return processed_output
def validate(self):
if self.validate_flag:
if self.verbose:
@ -180,11 +199,7 @@ class Interface:
return
raise RuntimeError("Validation did not pass")
<<<<<<< HEAD
def launch(self, inline=None, inbrowser=None, share=False, validate=True, title=None, description=None):
=======
def launch(self, inline=None, inbrowser=None, share=False, validate=True):
>>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217
"""
Standard method shared by interfaces that creates the interface and sets up a websocket to communicate with it.
:param inline: boolean. If True, then a gradio interface is created inline (e.g. in jupyter or colab notebook)
@ -198,6 +213,7 @@ class Interface:
self.context = context
if self.capture_session:
import tensorflow as tf
self.session = tf.get_default_graph(), \
tf.keras.backend.get_session()
@ -294,11 +310,7 @@ class Interface:
config = self.get_config_file()
config["share_url"] = share_url
<<<<<<< HEAD
config["title"] = title
config["description"] = description
=======
>>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217
config["examples"] = self.examples
networking.set_config(config, output_directory)
return httpd, path_to_local_server, share_url

View File

@ -76,12 +76,6 @@ class Label(AbstractOutput):
"label": {},
}
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for label
"""
return json.loads(msg)
class KeyValues(AbstractOutput):
def __init__(self, label=None):
@ -120,12 +114,6 @@ class Textbox(AbstractOutput):
"""
return prediction
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for label
"""
return json.loads(msg)
class Image(AbstractOutput):
def __init__(self, label=None, plot=False):

View File

@ -31,7 +31,6 @@ nav img {
padding: 4px;
border-radius: 2px;
}
<<<<<<< HEAD
#title {
text-align: center;
}
@ -40,19 +39,11 @@ nav img {
width: 100%;
margin: 0 auto;
}
=======
>>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217
.panels {
display: flex;
flex-flow: row;
flex-wrap: wrap;
justify-content: center;
<<<<<<< HEAD
=======
max-width: 1028px;
width: 100%;
margin: 0 auto;
>>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217
}
button.primary {
color: white;

View File

@ -1,11 +1,7 @@
function gradio(config, fn, target) {
target = $(target);
target.html(`
<<<<<<< HEAD
<div class="panels container">
=======
<div class="panels">
>>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217
<div class="panel input_panel">
<div class="input_interfaces">
</div>
@ -30,7 +26,7 @@ function gradio(config, fn, target) {
let input_to_object_map = {
"csv" : {},
"imagein" : image_input,
"image" : image_input,
"sketchpad" : sketchpad_input,
"textbox" : textbox_input,
"webcam" : webcam,

View File

@ -20,27 +20,16 @@ const webcam = {
},
submit: function() {
var io = this;
<<<<<<< HEAD
Webcam.snap(function(image_data) {
io.io_master.input(io.id, image_data);
});
// Webcam.freeze();
=======
Webcam.freeze();
Webcam.snap(function(image_data) {
io.io_master.input(io.id, image_data);
});
>>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217
this.state = "SNAPPED";
},
clear: function() {
if (this.state == "SNAPPED") {
this.state = "CAMERA_ON";
<<<<<<< HEAD
// Webcam.unfreeze();
=======
Webcam.unfreeze();
>>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217
}
},
state: "NOT_STARTED",

View File

@ -34,14 +34,16 @@
Live at <a id="share-link"></a>.
<button id="share-copy">Copy Link</button>
</div>
<<<<<<< HEAD
<div class="container">
<h1 id="title"></h1>
<p id="description"></p>
</div>
=======
>>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217
<div id="interface_target"></div>
<div id="examples" class="container invisible">
<h3>Examples</h3>
<table>
</table>
</div>
<script src="../static/js/vendor/jquery.min.js"></script>
<!-- TUI EDITOR -->
<script src="../static/js/vendor/fabric.js"></script>
@ -89,27 +91,40 @@
});
});
}, "#interface_target");
<<<<<<< HEAD
if (config["title"]) {
$("#title").text(config["title"]);
}
if (config["description"]) {
$("#description").text(config["description"]);
}
=======
>>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217
if (config["share_url"]) {
let share_url = config["share_url"];
$("#share").removeClass("invisible");
$("#share-link").text(share_url).attr("href", share_url);
$("#share-copy").click(function() {
copyToClipboard(share_url);
<<<<<<< HEAD
$("#share-copy").text("Copied!");
=======
>>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217
})
}
};
if (config["examples"]) {
$("#examples").removeClass("invisible");
let html = "<thead>"
for (let i = 0; i < config["input_interfaces"].length; i++) {
label = config["input_interfaces"][i][1]["label"];
html += "<th>" + label + "</th>";
}
html += "</thead>";
html += "<tbody>";
for (let example of config["examples"]) {
html += "<tr>";
for (let col of example) {
html += "<td>" + col + "</td>";
}
html += "</tr>";
}
html += "</tbody>";
$("#examples table").html(html);
};
});
const copyToClipboard = str => {
const el = document.createElement('textarea');

View File

@ -17,9 +17,8 @@ gr.Interface(answer_question,
], [
gr.outputs.Textbox(label="out", lines=8),
"key_values"
], examples=[
["things1", "things2"],
["things10", "things20"],
]
<<<<<<< HEAD
).launch(title="Demo", description="Trying out a funky model!")
=======
).launch(share=True)
>>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217
).launch()

View File

@ -58,11 +58,12 @@ class AbstractInput(ABC):
"""
return {}
def rebuild_flagged(self, dir, msg):
@classmethod
def process_example(self, example):
"""
All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64)
Proprocess example for UI
"""
pass
return example
class Sketchpad(AbstractInput):
@ -98,30 +99,6 @@ class Sketchpad(AbstractInput):
array = array.astype(self.dtype)
return array
# TODO(abidlabs): clean this up
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(msg)
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
return filename
def get_sample_inputs(self):
encoded_images = []
if self.sample_inputs is not None:
for input in self.sample_inputs:
if self.flatten:
input = input.reshape((self.image_width, self.image_height))
if self.invert_colors:
input = 1 - input
encoded_images.append(preprocessing_utils.encode_array_to_base64(input))
return encoded_images
class Webcam(AbstractInput):
def __init__(self, image_width=224, image_height=224, num_channels=3, label=None):
@ -149,17 +126,6 @@ class Webcam(AbstractInput):
array = np.array(im).flatten().reshape(self.image_width, self.image_height, self.num_channels)
return array
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method to decode a base64 image
"""
inp = msg['data']['input']
im = preprocessing_utils.decode_base64_to_image(inp)
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
return filename
class Textbox(AbstractInput):
def __init__(self, sample_inputs=None, lines=1, placeholder=None, label=None, numeric=False):
@ -196,15 +162,6 @@ class Textbox(AbstractInput):
else:
return inp
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for text saves it .txt file
"""
return json.loads(msg)
def get_sample_inputs(self):
return self.sample_inputs
class Radio(AbstractInput):
def __init__(self, choices, label=None):
@ -339,43 +296,6 @@ class Image(AbstractInput):
self.num_channels)
return array
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method to decode a base64 image
"""
im = preprocessing_utils.decode_base64_to_image(msg)
timestamp = datetime.datetime.now()
filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png'
im.save(f'{dir}/{filename}', 'PNG')
return filename
# TODO(abidlabs): clean this up
def save_to_file(self, dir, img):
"""
"""
timestamp = time.time()*1000
filename = 'input_{}.png'.format(timestamp)
img.save('{}/{}'.format(dir, filename), 'PNG')
return filename
class CSV(AbstractInput):
def get_name(self):
return 'csv'
def preprocess(self, inp):
"""
By default, no pre-processing is applied to a CSV file (TODO:aliabid94 fix this)
"""
return inp
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for csv
"""
return json.loads(msg)
class Microphone(AbstractInput):
@ -387,12 +307,6 @@ class Microphone(AbstractInput):
mfcc_array = preprocessing_utils.generate_mfcc_features_from_audio_file(file_obj.name)
return mfcc_array
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for csv
"""
return json.loads(msg)
# Automatically adds all shortcut implementations in AbstractInput into a dictionary.
shortcuts = {}

View File

@ -28,7 +28,7 @@ class Interface:
the appropriate inputs and outputs
"""
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False,
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, examples=None,
live=False, show_input=True, show_output=True,
load_fn=None, capture_session=False, title=None, description=None,
server_name=LOCALHOST_IP):
@ -82,6 +82,7 @@ class Interface:
self.server_name = server_name
self.title = title
self.description = description
self.examples = examples
def get_config_file(self):
return {
@ -309,6 +310,7 @@ class Interface:
config = self.get_config_file()
config["share_url"] = share_url
config["examples"] = self.examples
networking.set_config(config, output_directory)
return httpd, path_to_local_server, share_url

View File

@ -76,12 +76,6 @@ class Label(AbstractOutput):
"label": {},
}
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for label
"""
return json.loads(msg)
class KeyValues(AbstractOutput):
def __init__(self, label=None):
@ -120,12 +114,6 @@ class Textbox(AbstractOutput):
"""
return prediction
def rebuild_flagged(self, dir, msg):
"""
Default rebuild method for label
"""
return json.loads(msg)
class Image(AbstractOutput):
def __init__(self, label=None, plot=False):

View File

@ -39,6 +39,11 @@
<p id="description"></p>
</div>
<div id="interface_target"></div>
<div id="examples" class="container invisible">
<h3>Examples</h3>
<table>
</table>
</div>
<script src="../static/js/vendor/jquery.min.js"></script>
<!-- TUI EDITOR -->
<script src="../static/js/vendor/fabric.js"></script>
@ -100,7 +105,26 @@
copyToClipboard(share_url);
$("#share-copy").text("Copied!");
})
}
};
if (config["examples"]) {
$("#examples").removeClass("invisible");
let html = "<thead>"
for (let i = 0; i < config["input_interfaces"].length; i++) {
label = config["input_interfaces"][i][1]["label"];
html += "<th>" + label + "</th>";
}
html += "</thead>";
html += "<tbody>";
for (let example of config["examples"]) {
html += "<tr>";
for (let col of example) {
html += "<td>" + col + "</td>";
}
html += "</tr>";
}
html += "</tbody>";
$("#examples table").html(html);
};
});
const copyToClipboard = str => {
const el = document.createElement('textarea');