mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
added similarity code
This commit is contained in:
parent
fe8ea3f5f9
commit
02ab39b3ad
@ -4,6 +4,7 @@ import tensorflow as tf
|
||||
import gradio
|
||||
import gradio as gr
|
||||
from urllib.request import urlretrieve
|
||||
import os
|
||||
|
||||
urlretrieve("https://gr-models.s3-us-west-2.amazonaws.com/mnist-model.h5", "mnist-model.h5")
|
||||
model = tf.keras.models.load_model("mnist-model.h5")
|
||||
@ -14,12 +15,15 @@ def recognize_digit(image):
|
||||
prediction = model.predict(image).tolist()[0]
|
||||
return {str(i): prediction[i] for i in range(10)}
|
||||
|
||||
im = gradio.inputs.Image(shape=(28, 28), image_mode='L', invert_colors=True)
|
||||
|
||||
io = gr.Interface(
|
||||
recognize_digit,
|
||||
"sketchpad",
|
||||
im,
|
||||
gradio.outputs.Label(num_top_classes=3),
|
||||
live=True,
|
||||
examples=[['digits/' + x] for x in os.listdir('digits/')],
|
||||
# live=True,
|
||||
interpretation="default",
|
||||
capture_session=True,
|
||||
)
|
||||
|
||||
|
BIN
demo/digits/ex1.png
Normal file
BIN
demo/digits/ex1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.7 KiB |
BIN
demo/digits/ex2.png
Normal file
BIN
demo/digits/ex2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.6 KiB |
BIN
demo/digits/ex3.png
Normal file
BIN
demo/digits/ex3.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.7 KiB |
BIN
demo/digits/ex4.png
Normal file
BIN
demo/digits/ex4.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.6 KiB |
@ -1,7 +1,7 @@
|
||||
# Demo: (Image) -> (Label)
|
||||
|
||||
import gradio as gr
|
||||
import tensorflow as tf
|
||||
# import tensorflow as tf
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@ -13,12 +13,13 @@ import json
|
||||
with open("files/imagenet_labels.json") as labels_file:
|
||||
labels = json.load(labels_file)
|
||||
|
||||
mobile_net = tf.keras.applications.MobileNetV2()
|
||||
# mobile_net = tf.keras.applications.MobileNetV2()
|
||||
|
||||
|
||||
def image_classifier(im):
|
||||
return 0
|
||||
arr = np.expand_dims(im, axis=0)
|
||||
arr = tf.keras.applications.mobilenet.preprocess_input(arr)
|
||||
# arr = tf.keras.applications.mobilenet.preprocess_input(arr)
|
||||
prediction = mobile_net.predict(arr).flatten()
|
||||
return {labels[i]: float(prediction[i]) for i in range(1000)}
|
||||
|
||||
|
@ -640,6 +640,8 @@ class Image(InputComponent):
|
||||
output_scores = (output_scores - min_val) / (max_val - min_val)
|
||||
return output_scores.tolist()
|
||||
|
||||
def embed(self, x):
|
||||
return x.flatten()
|
||||
|
||||
class Audio(InputComponent):
|
||||
"""
|
||||
|
@ -15,6 +15,7 @@ import inspect
|
||||
import sys
|
||||
import weakref
|
||||
import analytics
|
||||
import numpy as np
|
||||
import os
|
||||
import copy
|
||||
|
||||
@ -42,6 +43,7 @@ class Interface:
|
||||
description=None, thumbnail=None, server_port=None,
|
||||
server_name=networking.LOCALHOST_NAME,
|
||||
allow_screenshot=True, allow_flagging=True,
|
||||
embedding_fn="default",
|
||||
flagging_dir="flagged", analytics_enabled=True):
|
||||
|
||||
"""
|
||||
@ -241,6 +243,16 @@ class Interface:
|
||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||
return processed_output, durations
|
||||
|
||||
def embedding_fn(self, raw_input):
|
||||
if self.interpretation == "default":
|
||||
processed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(self.input_interfaces)]
|
||||
embedding = np.concatenate([input_interface.embed(processed_input[i])
|
||||
for i, input_interface in enumerate(self.input_interfaces)])
|
||||
else:
|
||||
raise NotImplementedError("Only default embedding is currently supported")
|
||||
return embedding
|
||||
|
||||
def interpret(self, raw_input):
|
||||
"""
|
||||
Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box
|
||||
|
@ -11,16 +11,17 @@ from flask_cors import CORS
|
||||
import threading
|
||||
import pkg_resources
|
||||
from distutils import dir_util
|
||||
import gradio as gr
|
||||
import time
|
||||
import json
|
||||
from gradio.tunneling import create_tunnel
|
||||
import urllib.request
|
||||
from shutil import copyfile
|
||||
import requests
|
||||
import sys
|
||||
import csv
|
||||
import logging
|
||||
import gradio as gr
|
||||
from gradio.similarity import calculate_similarity
|
||||
from gradio.tunneling import create_tunnel
|
||||
|
||||
INITIAL_PORT_VALUE = int(os.getenv(
|
||||
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||
@ -120,6 +121,19 @@ def predict():
|
||||
return jsonify(output)
|
||||
|
||||
|
||||
@app.route("/api/score_similarity/", methods=["POST"])
|
||||
def score_similarity():
|
||||
raw_input = request.json["data"]
|
||||
input_embedding = app.interface.embedding_fn(raw_input)
|
||||
scores = list()
|
||||
for example in app.interface.examples:
|
||||
preprocessed_example = [iface.preprocess_example(example)
|
||||
for iface, example in zip(app.interface.input_interfaces, example)]
|
||||
example_embedding = app.interface.embedding_fn(preprocessed_example)
|
||||
scores.append(calculate_similarity(input_embedding, example_embedding))
|
||||
return jsonify({"data": scores})
|
||||
|
||||
|
||||
@app.route("/api/predict_examples/", methods=["POST"])
|
||||
def predict_examples():
|
||||
example_ids = request.json["data"]
|
||||
|
8
gradio/similarity.py
Normal file
8
gradio/similarity.py
Normal file
@ -0,0 +1,8 @@
|
||||
import numpy as np
|
||||
|
||||
def calculate_similarity(embedding1, embedding2):
|
||||
"""
|
||||
Scores the similarity between two embeddings by taking the L2 distance
|
||||
"""
|
||||
return np.linalg.norm(np.array(embedding1) - np.array(embedding2))
|
||||
|
@ -36,8 +36,33 @@ var io_master_template = {
|
||||
this.target.find(".loading_in_progress").hide();
|
||||
this.target.find(".loading_failed").show();
|
||||
});
|
||||
},
|
||||
score_similarity: function() {
|
||||
this.target.find(".loading").removeClass("invisible");
|
||||
this.target.find(".loading_in_progress").show();
|
||||
this.target.find(".loading_failed").hide();
|
||||
this.target.find(".output_interfaces").css("opacity", 0.5);
|
||||
|
||||
this.fn(this.last_input, "score_similarity").then((output) => {
|
||||
console.log(output.data)
|
||||
this.target.find(".loading").addClass("invisible");
|
||||
this.target.find(".output_interfaces").css("opacity", 1);
|
||||
let html = "<th>DIFFS</th>"
|
||||
this.target.find(".examples > table > thead > tr").append(html);
|
||||
for (let i = 0; i < output["data"].length; i++) {
|
||||
let html = "<td>" + output["data"][i] + "</td>"
|
||||
this.target.find(".examples_body tr[row='" + i + "']").append(html);
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
},
|
||||
submit_examples: function() {
|
||||
this.target.find(".loading").removeClass("invisible");
|
||||
this.target.find(".loading_in_progress").show();
|
||||
this.target.find(".loading_failed").hide();
|
||||
this.target.find(".output_interfaces").css("opacity", 0.5);
|
||||
|
||||
let example_ids = [];
|
||||
if (this.loaded_examples == null) {
|
||||
this.loaded_examples = {};
|
||||
@ -48,6 +73,9 @@ var io_master_template = {
|
||||
}
|
||||
}
|
||||
this.fn(example_ids, "predict_examples").then((output) => {
|
||||
this.target.find(".loading").addClass("invisible");
|
||||
this.target.find(".output_interfaces").css("opacity", 1);
|
||||
|
||||
output = output["data"];
|
||||
if (!this.has_loaded_examples) {
|
||||
this.has_loaded_examples = true;
|
||||
@ -147,3 +175,5 @@ var io_master_template = {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
@ -320,7 +320,8 @@ function gradio(config, fn, target, example_file_path) {
|
||||
target.find(".interpret").click(function() {
|
||||
target.find(".interpretation_explained").removeClass("invisible");
|
||||
if (io_master.last_output) {
|
||||
io_master.interpret();
|
||||
io_master.score_similarity();
|
||||
// io_master.interpret(); // TODO(UNDO)
|
||||
}
|
||||
});
|
||||
target.find(".run_examples").click(function() {
|
||||
|
Loading…
Reference in New Issue
Block a user