fixed bugs & added custom embedding fn

This commit is contained in:
Abubakar Abid 2020-11-12 06:10:51 -06:00
parent 0eab6a27a4
commit efe292cda9
10 changed files with 18 additions and 14 deletions

View File

@ -15,7 +15,7 @@ def recognize_digit(image):
prediction = model.predict(image).tolist()[0]
return {str(i): prediction[i] for i in range(10)}
im = gradio.inputs.Image(shape=(28, 28), image_mode='L', invert_colors=True)
im = gradio.inputs.Image(shape=(28, 28), image_mode='L', invert_colors=False)
io = gr.Interface(
recognize_digit,

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.7 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 KiB

File diff suppressed because one or more lines are too long

View File

@ -123,6 +123,7 @@ class Interface:
self.analytics_enabled=analytics_enabled
self.save_to = None
self.share = None
self.embedding_fn = embedding_fn
data = {'fn': fn,
'inputs': inputs,
@ -243,14 +244,12 @@ class Interface:
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
return processed_output, durations
def embedding_fn(self, raw_input):
if self.interpretation == "default":
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(self.input_interfaces)]
def embed(self, processed_input):
if self.embedding_fn == "default":
embedding = np.concatenate([input_interface.embed(processed_input[i])
for i, input_interface in enumerate(self.input_interfaces)])
else:
raise NotImplementedError("Only default embedding is currently supported")
embedding = self.embedding_fn(processed_input)
return embedding
def interpret(self, raw_input):

View File

@ -124,12 +124,14 @@ def predict():
@app.route("/api/score_similarity/", methods=["POST"])
def score_similarity():
raw_input = request.json["data"]
input_embedding = app.interface.embedding_fn(raw_input)
preprocessed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(app.interface.input_interfaces)]
input_embedding = app.interface.embed(preprocessed_input)
scores = list()
for example in app.interface.examples:
preprocessed_example = [iface.preprocess_example(example)
preprocessed_example = [iface.preprocess(iface.preprocess_example(example))
for iface, example in zip(app.interface.input_interfaces, example)]
example_embedding = app.interface.embedding_fn(preprocessed_example)
example_embedding = app.interface.embed(preprocessed_example)
scores.append(calculate_similarity(input_embedding, example_embedding))
return jsonify({"data": scores})

View File

@ -47,7 +47,7 @@ var io_master_template = {
console.log(output.data)
this.target.find(".loading").addClass("invisible");
this.target.find(".output_interfaces").css("opacity", 1);
let html = "<th>DIFFS</th>"
let html = "<th>SIMILARITY</th>"
this.target.find(".examples > table > thead > tr").append(html);
for (let i = 0; i < output["data"].length; i++) {
let html = "<td>" + output["data"][i] + "</td>"

View File

@ -318,7 +318,7 @@ function gradio(config, fn, target, example_file_path) {
}
})
target.find(".interpret").click(function() {
target.find(".interpretation_explained").removeClass("invisible");
// target.find(".interpretation_explained").removeClass("invisible");
if (io_master.last_output) {
io_master.score_similarity();
// io_master.interpret(); // TODO(UNDO)