Blocks-Components

- fixes
- format
This commit is contained in:
Ömer Faruk Özdemir 2022-03-15 12:44:02 +03:00
parent 9e45418227
commit 7fa8e45b67
3 changed files with 251 additions and 252 deletions

View File

@ -409,7 +409,7 @@ class Number(Component):
return None
return float(x)
def preprocess_example(self, x: float | None) -> float:
def preprocess_example(self, x: float | None) -> float | None:
"""
Returns:
(float): Number representing function input
@ -508,7 +508,7 @@ class Slider(Component):
if step is None:
difference = maximum - minimum
power = math.floor(math.log10(difference) - 2)
step = 10 ** power
step = 10**power
self.step = step
self.default = minimum if default is None else default
self.test_input = self.default
@ -965,253 +965,252 @@ class Image(Component):
label=label, requires_permissions=requires_permissions, **kwargs
)
@classmethod
def get_shortcut_implementations(cls):
return {
"image": {},
"webcam": {"source": "webcam"},
"sketchpad": {
"image_mode": "L",
"source": "canvas",
"shape": (28, 28),
"invert_colors": True,
},
"plot": {"type": "plot"},
"pil": {"type": "pil"},
}
@classmethod
def get_shortcut_implementations(cls):
return {
"image": {},
"webcam": {"source": "webcam"},
"sketchpad": {
"image_mode": "L",
"source": "canvas",
"shape": (28, 28),
"invert_colors": True,
},
"plot": {"type": "plot"},
"pil": {"type": "pil"},
}
def get_template_context(self):
return {
"image_mode": self.image_mode,
"shape": self.shape,
"source": self.source,
"tool": self.tool,
"optional": self.optional,
**super().get_template_context(),
}
def get_template_context(self):
return {
"image_mode": self.image_mode,
"shape": self.shape,
"source": self.source,
"tool": self.tool,
**super().get_template_context(),
}
def preprocess(self, x: Optional[str]) -> np.array | PIL.Image | str | None:
"""
Parameters:
x (str): base64 url data
Returns:
(Union[numpy.array, PIL.Image, filepath]): image in requested format
"""
if x is None:
return x
im = processing_utils.decode_base64_to_image(x)
fmt = im.format
with warnings.catch_warnings():
warnings.simplefilter("ignore")
im = im.convert(self.image_mode)
if self.shape is not None:
im = processing_utils.resize_and_crop(im, self.shape)
if self.invert_colors:
im = PIL.ImageOps.invert(im)
if self.type == "pil":
return im
elif self.type == "numpy":
return np.array(im)
elif self.type == "file" or self.type == "filepath":
file_obj = tempfile.NamedTemporaryFile(
delete=False,
suffix=("." + fmt.lower() if fmt is not None else ".png"),
def preprocess(self, x: Optional[str]) -> np.array | PIL.Image | str | None:
"""
Parameters:
x (str): base64 url data
Returns:
(Union[numpy.array, PIL.Image, filepath]): image in requested format
"""
if x is None:
return x
im = processing_utils.decode_base64_to_image(x)
fmt = im.format
with warnings.catch_warnings():
warnings.simplefilter("ignore")
im = im.convert(self.image_mode)
if self.shape is not None:
im = processing_utils.resize_and_crop(im, self.shape)
if self.invert_colors:
im = PIL.ImageOps.invert(im)
if self.type == "pil":
return im
elif self.type == "numpy":
return np.array(im)
elif self.type == "file" or self.type == "filepath":
file_obj = tempfile.NamedTemporaryFile(
delete=False,
suffix=("." + fmt.lower() if fmt is not None else ".png"),
)
im.save(file_obj.name)
if self.type == "file":
warnings.warn(
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.",
DeprecationWarning,
)
im.save(file_obj.name)
if self.type == "file":
warnings.warn(
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.",
DeprecationWarning,
)
return file_obj
else:
return file_obj.name
return file_obj
else:
return file_obj.name
else:
raise ValueError(
"Unknown type: "
+ str(self.type)
+ ". Please choose from: 'numpy', 'pil', 'filepath'."
)
def preprocess_example(self, x):
return processing_utils.encode_file_to_base64(x)
def serialize(self, x, called_directly=False):
# if called directly, can assume it's a URL or filepath
if self.type == "filepath" or called_directly:
return processing_utils.encode_url_or_file_to_base64(x)
elif self.type == "file":
return processing_utils.encode_url_or_file_to_base64(x.name)
elif self.type in ("numpy", "pil"):
if self.type == "numpy":
x = PIL.Image.fromarray(np.uint8(x)).convert("RGB")
fmt = x.format
file_obj = tempfile.NamedTemporaryFile(
delete=False,
suffix=("." + fmt.lower() if fmt is not None else ".png"),
)
x.save(file_obj.name)
return processing_utils.encode_url_or_file_to_base64(file_obj.name)
else:
raise ValueError(
"Unknown type: "
+ str(self.type)
+ ". Please choose from: 'numpy', 'pil', 'filepath'."
)
def set_interpret_parameters(self, segments=16):
"""
Calculates interpretation score of image subsections by splitting the image into subsections, then using a "leave one out" method to calculate the score of each subsection by whiting out the subsection and measuring the delta of the output value.
Parameters:
segments (int): Number of interpretation segments to split image into.
"""
self.interpretation_segments = segments
return self
def _segment_by_slic(self, x):
"""
Helper method that segments an image into superpixels using slic.
Parameters:
x: base64 representation of an image
"""
x = processing_utils.decode_base64_to_image(x)
if self.shape is not None:
x = processing_utils.resize_and_crop(x, self.shape)
resized_and_cropped_image = np.array(x)
try:
from skimage.segmentation import slic
except (ImportError, ModuleNotFoundError):
raise ValueError(
"Error: running this interpretation for images requires scikit-image, please install it first."
)
try:
segments_slic = slic(
resized_and_cropped_image,
self.interpretation_segments,
compactness=10,
sigma=1,
start_label=1,
)
except TypeError: # For skimage 0.16 and older
segments_slic = slic(
resized_and_cropped_image,
self.interpretation_segments,
compactness=10,
sigma=1,
)
return segments_slic, resized_and_cropped_image
def tokenize(self, x):
"""
Segments image into tokens, masks, and leave-one-out-tokens
Parameters:
x: base64 representation of an image
Returns:
tokens: list of tokens, used by the get_masked_input() method
leave_one_out_tokens: list of left-out tokens, used by the get_interpretation_neighbors() method
masks: list of masks, used by the get_interpretation_neighbors() method
"""
segments_slic, resized_and_cropped_image = self._segment_by_slic(x)
tokens, masks, leave_one_out_tokens = [], [], []
replace_color = np.mean(resized_and_cropped_image, axis=(0, 1))
for (i, segment_value) in enumerate(np.unique(segments_slic)):
mask = segments_slic == segment_value
image_screen = np.copy(resized_and_cropped_image)
image_screen[segments_slic == segment_value] = replace_color
leave_one_out_tokens.append(
processing_utils.encode_array_to_base64(image_screen)
)
token = np.copy(resized_and_cropped_image)
token[segments_slic != segment_value] = 0
tokens.append(token)
masks.append(mask)
return tokens, leave_one_out_tokens, masks
def get_masked_inputs(self, tokens, binary_mask_matrix):
masked_inputs = []
for binary_mask_vector in binary_mask_matrix:
masked_input = np.zeros_like(tokens[0], dtype=int)
for token, b in zip(tokens, binary_mask_vector):
masked_input = masked_input + token * int(b)
masked_inputs.append(processing_utils.encode_array_to_base64(masked_input))
return masked_inputs
def get_interpretation_scores(
self, x, neighbors, scores, masks, tokens=None, **kwargs
):
"""
Returns:
(List[List[float]]): A 2D array representing the interpretation score of each pixel of the image.
"""
x = processing_utils.decode_base64_to_image(x)
if self.shape is not None:
x = processing_utils.resize_and_crop(x, self.shape)
x = np.array(x)
output_scores = np.zeros((x.shape[0], x.shape[1]))
for score, mask in zip(scores, masks):
output_scores += score * mask
max_val, min_val = np.max(output_scores), np.min(output_scores)
if max_val > 0:
output_scores = (output_scores - min_val) / (max_val - min_val)
return output_scores.tolist()
def save_flagged(self, dir, label, data, encryption_key):
"""
Returns: (str) path to image file
"""
return self.save_flagged_file(dir, label, data, encryption_key)
def generate_sample(self):
return test_data.BASE64_IMAGE
# Output functions
def postprocess(self, y):
"""
Parameters:
y (Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]): image in specified format
Returns:
(str): base64 url data
"""
if self.type == "auto":
if isinstance(y, np.ndarray):
dtype = "numpy"
elif isinstance(y, PIL.Image.Image):
dtype = "pil"
elif isinstance(y, str):
dtype = "file"
elif isinstance(y, ModuleType):
dtype = "plot"
else:
raise ValueError(
"Unknown type: "
+ str(self.type)
+ ". Please choose from: 'numpy', 'pil', 'filepath'."
"Unknown type. Please choose from: 'numpy', 'pil', 'file', 'plot'."
)
else:
dtype = self.type
if dtype in ["numpy", "pil"]:
if dtype == "pil":
y = np.array(y)
out_y = processing_utils.encode_array_to_base64(y)
elif dtype == "file":
out_y = processing_utils.encode_url_or_file_to_base64(y)
elif dtype == "plot":
out_y = processing_utils.encode_plot_to_base64(y)
else:
raise ValueError(
"Unknown type: "
+ dtype
+ ". Please choose from: 'numpy', 'pil', 'file', 'plot'."
)
return out_y
def preprocess_example(self, x):
return processing_utils.encode_file_to_base64(x)
def deserialize(self, x):
y = processing_utils.decode_base64_to_file(x).name
return y
def serialize(self, x, called_directly=False):
# if called directly, can assume it's a URL or filepath
if self.type == "filepath" or called_directly:
return processing_utils.encode_url_or_file_to_base64(x)
elif self.type == "file":
return processing_utils.encode_url_or_file_to_base64(x.name)
elif self.type in ("numpy", "pil"):
if self.type == "numpy":
x = PIL.Image.fromarray(np.uint8(x)).convert("RGB")
fmt = x.format
file_obj = tempfile.NamedTemporaryFile(
delete=False,
suffix=("." + fmt.lower() if fmt is not None else ".png"),
)
x.save(file_obj.name)
return processing_utils.encode_url_or_file_to_base64(file_obj.name)
else:
raise ValueError(
"Unknown type: "
+ str(self.type)
+ ". Please choose from: 'numpy', 'pil', 'filepath'."
)
def set_interpret_parameters(self, segments=16):
"""
Calculates interpretation score of image subsections by splitting the image into subsections, then using a "leave one out" method to calculate the score of each subsection by whiting out the subsection and measuring the delta of the output value.
Parameters:
segments (int): Number of interpretation segments to split image into.
"""
self.interpretation_segments = segments
return self
def _segment_by_slic(self, x):
"""
Helper method that segments an image into superpixels using slic.
Parameters:
x: base64 representation of an image
"""
x = processing_utils.decode_base64_to_image(x)
if self.shape is not None:
x = processing_utils.resize_and_crop(x, self.shape)
resized_and_cropped_image = np.array(x)
try:
from skimage.segmentation import slic
except (ImportError, ModuleNotFoundError):
raise ValueError(
"Error: running this interpretation for images requires scikit-image, please install it first."
)
try:
segments_slic = slic(
resized_and_cropped_image,
self.interpretation_segments,
compactness=10,
sigma=1,
start_label=1,
)
except TypeError: # For skimage 0.16 and older
segments_slic = slic(
resized_and_cropped_image,
self.interpretation_segments,
compactness=10,
sigma=1,
)
return segments_slic, resized_and_cropped_image
def tokenize(self, x):
"""
Segments image into tokens, masks, and leave-one-out-tokens
Parameters:
x: base64 representation of an image
Returns:
tokens: list of tokens, used by the get_masked_input() method
leave_one_out_tokens: list of left-out tokens, used by the get_interpretation_neighbors() method
masks: list of masks, used by the get_interpretation_neighbors() method
"""
segments_slic, resized_and_cropped_image = self._segment_by_slic(x)
tokens, masks, leave_one_out_tokens = [], [], []
replace_color = np.mean(resized_and_cropped_image, axis=(0, 1))
for (i, segment_value) in enumerate(np.unique(segments_slic)):
mask = segments_slic == segment_value
image_screen = np.copy(resized_and_cropped_image)
image_screen[segments_slic == segment_value] = replace_color
leave_one_out_tokens.append(
processing_utils.encode_array_to_base64(image_screen)
)
token = np.copy(resized_and_cropped_image)
token[segments_slic != segment_value] = 0
tokens.append(token)
masks.append(mask)
return tokens, leave_one_out_tokens, masks
def get_masked_inputs(self, tokens, binary_mask_matrix):
masked_inputs = []
for binary_mask_vector in binary_mask_matrix:
masked_input = np.zeros_like(tokens[0], dtype=int)
for token, b in zip(tokens, binary_mask_vector):
masked_input = masked_input + token * int(b)
masked_inputs.append(
processing_utils.encode_array_to_base64(masked_input)
)
return masked_inputs
def get_interpretation_scores(self, x, neighbors, scores, masks, tokens=None):
"""
Returns:
(List[List[float]]): A 2D array representing the interpretation score of each pixel of the image.
"""
x = processing_utils.decode_base64_to_image(x)
if self.shape is not None:
x = processing_utils.resize_and_crop(x, self.shape)
x = np.array(x)
output_scores = np.zeros((x.shape[0], x.shape[1]))
for score, mask in zip(scores, masks):
output_scores += score * mask
max_val, min_val = np.max(output_scores), np.min(output_scores)
if max_val > 0:
output_scores = (output_scores - min_val) / (max_val - min_val)
return output_scores.tolist()
def save_flagged(self, dir, label, data, encryption_key):
"""
Returns: (str) path to image file
"""
return self.save_flagged_file(dir, label, data, encryption_key)
def generate_sample(self):
return test_data.BASE64_IMAGE
# Output functions
def postprocess(self, y):
"""
Parameters:
y (Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]): image in specified format
Returns:
(str): base64 url data
"""
if self.type == "auto":
if isinstance(y, np.ndarray):
dtype = "numpy"
elif isinstance(y, PIL.Image.Image):
dtype = "pil"
elif isinstance(y, str):
dtype = "file"
elif isinstance(y, ModuleType):
dtype = "plot"
else:
raise ValueError(
"Unknown type. Please choose from: 'numpy', 'pil', 'file', 'plot'."
)
else:
dtype = self.type
if dtype in ["numpy", "pil"]:
if dtype == "pil":
y = np.array(y)
out_y = processing_utils.encode_array_to_base64(y)
elif dtype == "file":
out_y = processing_utils.encode_url_or_file_to_base64(y)
elif dtype == "plot":
out_y = processing_utils.encode_plot_to_base64(y)
else:
raise ValueError(
"Unknown type: "
+ dtype
+ ". Please choose from: 'numpy', 'pil', 'file', 'plot'."
)
return out_y
def deserialize(self, x):
y = processing_utils.decode_base64_to_file(x).name
return y
def restore_flagged(self, dir, data, encryption_key):
return self.restore_flagged_file(dir, data, encryption_key)["data"]
def restore_flagged(self, dir, data, encryption_key):
return self.restore_flagged_file(dir, data, encryption_key)["data"]
class Video(Component):
@ -1230,7 +1229,7 @@ class Video(Component):
type: Optional[str] = None,
source: str = "upload",
label: Optional[str] = None,
**kwargs
**kwargs,
):
"""
Parameters:
@ -1254,7 +1253,6 @@ class Video(Component):
def get_template_context(self):
return {
"source": self.source,
"optional": self.optional,
**super().get_template_context(),
}
@ -1284,7 +1282,7 @@ class Video(Component):
file_name = file.name
uploaded_format = file_name.split(".")[-1].lower()
if self.type is not None and uploaded_format != self.type:
output_file_name = file_name[0: file_name.rindex(".") + 1] + self.type
output_file_name = file_name[0 : file_name.rindex(".") + 1] + self.type
ff = FFmpeg(inputs={file_name: None}, outputs={output_file_name: None})
ff.run()
return output_file_name
@ -1314,7 +1312,7 @@ class Video(Component):
"""
returned_format = y.split(".")[-1].lower()
if self.type is not None and returned_format != self.type:
output_file_name = y[0: y.rindex(".") + 1] + self.type
output_file_name = y[0 : y.rindex(".") + 1] + self.type
ff = FFmpeg(inputs={y: None}, outputs={output_file_name: None})
ff.run()
y = output_file_name
@ -1347,7 +1345,7 @@ class Audio(Component):
source: str = "upload",
type: str = "numpy",
label: str = None,
**kwargs
**kwargs,
):
"""
Parameters:
@ -1360,7 +1358,9 @@ class Audio(Component):
self.type = type
self.test_input = test_data.BASE64_AUDIO
self.interpret_by_tokens = True
super().__init__(label=label, requires_permissions=requires_permissions, **kwargs)
super().__init__(
label=label, requires_permissions=requires_permissions, **kwargs
)
def get_template_context(self):
return {

View File

@ -19,6 +19,7 @@ from ffmpy import FFmpeg
from gradio import processing_utils, test_data
from gradio.components import (
Audio,
Checkbox,
CheckboxGroup,
Component,
@ -28,7 +29,6 @@ from gradio.components import (
Radio,
Slider,
Textbox,
Audio
)
if TYPE_CHECKING: # Only import for type checking (is False at runtime).

View File

@ -21,7 +21,7 @@ import PIL
from ffmpy import FFmpeg
from gradio import processing_utils
from gradio.components import Component, Image, Textbox, Video, Audio
from gradio.components import Audio, Component, Image, Textbox, Video
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio import Interface
@ -83,6 +83,8 @@ class Video(Video):
DeprecationWarning,
)
super().__init__(label=label, type=type)
class Audio(Audio):
"""
Creates an audio player that plays the output audio.
@ -103,7 +105,6 @@ class Audio(Audio):
super().__init__(type=type, label=label)
class OutputComponent(Component):
"""
Output Component. All output components subclass this.
@ -328,8 +329,6 @@ class HighlightedText(OutputComponent):
return json.loads(data)
class JSON(OutputComponent):
"""
Used for JSON output. Expects a JSON string or a Python object that is JSON serializable.