labeled segments built

This commit is contained in:
Ali Abid 2020-11-23 15:17:42 -08:00
parent b1b4bac317
commit 82e49e72d9
11 changed files with 277 additions and 40 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -44,7 +44,7 @@ class Interface:
title=None, description=None, thumbnail=None,
server_port=None, server_name=networking.LOCALHOST_NAME,
allow_screenshot=True, allow_flagging=True,
embedding_fn="default",
embedding="default",
flagging_dir="flagged", analytics_enabled=True):
"""
@ -128,7 +128,7 @@ class Interface:
self.analytics_enabled=analytics_enabled
self.save_to = None
self.share = None
self.embedding_fn = embedding_fn
self.embedding = embedding
data = {'fn': fn,
'inputs': inputs,
@ -252,12 +252,12 @@ class Interface:
return processed_output, durations
def embed(self, processed_input):
if self.embedding_fn == "default":
embedding = np.concatenate([input_interface.embed(processed_input[i])
if self.embedding == "default":
embeddings = np.concatenate([input_interface.embed(processed_input[i])
for i, input_interface in enumerate(self.input_interfaces)])
else:
embedding = self.embedding_fn(*processed_input)
return embedding
embeddings = self.embedding(*processed_input)
return embeddings
def interpret(self, raw_input):
"""

View File

@ -20,7 +20,7 @@ import sys
import csv
import logging
import gradio as gr
from gradio.embeddings import calculate_similarity
from gradio.embeddings import calculate_similarity, fit_pca_to_embeddings, transform_with_pca
from gradio.tunneling import create_tunnel
INITIAL_PORT_VALUE = int(os.getenv(
@ -124,18 +124,57 @@ def predict():
@app.route("/api/score_similarity/", methods=["POST"])
def score_similarity():
raw_input = request.json["data"]
preprocessed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(app.interface.input_interfaces)]
input_embedding = app.interface.embed(preprocessed_input)
scores = list()
for example in app.interface.examples:
preprocessed_example = [iface.preprocess(iface.preprocess_example(example))
for iface, example in zip(app.interface.input_interfaces, example)]
example_embedding = app.interface.embed(preprocessed_example)
scores.append(calculate_similarity(input_embedding, example_embedding))
scores.append(calculate_similarity(input_embedding, example_embedding))
return jsonify({"data": scores})
@app.route("/api/view_embeddings/", methods=["POST"])
def view_embeddings():
sample_embedding = []
if "data" in request.json:
raw_input = request.json["data"]
preprocessed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(app.interface.input_interfaces)]
sample_embedding.append(app.interface.embed(preprocessed_input))
example_embeddings = []
for example in app.interface.examples:
preprocessed_example = [iface.preprocess(iface.preprocess_example(example))
for iface, example in zip(app.interface.input_interfaces, example)]
example_embedding = app.interface.embed(preprocessed_example)
example_embeddings.append(example_embedding)
pca_model, embeddings_2d = fit_pca_to_embeddings(sample_embedding + example_embeddings)
sample_embedding_2d = embeddings_2d[:len(sample_embedding)]
example_embeddings_2d = embeddings_2d[len(sample_embedding):]
app.pca_model = pca_model
return jsonify({"sample_embedding_2d": sample_embedding_2d, "example_embeddings_2d": example_embeddings_2d})
@app.route("/api/update_embeddings/", methods=["POST"])
def update_embeddings():
sample_embedding, sample_embedding_2d = [], []
if "data" in request.json:
raw_input = request.json["data"]
preprocessed_input = [input_interface.preprocess(raw_input[i])
for i, input_interface in enumerate(app.interface.input_interfaces)]
sample_embedding.append(app.interface.embed(preprocessed_input))
sample_embedding_2d = transform_with_pca(app.pca_model, sample_embedding)
return jsonify({"sample_embedding_2d": sample_embedding_2d})
@app.route("/api/predict_examples/", methods=["POST"])
def predict_examples():
example_ids = request.json["data"]

View File

@ -83,16 +83,19 @@ def decode_base64_to_file(encoding):
# AUDIO FILES
##################
def generate_mfcc_features_from_audio_file(wav_filename,
def generate_mfcc_features_from_audio_file(wav_filename=None,
pre_emphasis=0.95,
frame_size= 0.025,
frame_stride=0.01,
NFFT=512,
nfilt=40,
num_ceps=12,
cep_lifter=22):
cep_lifter=22,
sample_rate=None,
signal=None,
downsample_to=None):
"""
Loads and preprocesses a .wav audio file into mfcc coefficients, the typical inputs to models.
Loads and preprocesses a .wav audio file (or alternatively, a sample rate & signal) into mfcc coefficients, the typical inputs to models.
Adapted from: https://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
:param wav_filename: string name of audio file to process.
:param pre_emphasis: a float factor, typically 0.95 or 0.97, which amplifies high frequencies.
@ -102,9 +105,21 @@ def generate_mfcc_features_from_audio_file(wav_filename,
:param nfilt: The number of filters on the Mel-scale to extract frequency bands.
:param num_ceps: the number of cepstral coefficients to retrain.
:param cep_lifter: the int factor, by which to de-emphasize higher-frequency.
:return: a numpy array of mfcc coefficients.
:param sample_rate: optional param represnting sample rate that is used if `wav_filename` is not provided
:param signal: optional param representing sample data that is used if `wav_filename` is not provided
:param downsample_to: optional param. If provided, audio file is downsampled to this many frames.
:return: a 3D numpy array of mfcc coefficients, of the shape 1 x num_frames x num_coeffs.
"""
sample_rate, signal = scipy.io.wavfile.read(wav_filename)
if (wav_filename is None) and (sample_rate is None or signal is None):
raise ValueError("Either a wav_filename must be provdied or a sample_rate and signal")
elif wav_filename is None:
pass
else:
sample_rate, signal = scipy.io.wavfile.read(wav_filename)
if not(downsample_to is None):
signal = scipy.signal.resample(signal, downsample_to)
emphasized_signal = np.append(signal[0], signal[1:] - pre_emphasis * signal[:-1])
frame_length, frame_step = frame_size * sample_rate, frame_stride * sample_rate # Convert from seconds to samples

