added similarity code

This commit is contained in:
Abubakar Abid 2020-11-11 11:57:39 -06:00
parent fe8ea3f5f9
commit 02ab39b3ad
12 changed files with 81 additions and 9 deletions

View File

@ -4,6 +4,7 @@ import tensorflow as tf
import gradio import gradio
import gradio as gr import gradio as gr
from urllib.request import urlretrieve from urllib.request import urlretrieve
import os
urlretrieve("https://gr-models.s3-us-west-2.amazonaws.com/mnist-model.h5", "mnist-model.h5") urlretrieve("https://gr-models.s3-us-west-2.amazonaws.com/mnist-model.h5", "mnist-model.h5")
model = tf.keras.models.load_model("mnist-model.h5") model = tf.keras.models.load_model("mnist-model.h5")
@ -14,12 +15,15 @@ def recognize_digit(image):
prediction = model.predict(image).tolist()[0] prediction = model.predict(image).tolist()[0]
return {str(i): prediction[i] for i in range(10)} return {str(i): prediction[i] for i in range(10)}
im = gradio.inputs.Image(shape=(28, 28), image_mode='L', invert_colors=True)
io = gr.Interface( io = gr.Interface(
recognize_digit, recognize_digit,
"sketchpad", im,
gradio.outputs.Label(num_top_classes=3), gradio.outputs.Label(num_top_classes=3),
live=True, examples=[['digits/' + x] for x in os.listdir('digits/')],
# live=True,
interpretation="default",
capture_session=True, capture_session=True,
) )

BIN
demo/digits/ex1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 KiB

BIN
demo/digits/ex2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.6 KiB

BIN
demo/digits/ex3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 KiB

BIN
demo/digits/ex4.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

View File

@ -1,7 +1,7 @@
# Demo: (Image) -> (Label) # Demo: (Image) -> (Label)
import gradio as gr import gradio as gr
import tensorflow as tf # import tensorflow as tf
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -13,12 +13,13 @@ import json
with open("files/imagenet_labels.json") as labels_file: with open("files/imagenet_labels.json") as labels_file:
labels = json.load(labels_file) labels = json.load(labels_file)
mobile_net = tf.keras.applications.MobileNetV2() # mobile_net = tf.keras.applications.MobileNetV2()
def image_classifier(im): def image_classifier(im):
return 0
arr = np.expand_dims(im, axis=0) arr = np.expand_dims(im, axis=0)
arr = tf.keras.applications.mobilenet.preprocess_input(arr) # arr = tf.keras.applications.mobilenet.preprocess_input(arr)
prediction = mobile_net.predict(arr).flatten() prediction = mobile_net.predict(arr).flatten()
return {labels[i]: float(prediction[i]) for i in range(1000)} return {labels[i]: float(prediction[i]) for i in range(1000)}

View File

@ -640,6 +640,8 @@ class Image(InputComponent):
output_scores = (output_scores - min_val) / (max_val - min_val) output_scores = (output_scores - min_val) / (max_val - min_val)
return output_scores.tolist() return output_scores.tolist()
def embed(self, x):
return x.flatten()
class Audio(InputComponent): class Audio(InputComponent):
""" """

View File

@ -15,6 +15,7 @@ import inspect
import sys import sys
import weakref import weakref
import analytics import analytics
import numpy as np
import os import os
import copy import copy
@ -42,6 +43,7 @@ class Interface:
description=None, thumbnail=None, server_port=None, description=None, thumbnail=None, server_port=None,
server_name=networking.LOCALHOST_NAME, server_name=networking.LOCALHOST_NAME,
allow_screenshot=True, allow_flagging=True, allow_screenshot=True, allow_flagging=True,
embedding_fn="default",
flagging_dir="flagged", analytics_enabled=True): flagging_dir="flagged", analytics_enabled=True):
""" """
@ -241,6 +243,16 @@ class Interface:
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)] predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
return processed_output, durations return processed_output, durations
def embedding_fn(self, raw_input):
if self.interpretation == "default":
processed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(self.input_interfaces)]
embedding = np.concatenate([input_interface.embed(processed_input[i])
for i, input_interface in enumerate(self.input_interfaces)])
else:
raise NotImplementedError("Only default embedding is currently supported")
return embedding
def interpret(self, raw_input): def interpret(self, raw_input):
""" """
Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box

