added embedding plots

This commit is contained in:
Abubakar Abid 2020-11-23 07:57:36 -06:00
parent 5501e5f579
commit fd0c4d34c3
10 changed files with 266 additions and 14 deletions

View File

@ -52,6 +52,7 @@ function gradio(config, fn, target, example_file_path) {
<button class="load_prev">Load Previous <em>(CTRL + &larr;)</em></button>
<button class="load_next">Load Next <em>(CTRL + &rarr;)</em></button>
<button class="order_similar">Order by Similarity</button>
<button class="view_embeddings">Plot Embeddings</button>
<div class="pages invisible">Page:</div>
<table>
</table>
@ -338,7 +339,7 @@ function gradio(config, fn, target, example_file_path) {
}
}
});
};
});
target.find(".screenshot").click(function() {
$(".screenshot, .record").hide();

View File

@ -1,4 +1,5 @@
import numpy as np
from sklearn.decomposition import PCA
SMALL_CONST = 1e-10
@ -18,4 +19,19 @@ def calculate_similarity(embedding1, embedding2):
e1, e2 = np.array(embedding1), np.array(embedding2)
cosine_similarity = np.dot(e1, e2) / (np.linalg.norm(e1) * np.linalg.norm(e2) + SMALL_CONST)
return cosine_similarity
def fit_pca_to_embeddings(embeddings):
"""
Computes 2D tsne embeddings from a list of higher-dimensional embeddings
"""
pca_model = PCA(n_components=2, random_state=0)
embeddings = np.array(embeddings)
embeddings_2D = pca_model.fit_transform(embeddings)
return pca_model, [{'x': e[0], 'y': e[1]} for e in embeddings_2D.tolist()]
def transform_with_pca(pca_model, embeddings):
"""
Computes 2D tsne embeddings from a list of higher-dimensional embeddings
"""
embeddings_2D = pca_model.transform(embeddings)
return [{'x': e[0], 'y': e[1]} for e in embeddings_2D.tolist()]

View File

@ -453,7 +453,7 @@ class CheckboxGroup(InputComponent):
if self.type == "value":
return [choice in x for choice in self.choices]
elif self.type == "index":
return [index in x for index in range(len(choices))]
return [index in x for index in range(len(self.choices))]
else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
@ -514,7 +514,7 @@ class Radio(InputComponent):
if self.type == "value":
return [choice==x for choice in self.choices]
elif self.type == "index":
return [index==x for index in range(len(choices))]
return [index==x for index in range(len(self.choices))]
else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
@ -574,7 +574,7 @@ class Dropdown(InputComponent):
if self.type == "value":
return [choice==x for choice in self.choices]
elif self.type == "index":
return [index==x for index in range(len(choices))]
return [index==x for index in range(len(self.choices))]
else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
@ -791,7 +791,23 @@ class Audio(InputComponent):
return scores
def embed(self, x):
raise NotImplementedError("Audio doesn't currently support embeddings")
"""
Resamples each audio signal to be 1,000 frames and then returns the flattened vectors
"""
num_frames = 1000
if self.type == "file":
filename = x.name
mfcc = processing_utils.generate_mfcc_features_from_audio_file(filename, downsample_to=num_frames)
return mfcc.flatten()
elif self.type == "numpy":
sample_rate, signal = x
mfcc = processing_utils.generate_mfcc_features_from_audio_file(wav_filename=None, sample_rate=sample_rate, signal=signal, downsample_to=num_frames)
return mfcc.flatten()
elif self.type == "mfcc":
mfcc = scipy.signal.resample(x, num_frames, axis=1)
return mfcc.flatten()
else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'numpy', 'mfcc', 'file'.")
class File(InputComponent):

View File

@ -20,7 +20,7 @@ import sys
import csv
import logging
import gradio as gr
from gradio.embeddings import calculate_similarity
from gradio.embeddings import calculate_similarity, fit_pca_to_embeddings, transform_with_pca
from gradio.tunneling import create_tunnel
INITIAL_PORT_VALUE = int(os.getenv(
@ -124,18 +124,57 @@ def predict():
@app.route("/api/score_similarity/", methods=["POST"])
def score_similarity():
raw_input = request.json["data"]
preprocessed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(app.interface.input_interfaces)]
input_embedding = app.interface.embed(preprocessed_input)
scores = list()
for example in app.interface.examples:
preprocessed_example = [iface.preprocess(iface.preprocess_example(example))
for iface, example in zip(app.interface.input_interfaces, example)]
example_embedding = app.interface.embed(preprocessed_example)
scores.append(calculate_similarity(input_embedding, example_embedding))
scores.append(calculate_similarity(input_embedding, example_embedding))
return jsonify({"data": scores})
@app.route("/api/view_embeddings/", methods=["POST"])
def view_embeddings():
sample_embedding = []
if "data" in request.json:
raw_input = request.json["data"]
preprocessed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(app.interface.input_interfaces)]
sample_embedding.append(app.interface.embed(preprocessed_input))
example_embeddings = []
for example in app.interface.examples:
preprocessed_example = [iface.preprocess(iface.preprocess_example(example))
for iface, example in zip(app.interface.input_interfaces, example)]
example_embedding = app.interface.embed(preprocessed_example)
example_embeddings.append(example_embedding)
pca_model, embeddings_2d = fit_pca_to_embeddings(sample_embedding + example_embeddings)
sample_embedding_2d = embeddings_2d[:len(sample_embedding)]
example_embeddings_2d = embeddings_2d[len(sample_embedding):]
app.pca_model = pca_model
return jsonify({"sample_embedding_2d": sample_embedding_2d, "example_embeddings_2d": example_embeddings_2d})
@app.route("/api/update_embeddings/", methods=["POST"])
def update_embeddings():
sample_embedding, sample_embedding_2d = [], []
if "data" in request.json:
raw_input = request.json["data"]
preprocessed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(app.interface.input_interfaces)]
sample_embedding.append(app.interface.embed(preprocessed_input))
sample_embedding_2d = transform_with_pca(app.pca_model, sample_embedding)
return jsonify({"sample_embedding_2d": sample_embedding_2d})
@app.route("/api/predict_examples/", methods=["POST"])
def predict_examples():
example_ids = request.json["data"]

View File

@ -83,16 +83,19 @@ def decode_base64_to_file(encoding):
# AUDIO FILES
##################
def generate_mfcc_features_from_audio_file(wav_filename,
def generate_mfcc_features_from_audio_file(wav_filename=None,
pre_emphasis=0.95,
frame_size= 0.025,
frame_stride=0.01,
NFFT=512,
nfilt=40,
num_ceps=12,
cep_lifter=22):
cep_lifter=22,
sample_rate=None,
signal=None,
downsample_to=None):
"""
Loads and preprocesses a .wav audio file into mfcc coefficients, the typical inputs to models.
Loads and preprocesses a .wav audio file (or alternatively, a sample rate & signal) into mfcc coefficients, the typical inputs to models.
Adapted from: https://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
:param wav_filename: string name of audio file to process.
:param pre_emphasis: a float factor, typically 0.95 or 0.97, which amplifies high frequencies.
@ -102,9 +105,21 @@ def generate_mfcc_features_from_audio_file(wav_filename,
:param nfilt: The number of filters on the Mel-scale to extract frequency bands.
:param num_ceps: the number of cepstral coefficients to retrain.
:param cep_lifter: the int factor, by which to de-emphasize higher-frequency.
:return: a numpy array of mfcc coefficients.
:param sample_rate: optional param represnting sample rate that is used if `wav_filename` is not provided
:param signal: optional param representing sample data that is used if `wav_filename` is not provided
:param downsample_to: optional param. If provided, audio file is downsampled to this many frames.
:return: a 3D numpy array of mfcc coefficients, of the shape 1 x num_frames x num_coeffs.
"""
sample_rate, signal = scipy.io.wavfile.read(wav_filename)
if (wav_filename is None) and (sample_rate is None or signal is None):
raise ValueError("Either a wav_filename must be provdied or a sample_rate and signal")
elif wav_filename is None:
pass
else:
sample_rate, signal = scipy.io.wavfile.read(wav_filename)
if not(downsample_to is None):
signal = scipy.signal.resample(signal, downsample_to)
emphasized_signal = np.append(signal[0], signal[1:] - pre_emphasis * signal[:-1])
frame_length, frame_step = frame_size * sample_rate, frame_stride * sample_rate # Convert from seconds to samples

View File

@ -52,6 +52,30 @@ var io_master_template = {
callback();
})
},
view_embeddings: function(callback) {
this.target.find(".loading").removeClass("invisible");
this.target.find(".loading_in_progress").show();
this.target.find(".loading_failed").hide();
this.target.find(".output_interfaces").css("opacity", 0.5);
this.fn(this.last_input, "view_embeddings").then((output) => {
this.target.find(".loading").addClass("invisible");
this.target.find(".output_interfaces").css("opacity", 1);
callback(output)
})
},
update_embeddings: function(callback) {
this.target.find(".loading").removeClass("invisible");
this.target.find(".loading_in_progress").show();
this.target.find(".loading_failed").hide();
this.target.find(".output_interfaces").css("opacity", 0.5);
this.fn(this.last_input, "update_embeddings").then((output) => {
this.target.find(".loading").addClass("invisible");
this.target.find(".output_interfaces").css("opacity", 1);
callback(output)
})
},
submit_examples: function(callback) {
this.target.find(".loading").removeClass("invisible");
this.target.find(".loading_in_progress").show();

View File

@ -52,9 +52,13 @@ function gradio(config, fn, target, example_file_path) {
<button class="load_prev">Load Previous <em>(CTRL + &larr;)</em></button>
<button class="load_next">Load Next <em>(CTRL + &rarr;)</em></button>
<button class="order_similar">Order by Similarity</button>
<button class="view_embeddings">View Embeddings</button>
<button class="update_embeddings invisible">Update Embeddings</button>
<button class="view_examples invisible">View Examples</button>
<div class="pages invisible">Page:</div>
<table>
</table>
<div class="plot invisible"><canvas id="canvas" width="400px" height="300px"></canvas></div>
</div>`);
let io_master = Object.create(io_master_template);
io_master.fn = fn
@ -93,7 +97,8 @@ function gradio(config, fn, target, example_file_path) {
"dataframe" : dataframe_output,
}
let id_to_interface_map = {}
let embedding_chart;
function set_interface_id(interface, id) {
interface.id = id;
id_to_interface_map[id] = interface;
@ -283,6 +288,63 @@ function gradio(config, fn, target, example_file_path) {
}
target.find(".examples > table > tbody").html(html);
}
function getBackgroundColors(){
//Gets the background colors for the embedding plot
console.log("io", io_master)
// If labels aren't loaded, or it's not a label output interface:
if (!io_master.loaded_examples || io_master["config"]["output_interfaces"][0][0]!="label") {
return 'rgb(54, 162, 235)'
}
// If it is a label interface, get the labels
let labels = []
let isConfidencesPresent = false;
for (let i=0; i<Object.keys(io_master.loaded_examples).length; i++) {
console.log(io_master.loaded_examples[i])
let label = io_master.loaded_examples[i][0]["label"];
if ("confidences" in io_master.loaded_examples[i][0]){
isConfidencesPresent = true;
}
labels.push(label);
}
const isNumeric = (currentValue) => !isNaN(currentValue);
let isNumericArray = labels.every(isNumeric);
// If they are all numbers, and there are no confidences, then it's a regression
if (isNumericArray && !isConfidencesPresent) {
let backgroundColors = [];
labels = labels.map(Number);
let max = Math.max(...labels);
let min = Math.min(...labels);
let rgb_max = [255, 178, 102]
let rgb_min = [204, 255, 255]
for (let i=0; i<labels.length; i++) {
let frac = (Number(labels[i])-min)/(max-min)
let color = [rgb_min[0]+frac*(rgb_max[0]-rgb_min[0]),
rgb_min[1]+frac*(rgb_max[1]-rgb_min[1]),
rgb_min[2]+frac*(rgb_max[2]-rgb_min[2])]
backgroundColors.push(color);
}
}
// Otherwise, it's a classification
let colorArray = ['#FF6633', '#FFB399', '#FF33FF', '#00B3E6',
'#E6B333', '#3366E6', '#999966', '#99FF99', '#B34D4D',
'#80B300', '#809900', '#E6B3B3', '#6680B3', '#66991A',
'#FF99E6', '#CCFF1A', '#FF1A66', '#E6331A', '#33FFCC',
'#66994D', '#B366CC', '#4D8000', '#B33300', '#CC80CC',
'#66664D', '#991AFF', '#E666FF', '#4DB3FF', '#1AB399',
'#E666B3', '#33991A', '#CC9999', '#B3B31A', '#00E680',
'#4D8066', '#809980', '#E6FF80', '#1AFF33', '#999933',
'#FF3380', '#CCCC00', '#66E64D', '#4D80CC', '#9900B3',
'#E64D66', '#4DB380', '#FF4D4D', '#99E6E6', '#6666FF'];
let backgroundColors = [];
let label_list = [];
for (let i=0; i<labels.length; i++) {
if (!(label_list.includes(labels[i]))){
label_list.push(labels[i]);
}
backgroundColors.push(colorArray[label_list.indexOf(labels[i]) % colorArray.length]);
}
return backgroundColors
}
if (config["examples"]) {
target.find(".examples").removeClass("invisible");
let html = "<thead>"
@ -323,6 +385,75 @@ function gradio(config, fn, target, example_file_path) {
load_page();
})
});
target.find(".view_examples").click(function() {
target.find(".examples > table").removeClass("invisible");
target.find(".examples > .plot").addClass("invisible");
target.find(".run_examples").removeClass("invisible");
target.find(".view_embeddings").removeClass("invisible");
target.find(".load_prev").removeClass("invisible");
target.find(".load_next").removeClass("invisible");
target.find(".order_similar").removeClass("invisible");
target.find(".pages").removeClass("invisible");
target.find(".update_embeddings").addClass("invisible");
target.find(".view_examples").addClass("invisible");
});
target.find(".update_embeddings").click(function() {
io_master.update_embeddings(function(output) {
embedding_chart.data.datasets[0].data.push(output["sample_embedding_2d"][0]);
console.log(output["sample_embedding_2d"][0])
embedding_chart.update();
})
});
target.find(".view_embeddings").click(function() {
io_master.view_embeddings(function(output) {
let ctx = $('#canvas')[0].getContext('2d');
let backgroundColors = getBackgroundColors();
embedding_chart = new Chart(ctx, {
type: 'scatter',
data: {
datasets: [{
label: 'Sample Embedding',
data: output["sample_embedding_2d"],
backgroundColor: 'rgb(0, 0, 0)',
borderColor: 'rgb(0, 0, 0)',
pointRadius: 13,
pointHoverRadius: 13,
pointStyle: 'rectRot',
showLine: true,
fill: false,
}, {
label: 'Dataset Embeddings',
data: output["example_embeddings_2d"],
backgroundColor: backgroundColors,
borderColor: backgroundColors,
pointRadius: 7,
pointHoverRadius: 7
}]
},
options: {
legend: {display: false}
}
});
$("#canvas")[0].onclick = function(evt){
var activePoints = embedding_chart.getElementsAtEvent(evt);
var firstPoint = activePoints[0];
if (firstPoint._datasetIndex==1) { // if it's from the sample embeddings dataset
load_example(firstPoint._index)
}
};
target.find(".examples > table").addClass("invisible");
target.find(".examples > .plot").removeClass("invisible");
target.find(".run_examples").addClass("invisible");
target.find(".view_embeddings").addClass("invisible");
target.find(".load_prev").addClass("invisible");
target.find(".load_next").addClass("invisible");
target.find(".order_similar").addClass("invisible");
target.find(".pages").addClass("invisible");
target.find(".update_embeddings").removeClass("invisible");
target.find(".view_examples").removeClass("invisible");
})
});
$("body").keydown(function(e) {
if ($(document.activeElement).attr("type") == "text" || $(document.activeElement).attr("type") == "textarea") {
return;
@ -340,6 +471,7 @@ function gradio(config, fn, target, example_file_path) {
});
};
target.find(".screenshot").click(function() {
$(".screenshot, .record").hide();
$(".screenshot_logo").removeClass("invisible");

7
gradio/static/js/vendor/Chart.min.js vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -91,6 +91,7 @@
<script src="{{ vendor_prefix }}/static/js/vendor/sketchpad.js"></script>
<script src="{{ vendor_prefix }}/static/js/vendor/webcam.min.js"></script>
<script src="{{ vendor_prefix }}/static/js/vendor/gifcap/gifencoder.js"></script>
<script src="{{ vendor_prefix }}/static/js/vendor/Chart.min.js"></script>
<script src="{{ url_for('static', filename='js/utils.js') }}"></script>
<script src="{{ url_for('static', filename='js/all_io.js') }}"></script>

View File

@ -22,6 +22,7 @@ setup(
'paramiko',
'scipy',
'IPython',
'scikit-learn',
'scikit-image',
'analytics-python',
'pandas'