added embeddings for almost all interface types

This commit is contained in:
Abubakar Abid 2020-11-12 04:18:06 -06:00
parent 02ab39b3ad
commit 3e740b9076
4 changed files with 96 additions and 11 deletions

18
gradio/embeddings.py Normal file

File diff suppressed because one or more lines are too long

View File

@ -10,7 +10,7 @@ import os
import time
import warnings
from gradio.component import Component
from gradio.embeddings import embed_text
import base64
import numpy as np
import PIL
@ -72,6 +72,16 @@ class InputComponent(Component):
Returns:
(List[Any]): Arrangement of interpretation scores for interfaces to render.
'''
pass
def embed(self, x):
"""
Return a default embedding for the *preprocessed* input to the interface. Used to compute similar inputs.
x (Any): Input to interface
Returns:
(List[Float]): An embedding vector as a list or numpy array of floats
"""
pass
class Textbox(InputComponent):
"""
@ -170,6 +180,17 @@ class Textbox(InputComponent):
result.append((self.interpretation_separator, 0))
return result
def embed(self, x):
"""
Embeds an arbitrary text based on word frequency
"""
if self.type == "str":
return embed_text(x)
elif self.type == "number":
return [float(x)]
else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'str', 'number'.")
class Number(InputComponent):
"""
@ -238,6 +259,9 @@ class Number(InputComponent):
interpretation.insert(int(len(interpretation) / 2), [x, None])
return interpretation
def embed(self, x):
return [float(x)]
class Slider(InputComponent):
"""
@ -306,6 +330,10 @@ class Slider(InputComponent):
"""
return scores
def embed(self, x):
return [float(x)]
class Checkbox(InputComponent):
"""
@ -353,6 +381,10 @@ class Checkbox(InputComponent):
else:
return None, scores[0]
def embed(self, x):
return [float(x)]
class CheckboxGroup(InputComponent):
"""
@ -417,6 +449,15 @@ class CheckboxGroup(InputComponent):
final_scores.append(score_set)
return final_scores
def embed(self, x):
if self.type == "value":
return [choice in x for choice in self.choices]
elif self.type == "index":
return [index in x for index in range(len(choices))]
else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
class Radio(InputComponent):
"""
@ -469,6 +510,14 @@ class Radio(InputComponent):
scores.insert(self.choices.index(x), None)
return scores
def embed(self, x):
if self.type == "value":
return [choice==x for choice in self.choices]
elif self.type == "index":
return [index==x for index in range(len(choices))]
else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
class Dropdown(InputComponent):
"""
@ -521,6 +570,14 @@ class Dropdown(InputComponent):
scores.insert(self.choices.index(x), None)
return scores
def embed(self, x):
if self.type == "value":
return [choice==x for choice in self.choices]
elif self.type == "index":
return [index==x for index in range(len(choices))]
else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
class Image(InputComponent):
"""
@ -641,7 +698,17 @@ class Image(InputComponent):
return output_scores.tolist()
def embed(self, x):
return x.flatten()
shape = (100, 100) if self.shape is None else self.shape
if self.type == "pil":
im = x
elif self.type == "numpy":
im = PIL.Image.fromarray(x)
elif self.type == "file":
im = PIL.Image.open(x)
else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'numpy', 'pil', 'file'.")
im = processing_utils.resize_and_crop(im, (shape[0], shape[1]))
return np.asarray(im).flatten()
class Audio(InputComponent):
"""
@ -723,6 +790,9 @@ class Audio(InputComponent):
"""
return scores
def embed(self, x):
raise NotImplementedError("Audio doesn't currently support embeddings")
class File(InputComponent):
"""
@ -761,6 +831,8 @@ class File(InputComponent):
else:
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'file', 'bytes'.")
def embed(self, x):
raise NotImplementedError("File doesn't currently support embeddings")
class Dataframe(InputComponent):
@ -853,6 +925,9 @@ class Dataframe(InputComponent):
"""
return np.array(scores).reshape((shape)).tolist()
def embed(self, x):
raise NotImplementedError("DataFrame doesn't currently support embeddings")
#######################
# DEPRECATED COMPONENTS

View File

@ -20,7 +20,7 @@ import sys
import csv
import logging
import gradio as gr
from gradio.similarity import calculate_similarity
from gradio.embeddings import calculate_similarity
from gradio.tunneling import create_tunnel
INITIAL_PORT_VALUE = int(os.getenv(

View File

@ -1,8 +0,0 @@
import numpy as np
def calculate_similarity(embedding1, embedding2):
"""
Scores the similarity between two embeddings by taking the L2 distance
"""
return np.linalg.norm(np.array(embedding1) - np.array(embedding2))