View File

@ -11,16 +11,17 @@ from flask_cors import CORS
import threading import threading
import pkg_resources import pkg_resources
from distutils import dir_util from distutils import dir_util
import gradio as gr
import time import time
import json import json
from gradio.tunneling import create_tunnel
import urllib.request import urllib.request
from shutil import copyfile from shutil import copyfile
import requests import requests
import sys import sys
import csv import csv
import logging import logging
import gradio as gr
from gradio.similarity import calculate_similarity
from gradio.tunneling import create_tunnel
INITIAL_PORT_VALUE = int(os.getenv( INITIAL_PORT_VALUE = int(os.getenv(
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc. 'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
@ -120,6 +121,19 @@ def predict():
return jsonify(output) return jsonify(output)
@app.route("/api/score_similarity/", methods=["POST"])
def score_similarity():
raw_input = request.json["data"]
input_embedding = app.interface.embedding_fn(raw_input)
scores = list()
for example in app.interface.examples:
preprocessed_example = [iface.preprocess_example(example)
for iface, example in zip(app.interface.input_interfaces, example)]
example_embedding = app.interface.embedding_fn(preprocessed_example)
scores.append(calculate_similarity(input_embedding, example_embedding))
return jsonify({"data": scores})
@app.route("/api/predict_examples/", methods=["POST"]) @app.route("/api/predict_examples/", methods=["POST"])
def predict_examples(): def predict_examples():
example_ids = request.json["data"] example_ids = request.json["data"]

8
gradio/similarity.py Normal file
View File

@ -0,0 +1,8 @@
import numpy as np
def calculate_similarity(embedding1, embedding2):
"""
Scores the similarity between two embeddings by taking the L2 distance
"""
return np.linalg.norm(np.array(embedding1) - np.array(embedding2))

View File

@ -36,8 +36,33 @@ var io_master_template = {
this.target.find(".loading_in_progress").hide(); this.target.find(".loading_in_progress").hide();
this.target.find(".loading_failed").show(); this.target.find(".loading_failed").show();
}); });
},
score_similarity: function() {
this.target.find(".loading").removeClass("invisible");
this.target.find(".loading_in_progress").show();
this.target.find(".loading_failed").hide();
this.target.find(".output_interfaces").css("opacity", 0.5);
this.fn(this.last_input, "score_similarity").then((output) => {
console.log(output.data)
this.target.find(".loading").addClass("invisible");
this.target.find(".output_interfaces").css("opacity", 1);
let html = "<th>DIFFS</th>"
this.target.find(".examples > table > thead > tr").append(html);
for (let i = 0; i < output["data"].length; i++) {
let html = "<td>" + output["data"][i] + "</td>"
this.target.find(".examples_body tr[row='" + i + "']").append(html);
}
})
}, },
submit_examples: function() { submit_examples: function() {
this.target.find(".loading").removeClass("invisible");
this.target.find(".loading_in_progress").show();
this.target.find(".loading_failed").hide();
this.target.find(".output_interfaces").css("opacity", 0.5);
let example_ids = []; let example_ids = [];
if (this.loaded_examples == null) { if (this.loaded_examples == null) {
this.loaded_examples = {}; this.loaded_examples = {};
@ -48,6 +73,9 @@ var io_master_template = {
} }
} }
this.fn(example_ids, "predict_examples").then((output) => { this.fn(example_ids, "predict_examples").then((output) => {
this.target.find(".loading").addClass("invisible");
this.target.find(".output_interfaces").css("opacity", 1);
output = output["data"]; output = output["data"];
if (!this.has_loaded_examples) { if (!this.has_loaded_examples) {
this.has_loaded_examples = true; this.has_loaded_examples = true;
@ -147,3 +175,5 @@ var io_master_template = {
} }
} }
}; };

View File

@ -320,7 +320,8 @@ function gradio(config, fn, target, example_file_path) {
target.find(".interpret").click(function() { target.find(".interpret").click(function() {
target.find(".interpretation_explained").removeClass("invisible"); target.find(".interpretation_explained").removeClass("invisible");
if (io_master.last_output) { if (io_master.last_output) {
io_master.interpret(); io_master.score_similarity();
// io_master.interpret(); // TODO(UNDO)
} }
}); });
target.find(".run_examples").click(function() { target.find(".run_examples").click(function() {