View File

@ -52,6 +52,30 @@ var io_master_template = {
callback();
})
},
view_embeddings: function(callback) {
this.target.find(".loading").removeClass("invisible");
this.target.find(".loading_in_progress").show();
this.target.find(".loading_failed").hide();
this.target.find(".output_interfaces").css("opacity", 0.5);
this.fn(this.last_input, "view_embeddings").then((output) => {
this.target.find(".loading").addClass("invisible");
this.target.find(".output_interfaces").css("opacity", 1);
callback(output)
})
},
update_embeddings: function(callback) {
this.target.find(".loading").removeClass("invisible");
this.target.find(".loading_in_progress").show();
this.target.find(".loading_failed").hide();
this.target.find(".output_interfaces").css("opacity", 0.5);
this.fn(this.last_input, "update_embeddings").then((output) => {
this.target.find(".loading").addClass("invisible");
this.target.find(".output_interfaces").css("opacity", 1);
callback(output)
})
},
submit_examples: function(callback) {
this.target.find(".loading").removeClass("invisible");
this.target.find(".loading_in_progress").show();

View File

@ -48,13 +48,17 @@ function gradio(config, fn, target, example_file_path) {
</div>
<div class="examples invisible">
<h4>Examples</small></h4>
<button class="run_examples">Run All</button>
<button class="load_prev">Load Previous <em>(CTRL + &larr;)</em></button>
<button class="load_next">Load Next <em>(CTRL + &rarr;)</em></button>
<button class="order_similar">Order by Similarity</em></button>
<button class="run_examples examples-content">Run All</button>
<button class="load_prev examples-content">Load Previous <em>(CTRL + &larr;)</em></button>
<button class="load_next examples-content">Load Next <em>(CTRL + &rarr;)</em></button>
<button class="order_similar examples-content">Order by Similarity</button>
<button class="view_embeddings examples-content">View Embeddings</button>
<button class="update_embeddings embeddings-content invisible">Update Embeddings</button>
<button class="view_examples embeddings-content invisible">View Examples</button>
<div class="pages invisible">Page:</div>
<table>
<table class="examples-content">
</table>
<div class="plot embeddings-content invisible"><canvas id="canvas" width="400px" height="300px"></canvas></div>
</div>`);
let io_master = Object.create(io_master_template);
io_master.fn = fn
@ -93,7 +97,8 @@ function gradio(config, fn, target, example_file_path) {
"dataframe" : dataframe_output,
}
let id_to_interface_map = {}
let embedding_chart;
function set_interface_id(interface, id) {
interface.id = id;
id_to_interface_map[id] = interface;
@ -265,7 +270,6 @@ function gradio(config, fn, target, example_file_path) {
let html = "";
for (let i = page_start; i < page_start + config["examples_per_page"] && i < config.examples.length; i++) {
let example_id = io_master.order_mapping[i];
console.log(example_id)
let example = config["examples"][example_id];
html += "<tr row=" + example_id + ">";
for (let [j, col] of example.entries()) {
@ -330,6 +334,59 @@ function gradio(config, fn, target, example_file_path) {
load_page();
})
});
target.find(".view_examples").click(function() {
target.find(".examples-content").removeClass("invisible");
target.find(".embeddings-content").addClass("invisible");
});
target.find(".update_embeddings").click(function() {
io_master.update_embeddings(function(output) {
embedding_chart.data.datasets[0].data.push(output["sample_embedding_2d"][0]);
console.log(output["sample_embedding_2d"][0])
embedding_chart.update();
})
});
target.find(".view_embeddings").click(function() {
io_master.view_embeddings(function(output) {
let ctx = $('#canvas')[0].getContext('2d');
let backgroundColors = getBackgroundColors(io_master);
embedding_chart = new Chart(ctx, {
type: 'scatter',
data: {
datasets: [{
label: 'Sample Embedding',
data: output["sample_embedding_2d"],
backgroundColor: 'rgb(0, 0, 0)',
borderColor: 'rgb(0, 0, 0)',
pointRadius: 13,
pointHoverRadius: 13,
pointStyle: 'rectRot',
showLine: true,
fill: false,
}, {
label: 'Dataset Embeddings',
data: output["example_embeddings_2d"],
backgroundColor: backgroundColors,
borderColor: backgroundColors,
pointRadius: 7,
pointHoverRadius: 7
}]
},
options: {
legend: {display: false}
}
});
$("#canvas")[0].onclick = function(evt){
var activePoints = embedding_chart.getElementsAtEvent(evt);
var firstPoint = activePoints[0];
if (firstPoint._datasetIndex==1) { // if it's from the sample embeddings dataset
load_example(firstPoint._index)
}
};
target.find(".examples-content").addClass("invisible");
target.find(".embeddings-content").removeClass("invisible");
})
});
$("body").keydown(function(e) {
if ($(document.activeElement).attr("type") == "text" || $(document.activeElement).attr("type") == "textarea") {
return;
@ -345,8 +402,9 @@ function gradio(config, fn, target, example_file_path) {
}
}
});
});
};
target.find(".screenshot").click(function() {
$(".screenshot, .record").hide();
$(".screenshot_logo").removeClass("invisible");

View File

@ -84,6 +84,67 @@ function paintSaliency(data, ctx, width, height) {
})
}
function getBackgroundColors(io_master){
//Gets the background colors for the embedding plot
// If labels aren't loaded, or it's not a label output interface:
if (!io_master.loaded_examples || io_master["config"]["output_interfaces"][0][0]!="label") {
return 'rgb(54, 162, 235)'
}
// If it is a label interface, get the labels
let labels = []
let isConfidencesPresent = false;
for (let i=0; i<Object.keys(io_master.loaded_examples).length; i++) {
let label = io_master.loaded_examples[i][0]["label"];
if ("confidences" in io_master.loaded_examples[i][0]){
isConfidencesPresent = true;
}
labels.push(label);
}
// If they are all numbers, and there are no confidences, then it's a regression
const isNumeric = (currentValue) => !isNaN(currentValue);
let isNumericArray = labels.every(isNumeric);
if (isNumericArray && !isConfidencesPresent) {
let backgroundColors = [];
labels = labels.map(Number);
let max = Math.max(...labels);
let min = Math.min(...labels);
let rgb_max = [255, 178, 102]
let rgb_min = [204, 255, 255]
for (let i=0; i<labels.length; i++) {
let frac = (Number(labels[i])-min)/(max-min)
let color = [rgb_min[0]+frac*(rgb_max[0]-rgb_min[0]),
rgb_min[1]+frac*(rgb_max[1]-rgb_min[1]),
rgb_min[2]+frac*(rgb_max[2]-rgb_min[2])]
backgroundColors.push(color);
}
}
// Otherwise, it's a classification
let colorArray = ['#FF6633', '#FFB399', '#FF33FF', '#00B3E6',
'#E6B333', '#3366E6', '#999966', '#99FF99', '#B34D4D',
'#80B300', '#809900', '#E6B3B3', '#6680B3', '#66991A',
'#FF99E6', '#CCFF1A', '#FF1A66', '#E6331A', '#33FFCC',
'#66994D', '#B366CC', '#4D8000', '#B33300', '#CC80CC',
'#66664D', '#991AFF', '#E666FF', '#4DB3FF', '#1AB399',
'#E666B3', '#33991A', '#CC9999', '#B3B31A', '#00E680',
'#4D8066', '#809980', '#E6FF80', '#1AFF33', '#999933',
'#FF3380', '#CCCC00', '#66E64D', '#4D80CC', '#9900B3',
'#E64D66', '#4DB380', '#FF4D4D', '#99E6E6', '#6666FF'];
let backgroundColors = [];
let label_list = [];
for (let i=0; i<labels.length; i++) {
if (!(label_list.includes(labels[i]))){
label_list.push(labels[i]);
}
backgroundColors.push(colorArray[label_list.indexOf(labels[i]) % colorArray.length]);
}
return backgroundColors
}
function getSaliencyColor(value) {
if (value < 0) {
var color = [52, 152, 219];

File diff suppressed because one or more lines are too long

View File

@ -91,6 +91,7 @@
<script src="{{ vendor_prefix }}/static/js/vendor/sketchpad.js"></script>
<script src="{{ vendor_prefix }}/static/js/vendor/webcam.min.js"></script>
<script src="{{ vendor_prefix }}/static/js/vendor/gifcap/gifencoder.js"></script>
<script src="{{ vendor_prefix }}/static/js/vendor/Chart.min.js"></script>
<script src="{{ url_for('static', filename='js/utils.js') }}"></script>
<script src="{{ url_for('static', filename='js/all_io.js') }}"></script>

View File

@ -98,6 +98,7 @@ gradio/static/js/interfaces/output/json.js
gradio/static/js/interfaces/output/key_values.js
gradio/static/js/interfaces/output/label.js
gradio/static/js/interfaces/output/textbox.js
gradio/static/js/vendor/Chart.min.js
gradio/static/js/vendor/FileSaver.min.js
gradio/static/js/vendor/black-theme.js
gradio/static/js/vendor/cropper.min.js