mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-27 02:30:17 +08:00
fix
This commit is contained in:
parent
a03f9a4706
commit
4debb0f829
@ -67,7 +67,7 @@ class AbstractInput(ABC):
|
||||
|
||||
|
||||
class Sketchpad(AbstractInput):
|
||||
def __init__(self, shape=(28, 28), invert_colors=True,
|
||||
def __init__(self, cast_to="numpy", shape=(28, 28), invert_colors=True,
|
||||
flatten=False, scale=1/255, shift=0,
|
||||
dtype='float64', sample_inputs=None, label=None):
|
||||
self.image_width = shape[0]
|
||||
@ -110,10 +110,10 @@ class Sketchpad(AbstractInput):
|
||||
|
||||
|
||||
class Webcam(AbstractInput):
|
||||
def __init__(self, image_width=224, image_height=224, num_channels=3, label=None):
|
||||
self.image_width = image_width
|
||||
self.image_height = image_height
|
||||
self.num_channels = num_channels
|
||||
def __init__(self, shape=(224, 224), label=None):
|
||||
self.image_width = shape[0]
|
||||
self.image_height = shape[1]
|
||||
self.num_channels = 3
|
||||
super().__init__(label)
|
||||
|
||||
def get_validation_inputs(self):
|
||||
@ -132,8 +132,7 @@ class Webcam(AbstractInput):
|
||||
im = preprocessing_utils.decode_base64_to_image(inp)
|
||||
im = im.convert('RGB')
|
||||
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
|
||||
array = np.array(im).flatten().reshape(self.image_width, self.image_height, self.num_channels)
|
||||
return array
|
||||
return np.array(im)
|
||||
|
||||
|
||||
class Textbox(AbstractInput):
|
||||
@ -244,9 +243,11 @@ class Checkbox(AbstractInput):
|
||||
|
||||
|
||||
class Image(AbstractInput):
|
||||
def __init__(self, cast_to=None, shape=(224, 224), label=None):
|
||||
def __init__(self, cast_to=None, shape=(224, 224), image_mode='RGB', label=None):
|
||||
self.cast_to = cast_to
|
||||
self.image_width = shape[0]
|
||||
self.image_height = shape[1]
|
||||
self.image_mode = image_mode
|
||||
super().__init__(label)
|
||||
|
||||
def get_validation_inputs(self):
|
||||
@ -264,24 +265,36 @@ class Image(AbstractInput):
|
||||
**super().get_template_context()
|
||||
}
|
||||
|
||||
def cast_to_base64(self, inp):
|
||||
return inp
|
||||
|
||||
def cast_to_im(self, inp):
|
||||
return preprocessing_utils.decode_base64_to_image(inp)
|
||||
|
||||
def cast_to_numpy(self, inp):
|
||||
im = self.cast_to_im(inp)
|
||||
arr = np.array(im).flatten()
|
||||
return arr
|
||||
|
||||
def preprocess(self, inp):
|
||||
"""
|
||||
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
|
||||
"""
|
||||
cast_to_type = {
|
||||
"base64": self.cast_to_base64,
|
||||
"numpy": self.cast_to_numpy,
|
||||
"pillow": self.cast_to_im
|
||||
}
|
||||
if self.cast_to:
|
||||
return cast_to_type[self.cast_to](inp)
|
||||
|
||||
im = preprocessing_utils.decode_base64_to_image(inp)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
im = im.convert(self.image_mode)
|
||||
|
||||
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
|
||||
im = np.array(im).flatten()
|
||||
im = im * self.scale + self.shift
|
||||
if self.num_channels is None:
|
||||
array = im.reshape(self.image_width, self.image_height)
|
||||
else:
|
||||
array = im.reshape(self.image_width, self.image_height, \
|
||||
self.num_channels)
|
||||
return array
|
||||
return np.array(im)
|
||||
|
||||
def process_example(self, example):
|
||||
if os.path.exists(example):
|
||||
@ -305,4 +318,4 @@ class Microphone(AbstractInput):
|
||||
shortcuts = {}
|
||||
for cls in AbstractInput.__subclasses__():
|
||||
for shortcut, parameters in cls.get_shortcut_implementations().items():
|
||||
shortcuts[shortcut] = cls(**parameters)
|
||||
shortcuts[shortcut] = cls(**parameters)
|
Loading…
Reference in New Issue
Block a user