mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-12 10:34:32 +08:00
Blocks-Components
- combine Image into one
This commit is contained in:
parent
c32d3aafa4
commit
e75dda2f6e
@ -4,12 +4,15 @@ import json
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
|
||||
from gradio import processing_utils
|
||||
from gradio import processing_utils, test_data
|
||||
from gradio.blocks import Block
|
||||
|
||||
|
||||
@ -26,7 +29,10 @@ class Component(Block):
|
||||
**kwargs,
|
||||
):
|
||||
if "optional" in kwargs:
|
||||
warnings.warn("Usage of optional is deprecated, and it has no effect")
|
||||
warnings.warn(
|
||||
"Usage of optional is deprecated, and it has no effect",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.label = label
|
||||
self.requires_permissions = requires_permissions
|
||||
|
||||
@ -899,3 +905,301 @@ class Dropdown(Radio):
|
||||
super().__init__(
|
||||
default=default, choices=choices, type=type, label=label, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class Image(Component):
|
||||
"""
|
||||
Component creates an image component with input and output capabilities.
|
||||
|
||||
Input type: Union[numpy.array, PIL.Image, file-object]
|
||||
Output type: Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]
|
||||
Demos: image_classifier, image_mod, webcam, digit_classifier
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default=None,
|
||||
*,
|
||||
shape: Tuple[int, int] = None,
|
||||
image_mode: str = "RGB",
|
||||
invert_colors: bool = False,
|
||||
source: str = "upload",
|
||||
tool: str = "editor",
|
||||
type: str = "numpy",
|
||||
label: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
shape (Tuple[int, int]): (width, height) shape to crop and resize image to; if None, matches input image size.
|
||||
image_mode (str): "RGB" if color, or "L" if black and white.
|
||||
invert_colors (bool): whether to invert the image as a preprocessing step.
|
||||
source (str): Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools.
|
||||
tool (str): Tools used for editing. "editor" allows a full screen editor, "select" provides a cropping and zoom tool.
|
||||
type (str): #TODO:(Faruk) combine the descriptions below
|
||||
input: Type of value to be returned by component. "numpy" returns a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" returns a PIL image object, "file" returns a temporary file object whose path can be retrieved by file_obj.name, "filepath" returns the path directly.
|
||||
output: Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image or a remote URL, "plot" expects a matplotlib.pyplot object, "auto" detects return type.
|
||||
label (str): component name in interface.
|
||||
"""
|
||||
self.type = type
|
||||
self.shape = shape
|
||||
self.image_mode = image_mode
|
||||
self.source = source
|
||||
requires_permissions = source == "webcam"
|
||||
self.tool = tool
|
||||
self.invert_colors = invert_colors
|
||||
self.test_input = test_data.BASE64_IMAGE
|
||||
self.interpret_by_tokens = True
|
||||
super().__init__(
|
||||
label=label, requires_permissions=requires_permissions, **kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
return {
|
||||
"image": {},
|
||||
"webcam": {"source": "webcam"},
|
||||
"sketchpad": {
|
||||
"image_mode": "L",
|
||||
"source": "canvas",
|
||||
"shape": (28, 28),
|
||||
"invert_colors": True,
|
||||
},
|
||||
}
|
||||
|
||||
def get_template_context(self):
|
||||
return {
|
||||
"image_mode": self.image_mode,
|
||||
"shape": self.shape,
|
||||
"source": self.source,
|
||||
"tool": self.tool,
|
||||
"optional": self.optional,
|
||||
**super().get_template_context(),
|
||||
}
|
||||
|
||||
def preprocess(self, x: Optional[str]) -> np.array | PIL.Image | str | None:
|
||||
"""
|
||||
Parameters:
|
||||
x (str): base64 url data
|
||||
Returns:
|
||||
(Union[numpy.array, PIL.Image, filepath]): image in requested format
|
||||
"""
|
||||
if x is None:
|
||||
return x
|
||||
im = processing_utils.decode_base64_to_image(x)
|
||||
fmt = im.format
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
im = im.convert(self.image_mode)
|
||||
if self.shape is not None:
|
||||
im = processing_utils.resize_and_crop(im, self.shape)
|
||||
if self.invert_colors:
|
||||
im = PIL.ImageOps.invert(im)
|
||||
if self.type == "pil":
|
||||
return im
|
||||
elif self.type == "numpy":
|
||||
return np.array(im)
|
||||
elif self.type == "file" or self.type == "filepath":
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False,
|
||||
suffix=("." + fmt.lower() if fmt is not None else ".png"),
|
||||
)
|
||||
im.save(file_obj.name)
|
||||
if self.type == "file":
|
||||
warnings.warn(
|
||||
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return file_obj
|
||||
else:
|
||||
return file_obj.name
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type: "
|
||||
+ str(self.type)
|
||||
+ ". Please choose from: 'numpy', 'pil', 'filepath'."
|
||||
)
|
||||
|
||||
def preprocess_example(self, x):
|
||||
return processing_utils.encode_file_to_base64(x)
|
||||
|
||||
def serialize(self, x, called_directly=False):
|
||||
# if called directly, can assume it's a URL or filepath
|
||||
if self.type == "filepath" or called_directly:
|
||||
return processing_utils.encode_url_or_file_to_base64(x)
|
||||
elif self.type == "file":
|
||||
return processing_utils.encode_url_or_file_to_base64(x.name)
|
||||
elif self.type in ("numpy", "pil"):
|
||||
if self.type == "numpy":
|
||||
x = PIL.Image.fromarray(np.uint8(x)).convert("RGB")
|
||||
fmt = x.format
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False,
|
||||
suffix=("." + fmt.lower() if fmt is not None else ".png"),
|
||||
)
|
||||
x.save(file_obj.name)
|
||||
return processing_utils.encode_url_or_file_to_base64(file_obj.name)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type: "
|
||||
+ str(self.type)
|
||||
+ ". Please choose from: 'numpy', 'pil', 'filepath'."
|
||||
)
|
||||
|
||||
def set_interpret_parameters(self, segments=16):
|
||||
"""
|
||||
Calculates interpretation score of image subsections by splitting the image into subsections, then using a "leave one out" method to calculate the score of each subsection by whiting out the subsection and measuring the delta of the output value.
|
||||
Parameters:
|
||||
segments (int): Number of interpretation segments to split image into.
|
||||
"""
|
||||
self.interpretation_segments = segments
|
||||
return self
|
||||
|
||||
def _segment_by_slic(self, x):
|
||||
"""
|
||||
Helper method that segments an image into superpixels using slic.
|
||||
Parameters:
|
||||
x: base64 representation of an image
|
||||
"""
|
||||
x = processing_utils.decode_base64_to_image(x)
|
||||
if self.shape is not None:
|
||||
x = processing_utils.resize_and_crop(x, self.shape)
|
||||
resized_and_cropped_image = np.array(x)
|
||||
try:
|
||||
from skimage.segmentation import slic
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ValueError(
|
||||
"Error: running this interpretation for images requires scikit-image, please install it first."
|
||||
)
|
||||
try:
|
||||
segments_slic = slic(
|
||||
resized_and_cropped_image,
|
||||
self.interpretation_segments,
|
||||
compactness=10,
|
||||
sigma=1,
|
||||
start_label=1,
|
||||
)
|
||||
except TypeError: # For skimage 0.16 and older
|
||||
segments_slic = slic(
|
||||
resized_and_cropped_image,
|
||||
self.interpretation_segments,
|
||||
compactness=10,
|
||||
sigma=1,
|
||||
)
|
||||
return segments_slic, resized_and_cropped_image
|
||||
|
||||
def tokenize(self, x):
|
||||
"""
|
||||
Segments image into tokens, masks, and leave-one-out-tokens
|
||||
Parameters:
|
||||
x: base64 representation of an image
|
||||
Returns:
|
||||
tokens: list of tokens, used by the get_masked_input() method
|
||||
leave_one_out_tokens: list of left-out tokens, used by the get_interpretation_neighbors() method
|
||||
masks: list of masks, used by the get_interpretation_neighbors() method
|
||||
"""
|
||||
segments_slic, resized_and_cropped_image = self._segment_by_slic(x)
|
||||
tokens, masks, leave_one_out_tokens = [], [], []
|
||||
replace_color = np.mean(resized_and_cropped_image, axis=(0, 1))
|
||||
for (i, segment_value) in enumerate(np.unique(segments_slic)):
|
||||
mask = segments_slic == segment_value
|
||||
image_screen = np.copy(resized_and_cropped_image)
|
||||
image_screen[segments_slic == segment_value] = replace_color
|
||||
leave_one_out_tokens.append(
|
||||
processing_utils.encode_array_to_base64(image_screen)
|
||||
)
|
||||
token = np.copy(resized_and_cropped_image)
|
||||
token[segments_slic != segment_value] = 0
|
||||
tokens.append(token)
|
||||
masks.append(mask)
|
||||
return tokens, leave_one_out_tokens, masks
|
||||
|
||||
def get_masked_inputs(self, tokens, binary_mask_matrix):
|
||||
masked_inputs = []
|
||||
for binary_mask_vector in binary_mask_matrix:
|
||||
masked_input = np.zeros_like(tokens[0], dtype=int)
|
||||
for token, b in zip(tokens, binary_mask_vector):
|
||||
masked_input = masked_input + token * int(b)
|
||||
masked_inputs.append(
|
||||
processing_utils.encode_array_to_base64(masked_input)
|
||||
)
|
||||
return masked_inputs
|
||||
|
||||
def get_interpretation_scores(self, x, neighbors, scores, masks, tokens=None):
|
||||
"""
|
||||
Returns:
|
||||
(List[List[float]]): A 2D array representing the interpretation score of each pixel of the image.
|
||||
"""
|
||||
x = processing_utils.decode_base64_to_image(x)
|
||||
if self.shape is not None:
|
||||
x = processing_utils.resize_and_crop(x, self.shape)
|
||||
x = np.array(x)
|
||||
output_scores = np.zeros((x.shape[0], x.shape[1]))
|
||||
|
||||
for score, mask in zip(scores, masks):
|
||||
output_scores += score * mask
|
||||
|
||||
max_val, min_val = np.max(output_scores), np.min(output_scores)
|
||||
if max_val > 0:
|
||||
output_scores = (output_scores - min_val) / (max_val - min_val)
|
||||
return output_scores.tolist()
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
"""
|
||||
Returns: (str) path to image file
|
||||
"""
|
||||
return self.save_flagged_file(dir, label, data, encryption_key)
|
||||
|
||||
def generate_sample(self):
|
||||
return test_data.BASE64_IMAGE
|
||||
|
||||
# Output functions
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
return {"image": {}, "plot": {"type": "plot"}, "pil": {"type": "pil"}}
|
||||
|
||||
def postprocess(self, y):
|
||||
"""
|
||||
Parameters:
|
||||
y (Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]): image in specified format
|
||||
Returns:
|
||||
(str): base64 url data
|
||||
"""
|
||||
if self.type == "auto":
|
||||
if isinstance(y, np.ndarray):
|
||||
dtype = "numpy"
|
||||
elif isinstance(y, PIL.Image.Image):
|
||||
dtype = "pil"
|
||||
elif isinstance(y, str):
|
||||
dtype = "file"
|
||||
elif isinstance(y, ModuleType):
|
||||
dtype = "plot"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type. Please choose from: 'numpy', 'pil', 'file', 'plot'."
|
||||
)
|
||||
else:
|
||||
dtype = self.type
|
||||
if dtype in ["numpy", "pil"]:
|
||||
if dtype == "pil":
|
||||
y = np.array(y)
|
||||
out_y = processing_utils.encode_array_to_base64(y)
|
||||
elif dtype == "file":
|
||||
out_y = processing_utils.encode_url_or_file_to_base64(y)
|
||||
elif dtype == "plot":
|
||||
out_y = processing_utils.encode_plot_to_base64(y)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type: "
|
||||
+ dtype
|
||||
+ ". Please choose from: 'numpy', 'pil', 'file', 'plot'."
|
||||
)
|
||||
return out_y
|
||||
|
||||
def deserialize(self, x):
|
||||
y = processing_utils.decode_base64_to_file(x).name
|
||||
return y
|
||||
|
||||
|
||||
def restore_flagged(self, dir, data, encryption_key):
|
||||
return self.restore_flagged_file(dir, data, encryption_key)["data"]
|
||||
|
278
gradio/inputs.py
278
gradio/inputs.py
@ -23,6 +23,7 @@ from gradio.components import (
|
||||
CheckboxGroup,
|
||||
Component,
|
||||
Dropdown,
|
||||
Image,
|
||||
Number,
|
||||
Radio,
|
||||
Slider,
|
||||
@ -248,6 +249,47 @@ class Dropdown(Dropdown):
|
||||
)
|
||||
|
||||
|
||||
class Image(Image):
|
||||
"""
|
||||
Component creates an image upload box with editing capabilities.
|
||||
Input type: Union[numpy.array, PIL.Image, file-object]
|
||||
Demos: image_classifier, image_mod, webcam, digit_classifier
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape: Tuple[int, int] = None,
|
||||
image_mode: str = "RGB",
|
||||
invert_colors: bool = False,
|
||||
source: str = "upload",
|
||||
tool: str = "editor",
|
||||
type: str = "numpy",
|
||||
label: str = None,
|
||||
optional: bool = False,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
shape (Tuple[int, int]): (width, height) shape to crop and resize image to; if None, matches input image size.
|
||||
image_mode (str): "RGB" if color, or "L" if black and white.
|
||||
invert_colors (bool): whether to invert the image as a preprocessing step.
|
||||
source (str): Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools.
|
||||
tool (str): Tools used for editing. "editor" allows a full screen editor, "select" provides a cropping and zoom tool.
|
||||
type (str): Type of value to be returned by component. "numpy" returns a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" returns a PIL image object, "file" returns a temporary file object whose path can be retrieved by file_obj.name, "filepath" returns the path directly.
|
||||
label (str): component name in interface.
|
||||
optional (bool): If True, the interface can be submitted with no uploaded image, in which case the input value is None.
|
||||
"""
|
||||
super().__init__(
|
||||
shape=shape,
|
||||
image_mode=image_mode,
|
||||
invert_colors=invert_colors,
|
||||
source=source,
|
||||
tool=tool,
|
||||
type=type,
|
||||
label=label,
|
||||
optional=optional,
|
||||
)
|
||||
|
||||
|
||||
class InputComponent(Component):
|
||||
"""
|
||||
Input Component. All input components subclass this.
|
||||
@ -331,242 +373,6 @@ class InputComponent(Component):
|
||||
}
|
||||
|
||||
|
||||
class Image(InputComponent):
|
||||
"""
|
||||
Component creates an image upload box with editing capabilities.
|
||||
Input type: Union[numpy.array, PIL.Image, file-object]
|
||||
Demos: image_classifier, image_mod, webcam, digit_classifier
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape: Tuple[int, int] = None,
|
||||
image_mode: str = "RGB",
|
||||
invert_colors: bool = False,
|
||||
source: str = "upload",
|
||||
tool: str = "editor",
|
||||
type: str = "numpy",
|
||||
label: str = None,
|
||||
optional: bool = False,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
shape (Tuple[int, int]): (width, height) shape to crop and resize image to; if None, matches input image size.
|
||||
image_mode (str): "RGB" if color, or "L" if black and white.
|
||||
invert_colors (bool): whether to invert the image as a preprocessing step.
|
||||
source (str): Source of image. "upload" creates a box where user can drop an image file, "webcam" allows user to take snapshot from their webcam, "canvas" defaults to a white image that can be edited and drawn upon with tools.
|
||||
tool (str): Tools used for editing. "editor" allows a full screen editor, "select" provides a cropping and zoom tool.
|
||||
type (str): Type of value to be returned by component. "numpy" returns a numpy array with shape (width, height, 3) and values from 0 to 255, "pil" returns a PIL image object, "file" returns a temporary file object whose path can be retrieved by file_obj.name, "filepath" returns the path directly.
|
||||
label (str): component name in interface.
|
||||
optional (bool): If True, the interface can be submitted with no uploaded image, in which case the input value is None.
|
||||
"""
|
||||
self.shape = shape
|
||||
self.image_mode = image_mode
|
||||
self.source = source
|
||||
requires_permissions = source == "webcam"
|
||||
self.tool = tool
|
||||
self.type = type
|
||||
self.invert_colors = invert_colors
|
||||
self.test_input = test_data.BASE64_IMAGE
|
||||
self.interpret_by_tokens = True
|
||||
super().__init__(label, requires_permissions, optional=optional)
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
return {
|
||||
"image": {},
|
||||
"webcam": {"source": "webcam"},
|
||||
"sketchpad": {
|
||||
"image_mode": "L",
|
||||
"source": "canvas",
|
||||
"shape": (28, 28),
|
||||
"invert_colors": True,
|
||||
},
|
||||
}
|
||||
|
||||
def get_template_context(self):
|
||||
return {
|
||||
"image_mode": self.image_mode,
|
||||
"shape": self.shape,
|
||||
"source": self.source,
|
||||
"tool": self.tool,
|
||||
"optional": self.optional,
|
||||
**super().get_template_context(),
|
||||
}
|
||||
|
||||
def preprocess(self, x: Optional[str]) -> np.array | PIL.Image | str | None:
|
||||
"""
|
||||
Parameters:
|
||||
x (str): base64 url data
|
||||
Returns:
|
||||
(Union[numpy.array, PIL.Image, filepath]): image in requested format
|
||||
"""
|
||||
if x is None:
|
||||
return x
|
||||
im = processing_utils.decode_base64_to_image(x)
|
||||
fmt = im.format
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
im = im.convert(self.image_mode)
|
||||
if self.shape is not None:
|
||||
im = processing_utils.resize_and_crop(im, self.shape)
|
||||
if self.invert_colors:
|
||||
im = PIL.ImageOps.invert(im)
|
||||
if self.type == "pil":
|
||||
return im
|
||||
elif self.type == "numpy":
|
||||
return np.array(im)
|
||||
elif self.type == "file" or self.type == "filepath":
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=("." + fmt.lower() if fmt is not None else ".png")
|
||||
)
|
||||
im.save(file_obj.name)
|
||||
if self.type == "file":
|
||||
warnings.warn(
|
||||
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return file_obj
|
||||
else:
|
||||
return file_obj.name
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type: "
|
||||
+ str(self.type)
|
||||
+ ". Please choose from: 'numpy', 'pil', 'filepath'."
|
||||
)
|
||||
|
||||
def preprocess_example(self, x):
|
||||
return processing_utils.encode_file_to_base64(x)
|
||||
|
||||
def serialize(self, x, called_directly=False):
|
||||
# if called directly, can assume it's a URL or filepath
|
||||
if self.type == "filepath" or called_directly:
|
||||
return processing_utils.encode_url_or_file_to_base64(x)
|
||||
elif self.type == "file":
|
||||
return processing_utils.encode_url_or_file_to_base64(x.name)
|
||||
elif self.type in ("numpy", "pil"):
|
||||
if self.type == "numpy":
|
||||
x = PIL.Image.fromarray(np.uint8(x)).convert("RGB")
|
||||
fmt = x.format
|
||||
file_obj = tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=("." + fmt.lower() if fmt is not None else ".png")
|
||||
)
|
||||
x.save(file_obj.name)
|
||||
return processing_utils.encode_url_or_file_to_base64(file_obj.name)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type: "
|
||||
+ str(self.type)
|
||||
+ ". Please choose from: 'numpy', 'pil', 'filepath'."
|
||||
)
|
||||
|
||||
def set_interpret_parameters(self, segments=16):
|
||||
"""
|
||||
Calculates interpretation score of image subsections by splitting the image into subsections, then using a "leave one out" method to calculate the score of each subsection by whiting out the subsection and measuring the delta of the output value.
|
||||
Parameters:
|
||||
segments (int): Number of interpretation segments to split image into.
|
||||
"""
|
||||
self.interpretation_segments = segments
|
||||
return self
|
||||
|
||||
def _segment_by_slic(self, x):
|
||||
"""
|
||||
Helper method that segments an image into superpixels using slic.
|
||||
Parameters:
|
||||
x: base64 representation of an image
|
||||
"""
|
||||
x = processing_utils.decode_base64_to_image(x)
|
||||
if self.shape is not None:
|
||||
x = processing_utils.resize_and_crop(x, self.shape)
|
||||
resized_and_cropped_image = np.array(x)
|
||||
try:
|
||||
from skimage.segmentation import slic
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ValueError(
|
||||
"Error: running this interpretation for images requires scikit-image, please install it first."
|
||||
)
|
||||
try:
|
||||
segments_slic = slic(
|
||||
resized_and_cropped_image,
|
||||
self.interpretation_segments,
|
||||
compactness=10,
|
||||
sigma=1,
|
||||
start_label=1,
|
||||
)
|
||||
except TypeError: # For skimage 0.16 and older
|
||||
segments_slic = slic(
|
||||
resized_and_cropped_image,
|
||||
self.interpretation_segments,
|
||||
compactness=10,
|
||||
sigma=1,
|
||||
)
|
||||
return segments_slic, resized_and_cropped_image
|
||||
|
||||
def tokenize(self, x):
|
||||
"""
|
||||
Segments image into tokens, masks, and leave-one-out-tokens
|
||||
Parameters:
|
||||
x: base64 representation of an image
|
||||
Returns:
|
||||
tokens: list of tokens, used by the get_masked_input() method
|
||||
leave_one_out_tokens: list of left-out tokens, used by the get_interpretation_neighbors() method
|
||||
masks: list of masks, used by the get_interpretation_neighbors() method
|
||||
"""
|
||||
segments_slic, resized_and_cropped_image = self._segment_by_slic(x)
|
||||
tokens, masks, leave_one_out_tokens = [], [], []
|
||||
replace_color = np.mean(resized_and_cropped_image, axis=(0, 1))
|
||||
for (i, segment_value) in enumerate(np.unique(segments_slic)):
|
||||
mask = segments_slic == segment_value
|
||||
image_screen = np.copy(resized_and_cropped_image)
|
||||
image_screen[segments_slic == segment_value] = replace_color
|
||||
leave_one_out_tokens.append(
|
||||
processing_utils.encode_array_to_base64(image_screen)
|
||||
)
|
||||
token = np.copy(resized_and_cropped_image)
|
||||
token[segments_slic != segment_value] = 0
|
||||
tokens.append(token)
|
||||
masks.append(mask)
|
||||
return tokens, leave_one_out_tokens, masks
|
||||
|
||||
def get_masked_inputs(self, tokens, binary_mask_matrix):
|
||||
masked_inputs = []
|
||||
for binary_mask_vector in binary_mask_matrix:
|
||||
masked_input = np.zeros_like(tokens[0], dtype=int)
|
||||
for token, b in zip(tokens, binary_mask_vector):
|
||||
masked_input = masked_input + token * int(b)
|
||||
masked_inputs.append(processing_utils.encode_array_to_base64(masked_input))
|
||||
return masked_inputs
|
||||
|
||||
def get_interpretation_scores(self, x, neighbors, scores, masks, tokens=None):
|
||||
"""
|
||||
Returns:
|
||||
(List[List[float]]): A 2D array representing the interpretation score of each pixel of the image.
|
||||
"""
|
||||
x = processing_utils.decode_base64_to_image(x)
|
||||
if self.shape is not None:
|
||||
x = processing_utils.resize_and_crop(x, self.shape)
|
||||
x = np.array(x)
|
||||
output_scores = np.zeros((x.shape[0], x.shape[1]))
|
||||
|
||||
for score, mask in zip(scores, masks):
|
||||
output_scores += score * mask
|
||||
|
||||
max_val, min_val = np.max(output_scores), np.min(output_scores)
|
||||
if max_val > 0:
|
||||
output_scores = (output_scores - min_val) / (max_val - min_val)
|
||||
return output_scores.tolist()
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
"""
|
||||
Returns: (str) path to image file
|
||||
"""
|
||||
return self.save_flagged_file(dir, label, data, encryption_key)
|
||||
|
||||
def generate_sample(self):
|
||||
return test_data.BASE64_IMAGE
|
||||
|
||||
|
||||
class Video(InputComponent):
|
||||
"""
|
||||
Component creates a video file upload that is converted to a file path.
|
||||
|
@ -21,7 +21,7 @@ import PIL
|
||||
from ffmpy import FFmpeg
|
||||
|
||||
from gradio import processing_utils
|
||||
from gradio.components import Component, Textbox
|
||||
from gradio.components import Component, Image, Textbox
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
|
||||
from gradio import Interface
|
||||
@ -42,6 +42,33 @@ class Textbox(Textbox):
|
||||
super().__init__(type=type, label=label)
|
||||
|
||||
|
||||
class Image(Image):
|
||||
"""
|
||||
Component displays an output image.
|
||||
Output type: Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]
|
||||
Demos: image_mod, webcam
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, type: str = "auto", plot: bool = False, label: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image or a remote URL, "plot" expects a matplotlib.pyplot object, "auto" detects return type.
|
||||
plot (bool): DEPRECATED. Whether to expect a plot to be returned by the function.
|
||||
label (str): component name in interface.
|
||||
"""
|
||||
if plot:
|
||||
warnings.warn(
|
||||
"The 'plot' parameter has been deprecated. Set parameter 'type' to 'plot' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.type = "plot"
|
||||
else:
|
||||
self.type = type
|
||||
super().__init__(label=label, type=type, plot=plot)
|
||||
|
||||
|
||||
class OutputComponent(Component):
|
||||
"""
|
||||
Output Component. All output components subclass this.
|
||||
@ -168,85 +195,6 @@ class Label(OutputComponent):
|
||||
return data
|
||||
|
||||
|
||||
class Image(OutputComponent):
|
||||
"""
|
||||
Component displays an output image.
|
||||
Output type: Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]
|
||||
Demos: image_mod, webcam
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, type: str = "auto", plot: bool = False, label: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image or a remote URL, "plot" expects a matplotlib.pyplot object, "auto" detects return type.
|
||||
plot (bool): DEPRECATED. Whether to expect a plot to be returned by the function.
|
||||
label (str): component name in interface.
|
||||
"""
|
||||
if plot:
|
||||
warnings.warn(
|
||||
"The 'plot' parameter has been deprecated. Set parameter 'type' to 'plot' instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.type = "plot"
|
||||
else:
|
||||
self.type = type
|
||||
super().__init__(label)
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
return {"image": {}, "plot": {"type": "plot"}, "pil": {"type": "pil"}}
|
||||
|
||||
def postprocess(self, y):
|
||||
"""
|
||||
Parameters:
|
||||
y (Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]): image in specified format
|
||||
Returns:
|
||||
(str): base64 url data
|
||||
"""
|
||||
if self.type == "auto":
|
||||
if isinstance(y, np.ndarray):
|
||||
dtype = "numpy"
|
||||
elif isinstance(y, PIL.Image.Image):
|
||||
dtype = "pil"
|
||||
elif isinstance(y, str):
|
||||
dtype = "file"
|
||||
elif isinstance(y, ModuleType):
|
||||
dtype = "plot"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type. Please choose from: 'numpy', 'pil', 'file', 'plot'."
|
||||
)
|
||||
else:
|
||||
dtype = self.type
|
||||
if dtype in ["numpy", "pil"]:
|
||||
if dtype == "pil":
|
||||
y = np.array(y)
|
||||
out_y = processing_utils.encode_array_to_base64(y)
|
||||
elif dtype == "file":
|
||||
out_y = processing_utils.encode_url_or_file_to_base64(y)
|
||||
elif dtype == "plot":
|
||||
out_y = processing_utils.encode_plot_to_base64(y)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type: "
|
||||
+ dtype
|
||||
+ ". Please choose from: 'numpy', 'pil', 'file', 'plot'."
|
||||
)
|
||||
return out_y
|
||||
|
||||
def deserialize(self, x):
|
||||
y = processing_utils.decode_base64_to_file(x).name
|
||||
return y
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
return self.save_flagged_file(dir, label, data, encryption_key)
|
||||
|
||||
def restore_flagged(self, dir, data, encryption_key):
|
||||
return self.restore_flagged_file(dir, data, encryption_key)["data"]
|
||||
|
||||
|
||||
class Video(OutputComponent):
|
||||
"""
|
||||
Used for video output.
|
||||
|
@ -4,6 +4,7 @@ from gradio.components import Component
|
||||
|
||||
# TODO: (faruk) Remove this file in version 3.0
|
||||
|
||||
|
||||
class StaticComponent(Component):
|
||||
def __init__(self, label: str):
|
||||
self.component_type = "static"
|
||||
|
Loading…
Reference in New Issue
Block a user