mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-12 10:34:32 +08:00
added embedding plots
This commit is contained in:
parent
5501e5f579
commit
fd0c4d34c3
@ -52,6 +52,7 @@ function gradio(config, fn, target, example_file_path) {
|
||||
<button class="load_prev">Load Previous <em>(CTRL + ←)</em></button>
|
||||
<button class="load_next">Load Next <em>(CTRL + →)</em></button>
|
||||
<button class="order_similar">Order by Similarity</button>
|
||||
<button class="view_embeddings">Plot Embeddings</button>
|
||||
<div class="pages invisible">Page:</div>
|
||||
<table>
|
||||
</table>
|
||||
@ -338,7 +339,7 @@ function gradio(config, fn, target, example_file_path) {
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
});
|
||||
|
||||
target.find(".screenshot").click(function() {
|
||||
$(".screenshot, .record").hide();
|
||||
|
@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
SMALL_CONST = 1e-10
|
||||
|
||||
@ -19,3 +20,18 @@ def calculate_similarity(embedding1, embedding2):
|
||||
cosine_similarity = np.dot(e1, e2) / (np.linalg.norm(e1) * np.linalg.norm(e2) + SMALL_CONST)
|
||||
return cosine_similarity
|
||||
|
||||
def fit_pca_to_embeddings(embeddings):
|
||||
"""
|
||||
Computes 2D tsne embeddings from a list of higher-dimensional embeddings
|
||||
"""
|
||||
pca_model = PCA(n_components=2, random_state=0)
|
||||
embeddings = np.array(embeddings)
|
||||
embeddings_2D = pca_model.fit_transform(embeddings)
|
||||
return pca_model, [{'x': e[0], 'y': e[1]} for e in embeddings_2D.tolist()]
|
||||
|
||||
def transform_with_pca(pca_model, embeddings):
|
||||
"""
|
||||
Computes 2D tsne embeddings from a list of higher-dimensional embeddings
|
||||
"""
|
||||
embeddings_2D = pca_model.transform(embeddings)
|
||||
return [{'x': e[0], 'y': e[1]} for e in embeddings_2D.tolist()]
|
||||
|
@ -453,7 +453,7 @@ class CheckboxGroup(InputComponent):
|
||||
if self.type == "value":
|
||||
return [choice in x for choice in self.choices]
|
||||
elif self.type == "index":
|
||||
return [index in x for index in range(len(choices))]
|
||||
return [index in x for index in range(len(self.choices))]
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
|
||||
|
||||
@ -514,7 +514,7 @@ class Radio(InputComponent):
|
||||
if self.type == "value":
|
||||
return [choice==x for choice in self.choices]
|
||||
elif self.type == "index":
|
||||
return [index==x for index in range(len(choices))]
|
||||
return [index==x for index in range(len(self.choices))]
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
|
||||
|
||||
@ -574,7 +574,7 @@ class Dropdown(InputComponent):
|
||||
if self.type == "value":
|
||||
return [choice==x for choice in self.choices]
|
||||
elif self.type == "index":
|
||||
return [index==x for index in range(len(choices))]
|
||||
return [index==x for index in range(len(self.choices))]
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
|
||||
|
||||
@ -791,7 +791,23 @@ class Audio(InputComponent):
|
||||
return scores
|
||||
|
||||
def embed(self, x):
|
||||
raise NotImplementedError("Audio doesn't currently support embeddings")
|
||||
"""
|
||||
Resamples each audio signal to be 1,000 frames and then returns the flattened vectors
|
||||
"""
|
||||
num_frames = 1000
|
||||
if self.type == "file":
|
||||
filename = x.name
|
||||
mfcc = processing_utils.generate_mfcc_features_from_audio_file(filename, downsample_to=num_frames)
|
||||
return mfcc.flatten()
|
||||
elif self.type == "numpy":
|
||||
sample_rate, signal = x
|
||||
mfcc = processing_utils.generate_mfcc_features_from_audio_file(wav_filename=None, sample_rate=sample_rate, signal=signal, downsample_to=num_frames)
|
||||
return mfcc.flatten()
|
||||
elif self.type == "mfcc":
|
||||
mfcc = scipy.signal.resample(x, num_frames, axis=1)
|
||||
return mfcc.flatten()
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'numpy', 'mfcc', 'file'.")
|
||||
|
||||
|
||||
class File(InputComponent):
|
||||
|
@ -20,7 +20,7 @@ import sys
|
||||
import csv
|
||||
import logging
|
||||
import gradio as gr
|
||||
from gradio.embeddings import calculate_similarity
|
||||
from gradio.embeddings import calculate_similarity, fit_pca_to_embeddings, transform_with_pca
|
||||
from gradio.tunneling import create_tunnel
|
||||
|
||||
INITIAL_PORT_VALUE = int(os.getenv(
|
||||
@ -124,18 +124,57 @@ def predict():
|
||||
@app.route("/api/score_similarity/", methods=["POST"])
|
||||
def score_similarity():
|
||||
raw_input = request.json["data"]
|
||||
|
||||
preprocessed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
||||
input_embedding = app.interface.embed(preprocessed_input)
|
||||
scores = list()
|
||||
|
||||
for example in app.interface.examples:
|
||||
preprocessed_example = [iface.preprocess(iface.preprocess_example(example))
|
||||
for iface, example in zip(app.interface.input_interfaces, example)]
|
||||
example_embedding = app.interface.embed(preprocessed_example)
|
||||
scores.append(calculate_similarity(input_embedding, example_embedding))
|
||||
|
||||
return jsonify({"data": scores})
|
||||
|
||||
|
||||
@app.route("/api/view_embeddings/", methods=["POST"])
|
||||
def view_embeddings():
|
||||
sample_embedding = []
|
||||
if "data" in request.json:
|
||||
raw_input = request.json["data"]
|
||||
preprocessed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
||||
sample_embedding.append(app.interface.embed(preprocessed_input))
|
||||
|
||||
example_embeddings = []
|
||||
for example in app.interface.examples:
|
||||
preprocessed_example = [iface.preprocess(iface.preprocess_example(example))
|
||||
for iface, example in zip(app.interface.input_interfaces, example)]
|
||||
example_embedding = app.interface.embed(preprocessed_example)
|
||||
example_embeddings.append(example_embedding)
|
||||
|
||||
pca_model, embeddings_2d = fit_pca_to_embeddings(sample_embedding + example_embeddings)
|
||||
sample_embedding_2d = embeddings_2d[:len(sample_embedding)]
|
||||
example_embeddings_2d = embeddings_2d[len(sample_embedding):]
|
||||
app.pca_model = pca_model
|
||||
return jsonify({"sample_embedding_2d": sample_embedding_2d, "example_embeddings_2d": example_embeddings_2d})
|
||||
|
||||
|
||||
@app.route("/api/update_embeddings/", methods=["POST"])
|
||||
def update_embeddings():
|
||||
sample_embedding, sample_embedding_2d = [], []
|
||||
if "data" in request.json:
|
||||
raw_input = request.json["data"]
|
||||
preprocessed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
||||
sample_embedding.append(app.interface.embed(preprocessed_input))
|
||||
sample_embedding_2d = transform_with_pca(app.pca_model, sample_embedding)
|
||||
|
||||
return jsonify({"sample_embedding_2d": sample_embedding_2d})
|
||||
|
||||
|
||||
@app.route("/api/predict_examples/", methods=["POST"])
|
||||
def predict_examples():
|
||||
example_ids = request.json["data"]
|
||||
|
@ -83,16 +83,19 @@ def decode_base64_to_file(encoding):
|
||||
# AUDIO FILES
|
||||
##################
|
||||
|
||||
def generate_mfcc_features_from_audio_file(wav_filename,
|
||||
def generate_mfcc_features_from_audio_file(wav_filename=None,
|
||||
pre_emphasis=0.95,
|
||||
frame_size= 0.025,
|
||||
frame_stride=0.01,
|
||||
NFFT=512,
|
||||
nfilt=40,
|
||||
num_ceps=12,
|
||||
cep_lifter=22):
|
||||
cep_lifter=22,
|
||||
sample_rate=None,
|
||||
signal=None,
|
||||
downsample_to=None):
|
||||
"""
|
||||
Loads and preprocesses a .wav audio file into mfcc coefficients, the typical inputs to models.
|
||||
Loads and preprocesses a .wav audio file (or alternatively, a sample rate & signal) into mfcc coefficients, the typical inputs to models.
|
||||
Adapted from: https://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
|
||||
:param wav_filename: string name of audio file to process.
|
||||
:param pre_emphasis: a float factor, typically 0.95 or 0.97, which amplifies high frequencies.
|
||||
@ -102,9 +105,21 @@ def generate_mfcc_features_from_audio_file(wav_filename,
|
||||
:param nfilt: The number of filters on the Mel-scale to extract frequency bands.
|
||||
:param num_ceps: the number of cepstral coefficients to retrain.
|
||||
:param cep_lifter: the int factor, by which to de-emphasize higher-frequency.
|
||||
:return: a numpy array of mfcc coefficients.
|
||||
:param sample_rate: optional param represnting sample rate that is used if `wav_filename` is not provided
|
||||
:param signal: optional param representing sample data that is used if `wav_filename` is not provided
|
||||
:param downsample_to: optional param. If provided, audio file is downsampled to this many frames.
|
||||
:return: a 3D numpy array of mfcc coefficients, of the shape 1 x num_frames x num_coeffs.
|
||||
"""
|
||||
sample_rate, signal = scipy.io.wavfile.read(wav_filename)
|
||||
if (wav_filename is None) and (sample_rate is None or signal is None):
|
||||
raise ValueError("Either a wav_filename must be provdied or a sample_rate and signal")
|
||||
elif wav_filename is None:
|
||||
pass
|
||||
else:
|
||||
sample_rate, signal = scipy.io.wavfile.read(wav_filename)
|
||||
|
||||
if not(downsample_to is None):
|
||||
signal = scipy.signal.resample(signal, downsample_to)
|
||||
|
||||
emphasized_signal = np.append(signal[0], signal[1:] - pre_emphasis * signal[:-1])
|
||||
|
||||
frame_length, frame_step = frame_size * sample_rate, frame_stride * sample_rate # Convert from seconds to samples
|
||||
|
@ -52,6 +52,30 @@ var io_master_template = {
|
||||
callback();
|
||||
})
|
||||
},
|
||||
view_embeddings: function(callback) {
|
||||
this.target.find(".loading").removeClass("invisible");
|
||||
this.target.find(".loading_in_progress").show();
|
||||
this.target.find(".loading_failed").hide();
|
||||
this.target.find(".output_interfaces").css("opacity", 0.5);
|
||||
|
||||
this.fn(this.last_input, "view_embeddings").then((output) => {
|
||||
this.target.find(".loading").addClass("invisible");
|
||||
this.target.find(".output_interfaces").css("opacity", 1);
|
||||
callback(output)
|
||||
})
|
||||
},
|
||||
update_embeddings: function(callback) {
|
||||
this.target.find(".loading").removeClass("invisible");
|
||||
this.target.find(".loading_in_progress").show();
|
||||
this.target.find(".loading_failed").hide();
|
||||
this.target.find(".output_interfaces").css("opacity", 0.5);
|
||||
|
||||
this.fn(this.last_input, "update_embeddings").then((output) => {
|
||||
this.target.find(".loading").addClass("invisible");
|
||||
this.target.find(".output_interfaces").css("opacity", 1);
|
||||
callback(output)
|
||||
})
|
||||
},
|
||||
submit_examples: function(callback) {
|
||||
this.target.find(".loading").removeClass("invisible");
|
||||
this.target.find(".loading_in_progress").show();
|
||||
|
@ -52,9 +52,13 @@ function gradio(config, fn, target, example_file_path) {
|
||||
<button class="load_prev">Load Previous <em>(CTRL + ←)</em></button>
|
||||
<button class="load_next">Load Next <em>(CTRL + →)</em></button>
|
||||
<button class="order_similar">Order by Similarity</button>
|
||||
<button class="view_embeddings">View Embeddings</button>
|
||||
<button class="update_embeddings invisible">Update Embeddings</button>
|
||||
<button class="view_examples invisible">View Examples</button>
|
||||
<div class="pages invisible">Page:</div>
|
||||
<table>
|
||||
</table>
|
||||
<div class="plot invisible"><canvas id="canvas" width="400px" height="300px"></canvas></div>
|
||||
</div>`);
|
||||
let io_master = Object.create(io_master_template);
|
||||
io_master.fn = fn
|
||||
@ -93,6 +97,7 @@ function gradio(config, fn, target, example_file_path) {
|
||||
"dataframe" : dataframe_output,
|
||||
}
|
||||
let id_to_interface_map = {}
|
||||
let embedding_chart;
|
||||
|
||||
function set_interface_id(interface, id) {
|
||||
interface.id = id;
|
||||
@ -283,6 +288,63 @@ function gradio(config, fn, target, example_file_path) {
|
||||
}
|
||||
target.find(".examples > table > tbody").html(html);
|
||||
}
|
||||
function getBackgroundColors(){
|
||||
//Gets the background colors for the embedding plot
|
||||
console.log("io", io_master)
|
||||
// If labels aren't loaded, or it's not a label output interface:
|
||||
if (!io_master.loaded_examples || io_master["config"]["output_interfaces"][0][0]!="label") {
|
||||
return 'rgb(54, 162, 235)'
|
||||
}
|
||||
// If it is a label interface, get the labels
|
||||
let labels = []
|
||||
let isConfidencesPresent = false;
|
||||
for (let i=0; i<Object.keys(io_master.loaded_examples).length; i++) {
|
||||
console.log(io_master.loaded_examples[i])
|
||||
let label = io_master.loaded_examples[i][0]["label"];
|
||||
if ("confidences" in io_master.loaded_examples[i][0]){
|
||||
isConfidencesPresent = true;
|
||||
}
|
||||
labels.push(label);
|
||||
}
|
||||
const isNumeric = (currentValue) => !isNaN(currentValue);
|
||||
let isNumericArray = labels.every(isNumeric);
|
||||
// If they are all numbers, and there are no confidences, then it's a regression
|
||||
if (isNumericArray && !isConfidencesPresent) {
|
||||
let backgroundColors = [];
|
||||
labels = labels.map(Number);
|
||||
let max = Math.max(...labels);
|
||||
let min = Math.min(...labels);
|
||||
let rgb_max = [255, 178, 102]
|
||||
let rgb_min = [204, 255, 255]
|
||||
for (let i=0; i<labels.length; i++) {
|
||||
let frac = (Number(labels[i])-min)/(max-min)
|
||||
let color = [rgb_min[0]+frac*(rgb_max[0]-rgb_min[0]),
|
||||
rgb_min[1]+frac*(rgb_max[1]-rgb_min[1]),
|
||||
rgb_min[2]+frac*(rgb_max[2]-rgb_min[2])]
|
||||
backgroundColors.push(color);
|
||||
}
|
||||
}
|
||||
// Otherwise, it's a classification
|
||||
let colorArray = ['#FF6633', '#FFB399', '#FF33FF', '#00B3E6',
|
||||
'#E6B333', '#3366E6', '#999966', '#99FF99', '#B34D4D',
|
||||
'#80B300', '#809900', '#E6B3B3', '#6680B3', '#66991A',
|
||||
'#FF99E6', '#CCFF1A', '#FF1A66', '#E6331A', '#33FFCC',
|
||||
'#66994D', '#B366CC', '#4D8000', '#B33300', '#CC80CC',
|
||||
'#66664D', '#991AFF', '#E666FF', '#4DB3FF', '#1AB399',
|
||||
'#E666B3', '#33991A', '#CC9999', '#B3B31A', '#00E680',
|
||||
'#4D8066', '#809980', '#E6FF80', '#1AFF33', '#999933',
|
||||
'#FF3380', '#CCCC00', '#66E64D', '#4D80CC', '#9900B3',
|
||||
'#E64D66', '#4DB380', '#FF4D4D', '#99E6E6', '#6666FF'];
|
||||
let backgroundColors = [];
|
||||
let label_list = [];
|
||||
for (let i=0; i<labels.length; i++) {
|
||||
if (!(label_list.includes(labels[i]))){
|
||||
label_list.push(labels[i]);
|
||||
}
|
||||
backgroundColors.push(colorArray[label_list.indexOf(labels[i]) % colorArray.length]);
|
||||
}
|
||||
return backgroundColors
|
||||
}
|
||||
if (config["examples"]) {
|
||||
target.find(".examples").removeClass("invisible");
|
||||
let html = "<thead>"
|
||||
@ -323,6 +385,75 @@ function gradio(config, fn, target, example_file_path) {
|
||||
load_page();
|
||||
})
|
||||
});
|
||||
target.find(".view_examples").click(function() {
|
||||
target.find(".examples > table").removeClass("invisible");
|
||||
target.find(".examples > .plot").addClass("invisible");
|
||||
target.find(".run_examples").removeClass("invisible");
|
||||
target.find(".view_embeddings").removeClass("invisible");
|
||||
target.find(".load_prev").removeClass("invisible");
|
||||
target.find(".load_next").removeClass("invisible");
|
||||
target.find(".order_similar").removeClass("invisible");
|
||||
target.find(".pages").removeClass("invisible");
|
||||
target.find(".update_embeddings").addClass("invisible");
|
||||
target.find(".view_examples").addClass("invisible");
|
||||
});
|
||||
target.find(".update_embeddings").click(function() {
|
||||
io_master.update_embeddings(function(output) {
|
||||
embedding_chart.data.datasets[0].data.push(output["sample_embedding_2d"][0]);
|
||||
console.log(output["sample_embedding_2d"][0])
|
||||
embedding_chart.update();
|
||||
})
|
||||
});
|
||||
target.find(".view_embeddings").click(function() {
|
||||
io_master.view_embeddings(function(output) {
|
||||
let ctx = $('#canvas')[0].getContext('2d');
|
||||
let backgroundColors = getBackgroundColors();
|
||||
embedding_chart = new Chart(ctx, {
|
||||
type: 'scatter',
|
||||
data: {
|
||||
datasets: [{
|
||||
label: 'Sample Embedding',
|
||||
data: output["sample_embedding_2d"],
|
||||
backgroundColor: 'rgb(0, 0, 0)',
|
||||
borderColor: 'rgb(0, 0, 0)',
|
||||
pointRadius: 13,
|
||||
pointHoverRadius: 13,
|
||||
pointStyle: 'rectRot',
|
||||
showLine: true,
|
||||
fill: false,
|
||||
}, {
|
||||
label: 'Dataset Embeddings',
|
||||
data: output["example_embeddings_2d"],
|
||||
backgroundColor: backgroundColors,
|
||||
borderColor: backgroundColors,
|
||||
pointRadius: 7,
|
||||
pointHoverRadius: 7
|
||||
}]
|
||||
},
|
||||
options: {
|
||||
legend: {display: false}
|
||||
}
|
||||
});
|
||||
$("#canvas")[0].onclick = function(evt){
|
||||
var activePoints = embedding_chart.getElementsAtEvent(evt);
|
||||
var firstPoint = activePoints[0];
|
||||
if (firstPoint._datasetIndex==1) { // if it's from the sample embeddings dataset
|
||||
load_example(firstPoint._index)
|
||||
}
|
||||
};
|
||||
|
||||
target.find(".examples > table").addClass("invisible");
|
||||
target.find(".examples > .plot").removeClass("invisible");
|
||||
target.find(".run_examples").addClass("invisible");
|
||||
target.find(".view_embeddings").addClass("invisible");
|
||||
target.find(".load_prev").addClass("invisible");
|
||||
target.find(".load_next").addClass("invisible");
|
||||
target.find(".order_similar").addClass("invisible");
|
||||
target.find(".pages").addClass("invisible");
|
||||
target.find(".update_embeddings").removeClass("invisible");
|
||||
target.find(".view_examples").removeClass("invisible");
|
||||
})
|
||||
});
|
||||
$("body").keydown(function(e) {
|
||||
if ($(document.activeElement).attr("type") == "text" || $(document.activeElement).attr("type") == "textarea") {
|
||||
return;
|
||||
@ -340,6 +471,7 @@ function gradio(config, fn, target, example_file_path) {
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
target.find(".screenshot").click(function() {
|
||||
$(".screenshot, .record").hide();
|
||||
$(".screenshot_logo").removeClass("invisible");
|
||||
|
7
gradio/static/js/vendor/Chart.min.js
vendored
Normal file
7
gradio/static/js/vendor/Chart.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@ -91,6 +91,7 @@
|
||||
<script src="{{ vendor_prefix }}/static/js/vendor/sketchpad.js"></script>
|
||||
<script src="{{ vendor_prefix }}/static/js/vendor/webcam.min.js"></script>
|
||||
<script src="{{ vendor_prefix }}/static/js/vendor/gifcap/gifencoder.js"></script>
|
||||
<script src="{{ vendor_prefix }}/static/js/vendor/Chart.min.js"></script>
|
||||
|
||||
<script src="{{ url_for('static', filename='js/utils.js') }}"></script>
|
||||
<script src="{{ url_for('static', filename='js/all_io.js') }}"></script>
|
||||
|
Loading…
Reference in New Issue
Block a user