mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
added embeddings for almost all interface types
This commit is contained in:
parent
02ab39b3ad
commit
3e740b9076
18
gradio/embeddings.py
Normal file
18
gradio/embeddings.py
Normal file
File diff suppressed because one or more lines are too long
@ -10,7 +10,7 @@ import os
|
||||
import time
|
||||
import warnings
|
||||
from gradio.component import Component
|
||||
|
||||
from gradio.embeddings import embed_text
|
||||
import base64
|
||||
import numpy as np
|
||||
import PIL
|
||||
@ -72,6 +72,16 @@ class InputComponent(Component):
|
||||
Returns:
|
||||
(List[Any]): Arrangement of interpretation scores for interfaces to render.
|
||||
'''
|
||||
pass
|
||||
|
||||
def embed(self, x):
|
||||
"""
|
||||
Return a default embedding for the *preprocessed* input to the interface. Used to compute similar inputs.
|
||||
x (Any): Input to interface
|
||||
Returns:
|
||||
(List[Float]): An embedding vector as a list or numpy array of floats
|
||||
"""
|
||||
pass
|
||||
|
||||
class Textbox(InputComponent):
|
||||
"""
|
||||
@ -170,6 +180,17 @@ class Textbox(InputComponent):
|
||||
result.append((self.interpretation_separator, 0))
|
||||
return result
|
||||
|
||||
def embed(self, x):
|
||||
"""
|
||||
Embeds an arbitrary text based on word frequency
|
||||
"""
|
||||
if self.type == "str":
|
||||
return embed_text(x)
|
||||
elif self.type == "number":
|
||||
return [float(x)]
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'str', 'number'.")
|
||||
|
||||
|
||||
class Number(InputComponent):
|
||||
"""
|
||||
@ -238,6 +259,9 @@ class Number(InputComponent):
|
||||
interpretation.insert(int(len(interpretation) / 2), [x, None])
|
||||
return interpretation
|
||||
|
||||
def embed(self, x):
|
||||
return [float(x)]
|
||||
|
||||
|
||||
class Slider(InputComponent):
|
||||
"""
|
||||
@ -306,6 +330,10 @@ class Slider(InputComponent):
|
||||
"""
|
||||
return scores
|
||||
|
||||
def embed(self, x):
|
||||
return [float(x)]
|
||||
|
||||
|
||||
|
||||
class Checkbox(InputComponent):
|
||||
"""
|
||||
@ -353,6 +381,10 @@ class Checkbox(InputComponent):
|
||||
else:
|
||||
return None, scores[0]
|
||||
|
||||
def embed(self, x):
|
||||
return [float(x)]
|
||||
|
||||
|
||||
|
||||
class CheckboxGroup(InputComponent):
|
||||
"""
|
||||
@ -417,6 +449,15 @@ class CheckboxGroup(InputComponent):
|
||||
final_scores.append(score_set)
|
||||
return final_scores
|
||||
|
||||
def embed(self, x):
|
||||
if self.type == "value":
|
||||
return [choice in x for choice in self.choices]
|
||||
elif self.type == "index":
|
||||
return [index in x for index in range(len(choices))]
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
|
||||
|
||||
|
||||
|
||||
class Radio(InputComponent):
|
||||
"""
|
||||
@ -469,6 +510,14 @@ class Radio(InputComponent):
|
||||
scores.insert(self.choices.index(x), None)
|
||||
return scores
|
||||
|
||||
def embed(self, x):
|
||||
if self.type == "value":
|
||||
return [choice==x for choice in self.choices]
|
||||
elif self.type == "index":
|
||||
return [index==x for index in range(len(choices))]
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
|
||||
|
||||
|
||||
class Dropdown(InputComponent):
|
||||
"""
|
||||
@ -521,6 +570,14 @@ class Dropdown(InputComponent):
|
||||
scores.insert(self.choices.index(x), None)
|
||||
return scores
|
||||
|
||||
def embed(self, x):
|
||||
if self.type == "value":
|
||||
return [choice==x for choice in self.choices]
|
||||
elif self.type == "index":
|
||||
return [index==x for index in range(len(choices))]
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.")
|
||||
|
||||
|
||||
class Image(InputComponent):
|
||||
"""
|
||||
@ -641,7 +698,17 @@ class Image(InputComponent):
|
||||
return output_scores.tolist()
|
||||
|
||||
def embed(self, x):
|
||||
return x.flatten()
|
||||
shape = (100, 100) if self.shape is None else self.shape
|
||||
if self.type == "pil":
|
||||
im = x
|
||||
elif self.type == "numpy":
|
||||
im = PIL.Image.fromarray(x)
|
||||
elif self.type == "file":
|
||||
im = PIL.Image.open(x)
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'numpy', 'pil', 'file'.")
|
||||
im = processing_utils.resize_and_crop(im, (shape[0], shape[1]))
|
||||
return np.asarray(im).flatten()
|
||||
|
||||
class Audio(InputComponent):
|
||||
"""
|
||||
@ -723,6 +790,9 @@ class Audio(InputComponent):
|
||||
"""
|
||||
return scores
|
||||
|
||||
def embed(self, x):
|
||||
raise NotImplementedError("Audio doesn't currently support embeddings")
|
||||
|
||||
|
||||
class File(InputComponent):
|
||||
"""
|
||||
@ -761,6 +831,8 @@ class File(InputComponent):
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'file', 'bytes'.")
|
||||
|
||||
def embed(self, x):
|
||||
raise NotImplementedError("File doesn't currently support embeddings")
|
||||
|
||||
|
||||
class Dataframe(InputComponent):
|
||||
@ -853,6 +925,9 @@ class Dataframe(InputComponent):
|
||||
"""
|
||||
return np.array(scores).reshape((shape)).tolist()
|
||||
|
||||
def embed(self, x):
|
||||
raise NotImplementedError("DataFrame doesn't currently support embeddings")
|
||||
|
||||
|
||||
#######################
|
||||
# DEPRECATED COMPONENTS
|
||||
|
@ -20,7 +20,7 @@ import sys
|
||||
import csv
|
||||
import logging
|
||||
import gradio as gr
|
||||
from gradio.similarity import calculate_similarity
|
||||
from gradio.embeddings import calculate_similarity
|
||||
from gradio.tunneling import create_tunnel
|
||||
|
||||
INITIAL_PORT_VALUE = int(os.getenv(
|
||||
|
@ -1,8 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
def calculate_similarity(embedding1, embedding2):
|
||||
"""
|
||||
Scores the similarity between two embeddings by taking the L2 distance
|
||||
"""
|
||||
return np.linalg.norm(np.array(embedding1) - np.array(embedding2))
|
||||
|
Loading…
Reference in New Issue
Block a user