mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
labeled segments built
This commit is contained in:
parent
b1b4bac317
commit
82e49e72d9
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -44,7 +44,7 @@ class Interface:
|
||||
title=None, description=None, thumbnail=None,
|
||||
server_port=None, server_name=networking.LOCALHOST_NAME,
|
||||
allow_screenshot=True, allow_flagging=True,
|
||||
embedding_fn="default",
|
||||
embedding="default",
|
||||
flagging_dir="flagged", analytics_enabled=True):
|
||||
|
||||
"""
|
||||
@ -128,7 +128,7 @@ class Interface:
|
||||
self.analytics_enabled=analytics_enabled
|
||||
self.save_to = None
|
||||
self.share = None
|
||||
self.embedding_fn = embedding_fn
|
||||
self.embedding = embedding
|
||||
|
||||
data = {'fn': fn,
|
||||
'inputs': inputs,
|
||||
@ -252,12 +252,12 @@ class Interface:
|
||||
return processed_output, durations
|
||||
|
||||
def embed(self, processed_input):
|
||||
if self.embedding_fn == "default":
|
||||
embedding = np.concatenate([input_interface.embed(processed_input[i])
|
||||
if self.embedding == "default":
|
||||
embeddings = np.concatenate([input_interface.embed(processed_input[i])
|
||||
for i, input_interface in enumerate(self.input_interfaces)])
|
||||
else:
|
||||
embedding = self.embedding_fn(*processed_input)
|
||||
return embedding
|
||||
embeddings = self.embedding(*processed_input)
|
||||
return embeddings
|
||||
|
||||
def interpret(self, raw_input):
|
||||
"""
|
||||
|
@ -20,7 +20,7 @@ import sys
|
||||
import csv
|
||||
import logging
|
||||
import gradio as gr
|
||||
from gradio.embeddings import calculate_similarity
|
||||
from gradio.embeddings import calculate_similarity, fit_pca_to_embeddings, transform_with_pca
|
||||
from gradio.tunneling import create_tunnel
|
||||
|
||||
INITIAL_PORT_VALUE = int(os.getenv(
|
||||
@ -124,18 +124,57 @@ def predict():
|
||||
@app.route("/api/score_similarity/", methods=["POST"])
|
||||
def score_similarity():
|
||||
raw_input = request.json["data"]
|
||||
|
||||
preprocessed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
||||
input_embedding = app.interface.embed(preprocessed_input)
|
||||
scores = list()
|
||||
|
||||
for example in app.interface.examples:
|
||||
preprocessed_example = [iface.preprocess(iface.preprocess_example(example))
|
||||
for iface, example in zip(app.interface.input_interfaces, example)]
|
||||
example_embedding = app.interface.embed(preprocessed_example)
|
||||
scores.append(calculate_similarity(input_embedding, example_embedding))
|
||||
scores.append(calculate_similarity(input_embedding, example_embedding))
|
||||
|
||||
return jsonify({"data": scores})
|
||||
|
||||
|
||||
@app.route("/api/view_embeddings/", methods=["POST"])
|
||||
def view_embeddings():
|
||||
sample_embedding = []
|
||||
if "data" in request.json:
|
||||
raw_input = request.json["data"]
|
||||
preprocessed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
||||
sample_embedding.append(app.interface.embed(preprocessed_input))
|
||||
|
||||
example_embeddings = []
|
||||
for example in app.interface.examples:
|
||||
preprocessed_example = [iface.preprocess(iface.preprocess_example(example))
|
||||
for iface, example in zip(app.interface.input_interfaces, example)]
|
||||
example_embedding = app.interface.embed(preprocessed_example)
|
||||
example_embeddings.append(example_embedding)
|
||||
|
||||
pca_model, embeddings_2d = fit_pca_to_embeddings(sample_embedding + example_embeddings)
|
||||
sample_embedding_2d = embeddings_2d[:len(sample_embedding)]
|
||||
example_embeddings_2d = embeddings_2d[len(sample_embedding):]
|
||||
app.pca_model = pca_model
|
||||
return jsonify({"sample_embedding_2d": sample_embedding_2d, "example_embeddings_2d": example_embeddings_2d})
|
||||
|
||||
|
||||
@app.route("/api/update_embeddings/", methods=["POST"])
|
||||
def update_embeddings():
|
||||
sample_embedding, sample_embedding_2d = [], []
|
||||
if "data" in request.json:
|
||||
raw_input = request.json["data"]
|
||||
preprocessed_input = [input_interface.preprocess(raw_input[i])
|
||||
for i, input_interface in enumerate(app.interface.input_interfaces)]
|
||||
sample_embedding.append(app.interface.embed(preprocessed_input))
|
||||
sample_embedding_2d = transform_with_pca(app.pca_model, sample_embedding)
|
||||
|
||||
return jsonify({"sample_embedding_2d": sample_embedding_2d})
|
||||
|
||||
|
||||
@app.route("/api/predict_examples/", methods=["POST"])
|
||||
def predict_examples():
|
||||
example_ids = request.json["data"]
|
||||
|
@ -83,16 +83,19 @@ def decode_base64_to_file(encoding):
|
||||
# AUDIO FILES
|
||||
##################
|
||||
|
||||
def generate_mfcc_features_from_audio_file(wav_filename,
|
||||
def generate_mfcc_features_from_audio_file(wav_filename=None,
|
||||
pre_emphasis=0.95,
|
||||
frame_size= 0.025,
|
||||
frame_stride=0.01,
|
||||
NFFT=512,
|
||||
nfilt=40,
|
||||
num_ceps=12,
|
||||
cep_lifter=22):
|
||||
cep_lifter=22,
|
||||
sample_rate=None,
|
||||
signal=None,
|
||||
downsample_to=None):
|
||||
"""
|
||||
Loads and preprocesses a .wav audio file into mfcc coefficients, the typical inputs to models.
|
||||
Loads and preprocesses a .wav audio file (or alternatively, a sample rate & signal) into mfcc coefficients, the typical inputs to models.
|
||||
Adapted from: https://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
|
||||
:param wav_filename: string name of audio file to process.
|
||||
:param pre_emphasis: a float factor, typically 0.95 or 0.97, which amplifies high frequencies.
|
||||
@ -102,9 +105,21 @@ def generate_mfcc_features_from_audio_file(wav_filename,
|
||||
:param nfilt: The number of filters on the Mel-scale to extract frequency bands.
|
||||
:param num_ceps: the number of cepstral coefficients to retrain.
|
||||
:param cep_lifter: the int factor, by which to de-emphasize higher-frequency.
|
||||
:return: a numpy array of mfcc coefficients.
|
||||
:param sample_rate: optional param represnting sample rate that is used if `wav_filename` is not provided
|
||||
:param signal: optional param representing sample data that is used if `wav_filename` is not provided
|
||||
:param downsample_to: optional param. If provided, audio file is downsampled to this many frames.
|
||||
:return: a 3D numpy array of mfcc coefficients, of the shape 1 x num_frames x num_coeffs.
|
||||
"""
|
||||
sample_rate, signal = scipy.io.wavfile.read(wav_filename)
|
||||
if (wav_filename is None) and (sample_rate is None or signal is None):
|
||||
raise ValueError("Either a wav_filename must be provdied or a sample_rate and signal")
|
||||
elif wav_filename is None:
|
||||
pass
|
||||
else:
|
||||
sample_rate, signal = scipy.io.wavfile.read(wav_filename)
|
||||
|
||||
if not(downsample_to is None):
|
||||
signal = scipy.signal.resample(signal, downsample_to)
|
||||
|
||||
emphasized_signal = np.append(signal[0], signal[1:] - pre_emphasis * signal[:-1])
|
||||
|
||||
frame_length, frame_step = frame_size * sample_rate, frame_stride * sample_rate # Convert from seconds to samples
|
||||
|
@ -52,6 +52,30 @@ var io_master_template = {
|
||||
callback();
|
||||
})
|
||||
},
|
||||
view_embeddings: function(callback) {
|
||||
this.target.find(".loading").removeClass("invisible");
|
||||
this.target.find(".loading_in_progress").show();
|
||||
this.target.find(".loading_failed").hide();
|
||||
this.target.find(".output_interfaces").css("opacity", 0.5);
|
||||
|
||||
this.fn(this.last_input, "view_embeddings").then((output) => {
|
||||
this.target.find(".loading").addClass("invisible");
|
||||
this.target.find(".output_interfaces").css("opacity", 1);
|
||||
callback(output)
|
||||
})
|
||||
},
|
||||
update_embeddings: function(callback) {
|
||||
this.target.find(".loading").removeClass("invisible");
|
||||
this.target.find(".loading_in_progress").show();
|
||||
this.target.find(".loading_failed").hide();
|
||||
this.target.find(".output_interfaces").css("opacity", 0.5);
|
||||
|
||||
this.fn(this.last_input, "update_embeddings").then((output) => {
|
||||
this.target.find(".loading").addClass("invisible");
|
||||
this.target.find(".output_interfaces").css("opacity", 1);
|
||||
callback(output)
|
||||
})
|
||||
},
|
||||
submit_examples: function(callback) {
|
||||
this.target.find(".loading").removeClass("invisible");
|
||||
this.target.find(".loading_in_progress").show();
|
||||
|
@ -48,13 +48,17 @@ function gradio(config, fn, target, example_file_path) {
|
||||
</div>
|
||||
<div class="examples invisible">
|
||||
<h4>Examples</small></h4>
|
||||
<button class="run_examples">Run All</button>
|
||||
<button class="load_prev">Load Previous <em>(CTRL + ←)</em></button>
|
||||
<button class="load_next">Load Next <em>(CTRL + →)</em></button>
|
||||
<button class="order_similar">Order by Similarity</em></button>
|
||||
<button class="run_examples examples-content">Run All</button>
|
||||
<button class="load_prev examples-content">Load Previous <em>(CTRL + ←)</em></button>
|
||||
<button class="load_next examples-content">Load Next <em>(CTRL + →)</em></button>
|
||||
<button class="order_similar examples-content">Order by Similarity</button>
|
||||
<button class="view_embeddings examples-content">View Embeddings</button>
|
||||
<button class="update_embeddings embeddings-content invisible">Update Embeddings</button>
|
||||
<button class="view_examples embeddings-content invisible">View Examples</button>
|
||||
<div class="pages invisible">Page:</div>
|
||||
<table>
|
||||
<table class="examples-content">
|
||||
</table>
|
||||
<div class="plot embeddings-content invisible"><canvas id="canvas" width="400px" height="300px"></canvas></div>
|
||||
</div>`);
|
||||
let io_master = Object.create(io_master_template);
|
||||
io_master.fn = fn
|
||||
@ -93,7 +97,8 @@ function gradio(config, fn, target, example_file_path) {
|
||||
"dataframe" : dataframe_output,
|
||||
}
|
||||
let id_to_interface_map = {}
|
||||
|
||||
let embedding_chart;
|
||||
|
||||
function set_interface_id(interface, id) {
|
||||
interface.id = id;
|
||||
id_to_interface_map[id] = interface;
|
||||
@ -265,7 +270,6 @@ function gradio(config, fn, target, example_file_path) {
|
||||
let html = "";
|
||||
for (let i = page_start; i < page_start + config["examples_per_page"] && i < config.examples.length; i++) {
|
||||
let example_id = io_master.order_mapping[i];
|
||||
console.log(example_id)
|
||||
let example = config["examples"][example_id];
|
||||
html += "<tr row=" + example_id + ">";
|
||||
for (let [j, col] of example.entries()) {
|
||||
@ -330,6 +334,59 @@ function gradio(config, fn, target, example_file_path) {
|
||||
load_page();
|
||||
})
|
||||
});
|
||||
target.find(".view_examples").click(function() {
|
||||
target.find(".examples-content").removeClass("invisible");
|
||||
target.find(".embeddings-content").addClass("invisible");
|
||||
});
|
||||
target.find(".update_embeddings").click(function() {
|
||||
io_master.update_embeddings(function(output) {
|
||||
embedding_chart.data.datasets[0].data.push(output["sample_embedding_2d"][0]);
|
||||
console.log(output["sample_embedding_2d"][0])
|
||||
embedding_chart.update();
|
||||
})
|
||||
});
|
||||
target.find(".view_embeddings").click(function() {
|
||||
io_master.view_embeddings(function(output) {
|
||||
let ctx = $('#canvas')[0].getContext('2d');
|
||||
let backgroundColors = getBackgroundColors(io_master);
|
||||
embedding_chart = new Chart(ctx, {
|
||||
type: 'scatter',
|
||||
data: {
|
||||
datasets: [{
|
||||
label: 'Sample Embedding',
|
||||
data: output["sample_embedding_2d"],
|
||||
backgroundColor: 'rgb(0, 0, 0)',
|
||||
borderColor: 'rgb(0, 0, 0)',
|
||||
pointRadius: 13,
|
||||
pointHoverRadius: 13,
|
||||
pointStyle: 'rectRot',
|
||||
showLine: true,
|
||||
fill: false,
|
||||
}, {
|
||||
label: 'Dataset Embeddings',
|
||||
data: output["example_embeddings_2d"],
|
||||
backgroundColor: backgroundColors,
|
||||
borderColor: backgroundColors,
|
||||
pointRadius: 7,
|
||||
pointHoverRadius: 7
|
||||
}]
|
||||
},
|
||||
options: {
|
||||
legend: {display: false}
|
||||
}
|
||||
});
|
||||
$("#canvas")[0].onclick = function(evt){
|
||||
var activePoints = embedding_chart.getElementsAtEvent(evt);
|
||||
var firstPoint = activePoints[0];
|
||||
if (firstPoint._datasetIndex==1) { // if it's from the sample embeddings dataset
|
||||
load_example(firstPoint._index)
|
||||
}
|
||||
};
|
||||
|
||||
target.find(".examples-content").addClass("invisible");
|
||||
target.find(".embeddings-content").removeClass("invisible");
|
||||
})
|
||||
});
|
||||
$("body").keydown(function(e) {
|
||||
if ($(document.activeElement).attr("type") == "text" || $(document.activeElement).attr("type") == "textarea") {
|
||||
return;
|
||||
@ -345,8 +402,9 @@ function gradio(config, fn, target, example_file_path) {
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
target.find(".screenshot").click(function() {
|
||||
$(".screenshot, .record").hide();
|
||||
$(".screenshot_logo").removeClass("invisible");
|
||||
|
@ -84,6 +84,67 @@ function paintSaliency(data, ctx, width, height) {
|
||||
})
|
||||
}
|
||||
|
||||
function getBackgroundColors(io_master){
|
||||
//Gets the background colors for the embedding plot
|
||||
|
||||
// If labels aren't loaded, or it's not a label output interface:
|
||||
if (!io_master.loaded_examples || io_master["config"]["output_interfaces"][0][0]!="label") {
|
||||
return 'rgb(54, 162, 235)'
|
||||
}
|
||||
|
||||
// If it is a label interface, get the labels
|
||||
let labels = []
|
||||
let isConfidencesPresent = false;
|
||||
for (let i=0; i<Object.keys(io_master.loaded_examples).length; i++) {
|
||||
let label = io_master.loaded_examples[i][0]["label"];
|
||||
if ("confidences" in io_master.loaded_examples[i][0]){
|
||||
isConfidencesPresent = true;
|
||||
}
|
||||
labels.push(label);
|
||||
}
|
||||
|
||||
// If they are all numbers, and there are no confidences, then it's a regression
|
||||
const isNumeric = (currentValue) => !isNaN(currentValue);
|
||||
let isNumericArray = labels.every(isNumeric);
|
||||
if (isNumericArray && !isConfidencesPresent) {
|
||||
let backgroundColors = [];
|
||||
labels = labels.map(Number);
|
||||
let max = Math.max(...labels);
|
||||
let min = Math.min(...labels);
|
||||
let rgb_max = [255, 178, 102]
|
||||
let rgb_min = [204, 255, 255]
|
||||
for (let i=0; i<labels.length; i++) {
|
||||
let frac = (Number(labels[i])-min)/(max-min)
|
||||
let color = [rgb_min[0]+frac*(rgb_max[0]-rgb_min[0]),
|
||||
rgb_min[1]+frac*(rgb_max[1]-rgb_min[1]),
|
||||
rgb_min[2]+frac*(rgb_max[2]-rgb_min[2])]
|
||||
backgroundColors.push(color);
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, it's a classification
|
||||
let colorArray = ['#FF6633', '#FFB399', '#FF33FF', '#00B3E6',
|
||||
'#E6B333', '#3366E6', '#999966', '#99FF99', '#B34D4D',
|
||||
'#80B300', '#809900', '#E6B3B3', '#6680B3', '#66991A',
|
||||
'#FF99E6', '#CCFF1A', '#FF1A66', '#E6331A', '#33FFCC',
|
||||
'#66994D', '#B366CC', '#4D8000', '#B33300', '#CC80CC',
|
||||
'#66664D', '#991AFF', '#E666FF', '#4DB3FF', '#1AB399',
|
||||
'#E666B3', '#33991A', '#CC9999', '#B3B31A', '#00E680',
|
||||
'#4D8066', '#809980', '#E6FF80', '#1AFF33', '#999933',
|
||||
'#FF3380', '#CCCC00', '#66E64D', '#4D80CC', '#9900B3',
|
||||
'#E64D66', '#4DB380', '#FF4D4D', '#99E6E6', '#6666FF'];
|
||||
let backgroundColors = [];
|
||||
let label_list = [];
|
||||
for (let i=0; i<labels.length; i++) {
|
||||
if (!(label_list.includes(labels[i]))){
|
||||
label_list.push(labels[i]);
|
||||
}
|
||||
backgroundColors.push(colorArray[label_list.indexOf(labels[i]) % colorArray.length]);
|
||||
}
|
||||
return backgroundColors
|
||||
}
|
||||
|
||||
|
||||
function getSaliencyColor(value) {
|
||||
if (value < 0) {
|
||||
var color = [52, 152, 219];
|
||||
|
7
build/lib/gradio/static/js/vendor/Chart.min.js
vendored
Normal file
7
build/lib/gradio/static/js/vendor/Chart.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@ -91,6 +91,7 @@
|
||||
<script src="{{ vendor_prefix }}/static/js/vendor/sketchpad.js"></script>
|
||||
<script src="{{ vendor_prefix }}/static/js/vendor/webcam.min.js"></script>
|
||||
<script src="{{ vendor_prefix }}/static/js/vendor/gifcap/gifencoder.js"></script>
|
||||
<script src="{{ vendor_prefix }}/static/js/vendor/Chart.min.js"></script>
|
||||
|
||||
<script src="{{ url_for('static', filename='js/utils.js') }}"></script>
|
||||
<script src="{{ url_for('static', filename='js/all_io.js') }}"></script>
|
||||
|
@ -98,6 +98,7 @@ gradio/static/js/interfaces/output/json.js
|
||||
gradio/static/js/interfaces/output/key_values.js
|
||||
gradio/static/js/interfaces/output/label.js
|
||||
gradio/static/js/interfaces/output/textbox.js
|
||||
gradio/static/js/vendor/Chart.min.js
|
||||
gradio/static/js/vendor/FileSaver.min.js
|
||||
gradio/static/js/vendor/black-theme.js
|
||||
gradio/static/js/vendor/cropper.min.js
|
||||
|
Loading…
Reference in New Issue
Block a user