added similarity code

This commit is contained in:
Abubakar Abid 2020-11-11 11:57:39 -06:00
parent fe8ea3f5f9
commit 02ab39b3ad
12 changed files with 81 additions and 9 deletions

View File

@ -4,6 +4,7 @@ import tensorflow as tf
import gradio
import gradio as gr
from urllib.request import urlretrieve
import os
urlretrieve("https://gr-models.s3-us-west-2.amazonaws.com/mnist-model.h5", "mnist-model.h5")
model = tf.keras.models.load_model("mnist-model.h5")
@ -14,12 +15,15 @@ def recognize_digit(image):
prediction = model.predict(image).tolist()[0]
return {str(i): prediction[i] for i in range(10)}
im = gradio.inputs.Image(shape=(28, 28), image_mode='L', invert_colors=True)
io = gr.Interface(
recognize_digit,
"sketchpad",
im,
gradio.outputs.Label(num_top_classes=3),
live=True,
examples=[['digits/' + x] for x in os.listdir('digits/')],
# live=True,
interpretation="default",
capture_session=True,
)

BIN
demo/digits/ex1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 KiB

BIN
demo/digits/ex2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.6 KiB

BIN
demo/digits/ex3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 KiB

BIN
demo/digits/ex4.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

View File

@ -1,7 +1,7 @@
# Demo: (Image) -> (Label)
import gradio as gr
import tensorflow as tf
# import tensorflow as tf
import numpy as np
from PIL import Image
@ -13,12 +13,13 @@ import json
with open("files/imagenet_labels.json") as labels_file:
labels = json.load(labels_file)
mobile_net = tf.keras.applications.MobileNetV2()
# mobile_net = tf.keras.applications.MobileNetV2()
def image_classifier(im):
return 0
arr = np.expand_dims(im, axis=0)
arr = tf.keras.applications.mobilenet.preprocess_input(arr)
# arr = tf.keras.applications.mobilenet.preprocess_input(arr)
prediction = mobile_net.predict(arr).flatten()
return {labels[i]: float(prediction[i]) for i in range(1000)}

View File

@ -640,6 +640,8 @@ class Image(InputComponent):
output_scores = (output_scores - min_val) / (max_val - min_val)
return output_scores.tolist()
def embed(self, x):
return x.flatten()
class Audio(InputComponent):
"""

View File

@ -15,6 +15,7 @@ import inspect
import sys
import weakref
import analytics
import numpy as np
import os
import copy
@ -42,6 +43,7 @@ class Interface:
description=None, thumbnail=None, server_port=None,
server_name=networking.LOCALHOST_NAME,
allow_screenshot=True, allow_flagging=True,
embedding_fn="default",
flagging_dir="flagged", analytics_enabled=True):
"""
@ -241,6 +243,16 @@ class Interface:
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
return processed_output, durations
def embedding_fn(self, raw_input):
if self.interpretation == "default":
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(self.input_interfaces)]
embedding = np.concatenate([input_interface.embed(processed_input[i])
for i, input_interface in enumerate(self.input_interfaces)])
else:
raise NotImplementedError("Only default embedding is currently supported")
return embedding
def interpret(self, raw_input):
"""
Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box

View File

@ -11,16 +11,17 @@ from flask_cors import CORS
import threading
import pkg_resources
from distutils import dir_util
import gradio as gr
import time
import json
from gradio.tunneling import create_tunnel
import urllib.request
from shutil import copyfile
import requests
import sys
import csv
import logging
import gradio as gr
from gradio.similarity import calculate_similarity
from gradio.tunneling import create_tunnel
INITIAL_PORT_VALUE = int(os.getenv(
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
@ -120,6 +121,19 @@ def predict():
return jsonify(output)
@app.route("/api/score_similarity/", methods=["POST"])
def score_similarity():
raw_input = request.json["data"]
input_embedding = app.interface.embedding_fn(raw_input)
scores = list()
for example in app.interface.examples:
preprocessed_example = [iface.preprocess_example(example)
for iface, example in zip(app.interface.input_interfaces, example)]
example_embedding = app.interface.embedding_fn(preprocessed_example)
scores.append(calculate_similarity(input_embedding, example_embedding))
return jsonify({"data": scores})
@app.route("/api/predict_examples/", methods=["POST"])
def predict_examples():
example_ids = request.json["data"]

8
gradio/similarity.py Normal file
View File

@ -0,0 +1,8 @@
import numpy as np
def calculate_similarity(embedding1, embedding2):
"""
Scores the similarity between two embeddings by taking the L2 distance
"""
return np.linalg.norm(np.array(embedding1) - np.array(embedding2))

View File

@ -36,8 +36,33 @@ var io_master_template = {
this.target.find(".loading_in_progress").hide();
this.target.find(".loading_failed").show();
});
},
score_similarity: function() {
this.target.find(".loading").removeClass("invisible");
this.target.find(".loading_in_progress").show();
this.target.find(".loading_failed").hide();
this.target.find(".output_interfaces").css("opacity", 0.5);
this.fn(this.last_input, "score_similarity").then((output) => {
console.log(output.data)
this.target.find(".loading").addClass("invisible");
this.target.find(".output_interfaces").css("opacity", 1);
let html = "<th>DIFFS</th>"
this.target.find(".examples > table > thead > tr").append(html);
for (let i = 0; i < output["data"].length; i++) {
let html = "<td>" + output["data"][i] + "</td>"
this.target.find(".examples_body tr[row='" + i + "']").append(html);
}
})
},
submit_examples: function() {
this.target.find(".loading").removeClass("invisible");
this.target.find(".loading_in_progress").show();
this.target.find(".loading_failed").hide();
this.target.find(".output_interfaces").css("opacity", 0.5);
let example_ids = [];
if (this.loaded_examples == null) {
this.loaded_examples = {};
@ -48,6 +73,9 @@ var io_master_template = {
}
}
this.fn(example_ids, "predict_examples").then((output) => {
this.target.find(".loading").addClass("invisible");
this.target.find(".output_interfaces").css("opacity", 1);
output = output["data"];
if (!this.has_loaded_examples) {
this.has_loaded_examples = true;
@ -147,3 +175,5 @@ var io_master_template = {
}
}
};

View File

@ -320,7 +320,8 @@ function gradio(config, fn, target, example_file_path) {
target.find(".interpret").click(function() {
target.find(".interpretation_explained").removeClass("invisible");
if (io_master.last_output) {
io_master.interpret();
io_master.score_similarity();
// io_master.interpret(); // TODO(UNDO)
}
});
target.find(".run_examples").click(function() {