added some static code for sorting (should be improved)

This commit is contained in:
Abubakar Abid 2020-11-12 07:11:01 -06:00
parent 2c9efd84d1
commit abda96bf30
3 changed files with 19 additions and 7 deletions

View File

@ -13,9 +13,9 @@ def embed_text(text):
def calculate_similarity(embedding1, embedding2):
"""
Scores the similarity between two embeddings by taking the cosine distance
Scores the similarity between two embeddings by taking the cosine similarity
"""
e1, e2 = np.array(embedding1), np.array(embedding2)
cosine_distance = np.dot(e1, e2) / (np.linalg.norm(e1) * np.linalg.norm(e2) + SMALL_CONST)
return 1 - cosine_distance
cosine_similarity = np.dot(e1, e2) / (np.linalg.norm(e1) * np.linalg.norm(e2) + SMALL_CONST)
return cosine_similarity

View File

@ -249,7 +249,7 @@ class Interface:
embedding = np.concatenate([input_interface.embed(processed_input[i])
for i, input_interface in enumerate(self.input_interfaces)])
else:
embedding = self.embedding_fn(processed_input)
embedding = self.embedding_fn(*processed_input)
return embedding
def interpret(self, raw_input):

View File

@ -1,3 +1,4 @@
var io_master_template = {
gather: function() {
this.clear();
@ -44,7 +45,6 @@ var io_master_template = {
this.target.find(".output_interfaces").css("opacity", 0.5);
this.fn(this.last_input, "score_similarity").then((output) => {
console.log(output.data)
this.target.find(".loading").addClass("invisible");
this.target.find(".output_interfaces").css("opacity", 1);
let html = "<th>SIMILARITY</th>"
@ -53,9 +53,21 @@ var io_master_template = {
let html = "<td>" + output["data"][i] + "</td>"
this.target.find(".examples_body tr[row='" + i + "']").append(html);
}
function getCellValue(row, index){ return $(row).children('td').eq(index).text() }
function comparer(index) {
return function(a, b) {
var valA = getCellValue(a, index), valB = getCellValue(b, index)
return $.isNumeric(valA) && $.isNumeric(valB) ? valA - valB : valA.toString().localeCompare(valB)
}
}
var table = $(".examples > table").eq(0)
var rows = table.find('tr:gt(0)').toArray().sort(comparer(-1)).reverse() // sort by last column
table.find("tr:gt(0)").remove()
for (var i = 0; i < rows.length; i++){table.append(rows[i]); console.log(i)}
})
},
submit_examples: function() {
this.target.find(".loading").removeClass("invisible");