mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
added similarity code
This commit is contained in:
parent
fe8ea3f5f9
commit
02ab39b3ad
@ -4,6 +4,7 @@ import tensorflow as tf
|
|||||||
import gradio
|
import gradio
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from urllib.request import urlretrieve
|
from urllib.request import urlretrieve
|
||||||
|
import os
|
||||||
|
|
||||||
urlretrieve("https://gr-models.s3-us-west-2.amazonaws.com/mnist-model.h5", "mnist-model.h5")
|
urlretrieve("https://gr-models.s3-us-west-2.amazonaws.com/mnist-model.h5", "mnist-model.h5")
|
||||||
model = tf.keras.models.load_model("mnist-model.h5")
|
model = tf.keras.models.load_model("mnist-model.h5")
|
||||||
@ -14,12 +15,15 @@ def recognize_digit(image):
|
|||||||
prediction = model.predict(image).tolist()[0]
|
prediction = model.predict(image).tolist()[0]
|
||||||
return {str(i): prediction[i] for i in range(10)}
|
return {str(i): prediction[i] for i in range(10)}
|
||||||
|
|
||||||
|
im = gradio.inputs.Image(shape=(28, 28), image_mode='L', invert_colors=True)
|
||||||
|
|
||||||
io = gr.Interface(
|
io = gr.Interface(
|
||||||
recognize_digit,
|
recognize_digit,
|
||||||
"sketchpad",
|
im,
|
||||||
gradio.outputs.Label(num_top_classes=3),
|
gradio.outputs.Label(num_top_classes=3),
|
||||||
live=True,
|
examples=[['digits/' + x] for x in os.listdir('digits/')],
|
||||||
|
# live=True,
|
||||||
|
interpretation="default",
|
||||||
capture_session=True,
|
capture_session=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
BIN
demo/digits/ex1.png
Normal file
BIN
demo/digits/ex1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.7 KiB |
BIN
demo/digits/ex2.png
Normal file
BIN
demo/digits/ex2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.6 KiB |
BIN
demo/digits/ex3.png
Normal file
BIN
demo/digits/ex3.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.7 KiB |
BIN
demo/digits/ex4.png
Normal file
BIN
demo/digits/ex4.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.6 KiB |
@ -1,7 +1,7 @@
|
|||||||
# Demo: (Image) -> (Label)
|
# Demo: (Image) -> (Label)
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import tensorflow as tf
|
# import tensorflow as tf
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -13,12 +13,13 @@ import json
|
|||||||
with open("files/imagenet_labels.json") as labels_file:
|
with open("files/imagenet_labels.json") as labels_file:
|
||||||
labels = json.load(labels_file)
|
labels = json.load(labels_file)
|
||||||
|
|
||||||
mobile_net = tf.keras.applications.MobileNetV2()
|
# mobile_net = tf.keras.applications.MobileNetV2()
|
||||||
|
|
||||||
|
|
||||||
def image_classifier(im):
|
def image_classifier(im):
|
||||||
|
return 0
|
||||||
arr = np.expand_dims(im, axis=0)
|
arr = np.expand_dims(im, axis=0)
|
||||||
arr = tf.keras.applications.mobilenet.preprocess_input(arr)
|
# arr = tf.keras.applications.mobilenet.preprocess_input(arr)
|
||||||
prediction = mobile_net.predict(arr).flatten()
|
prediction = mobile_net.predict(arr).flatten()
|
||||||
return {labels[i]: float(prediction[i]) for i in range(1000)}
|
return {labels[i]: float(prediction[i]) for i in range(1000)}
|
||||||
|
|
||||||
|
@ -640,6 +640,8 @@ class Image(InputComponent):
|
|||||||
output_scores = (output_scores - min_val) / (max_val - min_val)
|
output_scores = (output_scores - min_val) / (max_val - min_val)
|
||||||
return output_scores.tolist()
|
return output_scores.tolist()
|
||||||
|
|
||||||
|
def embed(self, x):
|
||||||
|
return x.flatten()
|
||||||
|
|
||||||
class Audio(InputComponent):
|
class Audio(InputComponent):
|
||||||
"""
|
"""
|
||||||
|
@ -15,6 +15,7 @@ import inspect
|
|||||||
import sys
|
import sys
|
||||||
import weakref
|
import weakref
|
||||||
import analytics
|
import analytics
|
||||||
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
@ -42,6 +43,7 @@ class Interface:
|
|||||||
description=None, thumbnail=None, server_port=None,
|
description=None, thumbnail=None, server_port=None,
|
||||||
server_name=networking.LOCALHOST_NAME,
|
server_name=networking.LOCALHOST_NAME,
|
||||||
allow_screenshot=True, allow_flagging=True,
|
allow_screenshot=True, allow_flagging=True,
|
||||||
|
embedding_fn="default",
|
||||||
flagging_dir="flagged", analytics_enabled=True):
|
flagging_dir="flagged", analytics_enabled=True):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -241,6 +243,16 @@ class Interface:
|
|||||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||||
return processed_output, durations
|
return processed_output, durations
|
||||||
|
|
||||||
|
def embedding_fn(self, raw_input):
|
||||||
|
if self.interpretation == "default":
|
||||||
|
processed_input = [input_interface.preprocess(raw_input[i])
|
||||||
|
for i, input_interface in enumerate(self.input_interfaces)]
|
||||||
|
embedding = np.concatenate([input_interface.embed(processed_input[i])
|
||||||
|
for i, input_interface in enumerate(self.input_interfaces)])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only default embedding is currently supported")
|
||||||
|
return embedding
|
||||||
|
|
||||||
def interpret(self, raw_input):
|
def interpret(self, raw_input):
|
||||||
"""
|
"""
|
||||||
Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box
|
Runs the interpretation command for the machine learning model. Handles both the "default" out-of-the-box
|
||||||
|
@ -11,16 +11,17 @@ from flask_cors import CORS
|
|||||||
import threading
|
import threading
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
from distutils import dir_util
|
from distutils import dir_util
|
||||||
import gradio as gr
|
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
from gradio.tunneling import create_tunnel
|
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
import requests
|
import requests
|
||||||
import sys
|
import sys
|
||||||
import csv
|
import csv
|
||||||
import logging
|
import logging
|
||||||
|
import gradio as gr
|
||||||
|
from gradio.similarity import calculate_similarity
|
||||||
|
from gradio.tunneling import create_tunnel
|
||||||
|
|
||||||
INITIAL_PORT_VALUE = int(os.getenv(
|
INITIAL_PORT_VALUE = int(os.getenv(
|
||||||
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||||
@ -120,6 +121,19 @@ def predict():
|
|||||||
return jsonify(output)
|
return jsonify(output)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/api/score_similarity/", methods=["POST"])
|
||||||
|
def score_similarity():
|
||||||
|
raw_input = request.json["data"]
|
||||||
|
input_embedding = app.interface.embedding_fn(raw_input)
|
||||||
|
scores = list()
|
||||||
|
for example in app.interface.examples:
|
||||||
|
preprocessed_example = [iface.preprocess_example(example)
|
||||||
|
for iface, example in zip(app.interface.input_interfaces, example)]
|
||||||
|
example_embedding = app.interface.embedding_fn(preprocessed_example)
|
||||||
|
scores.append(calculate_similarity(input_embedding, example_embedding))
|
||||||
|
return jsonify({"data": scores})
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/predict_examples/", methods=["POST"])
|
@app.route("/api/predict_examples/", methods=["POST"])
|
||||||
def predict_examples():
|
def predict_examples():
|
||||||
example_ids = request.json["data"]
|
example_ids = request.json["data"]
|
||||||
|
8
gradio/similarity.py
Normal file
8
gradio/similarity.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def calculate_similarity(embedding1, embedding2):
|
||||||
|
"""
|
||||||
|
Scores the similarity between two embeddings by taking the L2 distance
|
||||||
|
"""
|
||||||
|
return np.linalg.norm(np.array(embedding1) - np.array(embedding2))
|
||||||
|
|
@ -36,8 +36,33 @@ var io_master_template = {
|
|||||||
this.target.find(".loading_in_progress").hide();
|
this.target.find(".loading_in_progress").hide();
|
||||||
this.target.find(".loading_failed").show();
|
this.target.find(".loading_failed").show();
|
||||||
});
|
});
|
||||||
|
},
|
||||||
|
score_similarity: function() {
|
||||||
|
this.target.find(".loading").removeClass("invisible");
|
||||||
|
this.target.find(".loading_in_progress").show();
|
||||||
|
this.target.find(".loading_failed").hide();
|
||||||
|
this.target.find(".output_interfaces").css("opacity", 0.5);
|
||||||
|
|
||||||
|
this.fn(this.last_input, "score_similarity").then((output) => {
|
||||||
|
console.log(output.data)
|
||||||
|
this.target.find(".loading").addClass("invisible");
|
||||||
|
this.target.find(".output_interfaces").css("opacity", 1);
|
||||||
|
let html = "<th>DIFFS</th>"
|
||||||
|
this.target.find(".examples > table > thead > tr").append(html);
|
||||||
|
for (let i = 0; i < output["data"].length; i++) {
|
||||||
|
let html = "<td>" + output["data"][i] + "</td>"
|
||||||
|
this.target.find(".examples_body tr[row='" + i + "']").append(html);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
},
|
},
|
||||||
submit_examples: function() {
|
submit_examples: function() {
|
||||||
|
this.target.find(".loading").removeClass("invisible");
|
||||||
|
this.target.find(".loading_in_progress").show();
|
||||||
|
this.target.find(".loading_failed").hide();
|
||||||
|
this.target.find(".output_interfaces").css("opacity", 0.5);
|
||||||
|
|
||||||
let example_ids = [];
|
let example_ids = [];
|
||||||
if (this.loaded_examples == null) {
|
if (this.loaded_examples == null) {
|
||||||
this.loaded_examples = {};
|
this.loaded_examples = {};
|
||||||
@ -48,6 +73,9 @@ var io_master_template = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
this.fn(example_ids, "predict_examples").then((output) => {
|
this.fn(example_ids, "predict_examples").then((output) => {
|
||||||
|
this.target.find(".loading").addClass("invisible");
|
||||||
|
this.target.find(".output_interfaces").css("opacity", 1);
|
||||||
|
|
||||||
output = output["data"];
|
output = output["data"];
|
||||||
if (!this.has_loaded_examples) {
|
if (!this.has_loaded_examples) {
|
||||||
this.has_loaded_examples = true;
|
this.has_loaded_examples = true;
|
||||||
@ -147,3 +175,5 @@ var io_master_template = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -320,7 +320,8 @@ function gradio(config, fn, target, example_file_path) {
|
|||||||
target.find(".interpret").click(function() {
|
target.find(".interpret").click(function() {
|
||||||
target.find(".interpretation_explained").removeClass("invisible");
|
target.find(".interpretation_explained").removeClass("invisible");
|
||||||
if (io_master.last_output) {
|
if (io_master.last_output) {
|
||||||
io_master.interpret();
|
io_master.score_similarity();
|
||||||
|
// io_master.interpret(); // TODO(UNDO)
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
target.find(".run_examples").click(function() {
|
target.find(".run_examples").click(function() {
|
||||||
|
Loading…
Reference in New Issue
Block a user