labeled image segments

This commit is contained in:
Ali Abid 2020-11-23 15:11:37 -08:00
parent a19842884c
commit 37028bcf25
15 changed files with 138 additions and 26 deletions

View File

@ -127,6 +127,7 @@ You can provide example data that a user can easily load into the model. This ca
```python
import gradio as gr
import random
def calculator(num1, operation, num2):
if operation == "add":
@ -138,19 +139,28 @@ def calculator(num1, operation, num2):
elif operation == "divide":
return num1 / num2
iface = gr.Interface(calculator,
["number", gr.inputs.Radio(["add", "subtract", "multiply", "divide"]), "number"],
"number",
examples=[
[5, "add", 3],
[12, "divide", -2]
[4, "divide", 2],
[-4, "multiply", 2.5],
[0, "subtract", 1.2],
]
)
iface.launch()
```
![calculator interface](demo/screenshots/calculator/1.gif)
### Exploring Similar Examples with Embeddings
When you provide input to the function, you may wish to see if there are similar samples in the example dataset that could explain the behaviour of the function. For example, if an image model returns a peculiar output for a given input, you may load the training data into the examples dataset and see what training data samples are similar to the input you provided. If you enable this feature, you can click the *Order by Similarity* button to show the most similar samples from the example dataset.
Gradio supports exploring similar data samples through embeddings. Embeddings are a list of floats that numerically represent any input. To the `embedding` keyword argument of Interface, you must pass a function that takes the same inputs as the main `fn` argument, but instead returns an embedding that represents all the input values as a single list of floats. You can also pass the "default" string to `embedding` and Gradio will automatically generate embeddings for each sample in the examples dataset.
### Flagging
Underneath the output interfaces, there is a button marked "Flag". When a user testing your model sees input with interesting output, such as erroneous or unexpected model behaviour, they can flag the input for review. Within the directory provided by the `flagging_dir=` argument to the Interface constructor, a CSV file will log the flagged inputs. If the interface involved file inputs, such as for Image and Audio interfaces, folders will be created to store those flagged inputs as well.

View File

@ -528,7 +528,7 @@ class Image(InputComponent):
Input type: Union[numpy.array, PIL.Image, str]
"""
def __init__(self, shape=None, image_mode='RGB', invert_colors=False, source="upload", tool="editor", type="numpy", label=None):
def __init__(self, shape=None, image_mode='RGB', invert_colors=False, source="upload", tool="editor", labeled_segments=False, type="numpy", label=None):
'''
Parameters:
shape (Tuple[int, int]): (width, height) shape to crop and resize image to; if None, matches input image size.

View File

@ -128,16 +128,18 @@ class Label(OutputComponent):
class Image(OutputComponent):
'''
Component displays an output image.
Output type: Union[numpy.array, PIL.Image, str, matplotlib.pyplot]
Output type: Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]
'''
def __init__(self, type="auto", plot=False, label=None):
def __init__(self, type="auto", labeled_segments=False, plot=False, label=None):
'''
Parameters:
type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image, "plot" expects a matplotlib.pyplot object, "auto" detects return type.
labeled_segments (bool): If True, expects a two-element tuple to be returned. The first element of the tuple is the image of format specified by type. The second element is a list of tuples, where each tuple represents a labeled segment within the image. The first element of the tuple is the string label of the segment, followed by 4 floats that represent the left-x, top-y, right-x, and bottom-y coordinates of the bounding box.
plot (bool): DEPRECATED. Whether to expect a plot to be returned by the function.
label (str): component name in interface.
'''
self.labeled_segments = labeled_segments
if plot:
warnings.warn("The 'plot' parameter has been deprecated. Set parameter 'type' to 'plot' instead.", DeprecationWarning)
self.type = "plot"
@ -149,11 +151,16 @@ class Image(OutputComponent):
def get_shortcut_implementations(cls):
return {
"image": {},
"segmented_image": {"labeled_segments": True},
"plot": {"type": "plot"},
"pil": {"type": "pil"}
}
def postprocess(self, y):
if self.labeled_segments:
y, coordinates = y
else:
coordinates = []
if self.type == "auto":
if isinstance(y, np.ndarray):
dtype = "numpy"
@ -168,13 +175,14 @@ class Image(OutputComponent):
if dtype in ["numpy", "pil"]:
if dtype == "pil":
y = np.array(y)
return processing_utils.encode_array_to_base64(y)
out_y = processing_utils.encode_array_to_base64(y)
elif dtype == "file":
return processing_utils.encode_file_to_base64(y)
out_y = processing_utils.encode_file_to_base64(y)
elif dtype == "plot":
return processing_utils.encode_plot_to_base64(y)
out_y = processing_utils.encode_plot_to_base64(y)
else:
raise ValueError("Unknown type: " + dtype + ". Please choose from: 'numpy', 'pil', 'file', 'plot'.")
return out_y, coordinates
def rebuild(self, dir, data):
"""

View File

@ -54,5 +54,5 @@
}
.view_holders canvas {
background-color: white;
border: solid 1px black;
border: solid 1px white;
}

View File

@ -9,5 +9,4 @@
width: 100%;
height: 100%;
object-fit: contain;
display: none;
}

View File

@ -1,17 +1,47 @@
const image_output = {
html: `
<div class="interface_box">
<div class="output_image_holder">
<img class="output_image" />
<div class="view_holders">
<div class="saliency_holder hide">
<canvas class="saliency"></canvas>
</div>
<div class="output_image_holder hide">
<img class="output_image">
</div>
</div>
</div>
`,
init: function(opts) {},
output: function(data) {
this.target.find(".output_image").attr('src', data).show();
let io = this;
let [img_data, coord] = data;
this.target.find(".output_image_holder").removeClass("hide");
img = this.target.find(".output_image").attr('src', img_data);
if (coord.length) {
img = img[0];
img.onload = function() {
var size = getObjectFitSize(true, img.width, img.height, img.naturalWidth, img.naturalHeight);
var width = size.width;
var height = size.height;
io.target.find(".saliency_holder").removeClass("hide").html(`
<canvas class="saliency" width=${width} height=${height}></canvas>`);
var ctx = io.target.find(".saliency")[0].getContext('2d');
ctx.lineWidth = 2;
ctx.strokeStyle = 'red';
ctx.font = '16px monospace';
ctx.textBaseline = 'top';
for (let [label, left_x, top_y, right_x, bottom_y] of coord) {
ctx.rect(left_x, top_y, right_x - left_x, bottom_y - top_y);
ctx.fillText(label, left_x + 2, top_y + 2)
}
ctx.stroke();
}
}
},
clear: function() {
this.target.find(".output_image").attr('src', "").hide();
this.target.find(".output_image_holder").addClass("hide");
this.target.find(".saliency_holder").addClass("hide");
this.target.find(".output_image").attr('src', "")
},
load_example_preview: function(data) {
return "<img src='"+data+"' height=100>"

BIN
demo/images/stop_1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 324 KiB

BIN
demo/images/stop_2.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 192 KiB

View File

@ -0,0 +1,22 @@
# Demo: (Image) -> (Image)
import gradio as gr
def detect(image):
# return image
return (image, [("sign", 210, 80, 280, 150), ("train", 100, 100, 180, 150)])
iface = gr.Interface(detect,
gr.inputs.Image(type="pil"),
# "image",
"segmented_image",
examples=[
["images/stop_1.jpg"],
["images/stop_2.jpg"],
])
iface.test_launch()
if __name__ == "__main__":
iface.launch()

View File

@ -528,7 +528,7 @@ class Image(InputComponent):
Input type: Union[numpy.array, PIL.Image, str]
"""
def __init__(self, shape=None, image_mode='RGB', invert_colors=False, source="upload", tool="editor", type="numpy", label=None):
def __init__(self, shape=None, image_mode='RGB', invert_colors=False, source="upload", tool="editor", labeled_segments=False, type="numpy", label=None):
'''
Parameters:
shape (Tuple[int, int]): (width, height) shape to crop and resize image to; if None, matches input image size.

View File

@ -128,16 +128,18 @@ class Label(OutputComponent):
class Image(OutputComponent):
'''
Component displays an output image.
Output type: Union[numpy.array, PIL.Image, str, matplotlib.pyplot]
Output type: Union[numpy.array, PIL.Image, str, matplotlib.pyplot, Tuple[Union[numpy.array, PIL.Image, str], List[Tuple[str, float, float, float, float]]]]
'''
def __init__(self, type="auto", plot=False, label=None):
def __init__(self, type="auto", labeled_segments=False, plot=False, label=None):
'''
Parameters:
type (str): Type of value to be passed to component. "numpy" expects a numpy array with shape (width, height, 3), "pil" expects a PIL image object, "file" expects a file path to the saved image, "plot" expects a matplotlib.pyplot object, "auto" detects return type.
labeled_segments (bool): If True, expects a two-element tuple to be returned. The first element of the tuple is the image of format specified by type. The second element is a list of tuples, where each tuple represents a labeled segment within the image. The first element of the tuple is the string label of the segment, followed by 4 floats that represent the left-x, top-y, right-x, and bottom-y coordinates of the bounding box.
plot (bool): DEPRECATED. Whether to expect a plot to be returned by the function.
label (str): component name in interface.
'''
self.labeled_segments = labeled_segments
if plot:
warnings.warn("The 'plot' parameter has been deprecated. Set parameter 'type' to 'plot' instead.", DeprecationWarning)
self.type = "plot"
@ -149,11 +151,16 @@ class Image(OutputComponent):
def get_shortcut_implementations(cls):
return {
"image": {},
"segmented_image": {"labeled_segments": True},
"plot": {"type": "plot"},
"pil": {"type": "pil"}
}
def postprocess(self, y):
if self.labeled_segments:
y, coordinates = y
else:
coordinates = []
if self.type == "auto":
if isinstance(y, np.ndarray):
dtype = "numpy"
@ -168,13 +175,14 @@ class Image(OutputComponent):
if dtype in ["numpy", "pil"]:
if dtype == "pil":
y = np.array(y)
return processing_utils.encode_array_to_base64(y)
out_y = processing_utils.encode_array_to_base64(y)
elif dtype == "file":
return processing_utils.encode_file_to_base64(y)
out_y = processing_utils.encode_file_to_base64(y)
elif dtype == "plot":
return processing_utils.encode_plot_to_base64(y)
out_y = processing_utils.encode_plot_to_base64(y)
else:
raise ValueError("Unknown type: " + dtype + ". Please choose from: 'numpy', 'pil', 'file', 'plot'.")
return out_y, coordinates
def rebuild(self, dir, data):
"""

View File

@ -54,5 +54,5 @@
}
.view_holders canvas {
background-color: white;
border: solid 1px black;
border: solid 1px white;
}

View File

@ -9,5 +9,4 @@
width: 100%;
height: 100%;
object-fit: contain;
display: none;
}

View File

@ -1,17 +1,47 @@
const image_output = {
html: `
<div class="interface_box">
<div class="output_image_holder">
<img class="output_image" />
<div class="view_holders">
<div class="saliency_holder hide">
<canvas class="saliency"></canvas>
</div>
<div class="output_image_holder hide">
<img class="output_image">
</div>
</div>
</div>
`,
init: function(opts) {},
output: function(data) {
this.target.find(".output_image").attr('src', data).show();
let io = this;
let [img_data, coord] = data;
this.target.find(".output_image_holder").removeClass("hide");
img = this.target.find(".output_image").attr('src', img_data);
if (coord.length) {
img = img[0];
img.onload = function() {
var size = getObjectFitSize(true, img.width, img.height, img.naturalWidth, img.naturalHeight);
var width = size.width;
var height = size.height;
io.target.find(".saliency_holder").removeClass("hide").html(`
<canvas class="saliency" width=${width} height=${height}></canvas>`);
var ctx = io.target.find(".saliency")[0].getContext('2d');
ctx.lineWidth = 2;
ctx.strokeStyle = 'red';
ctx.font = '16px monospace';
ctx.textBaseline = 'top';
for (let [label, left_x, top_y, right_x, bottom_y] of coord) {
ctx.rect(left_x, top_y, right_x - left_x, bottom_y - top_y);
ctx.fillText(label, left_x + 2, top_y + 2)
}
ctx.stroke();
}
}
},
clear: function() {
this.target.find(".output_image").attr('src', "").hide();
this.target.find(".output_image_holder").addClass("hide");
this.target.find(".saliency_holder").addClass("hide");
this.target.find(".output_image").attr('src', "")
},
load_example_preview: function(data) {
return "<img src='"+data+"' height=100>"

View File

@ -80,6 +80,12 @@ You can provide example data that a user can easily load into the model. This ca
$code_calculator
$demo_calculator
### Exploring Similar Examples with Embeddings
When you provide input to the function, you may wish to see if there are similar samples in the example dataset that could explain the behaviour of the function. For example, if an image model returns a peculiar output for a given input, you may load the training data into the examples dataset and see what training data samples are similar to the input you provided. If you enable this feature, you can click the *Order by Similarity* button to show the most similar samples from the example dataset.
Gradio supports exploring similar data samples through embeddings. Embeddings are a list of floats that numerically represent any input. To the `embedding` keyword argument of Interface, you must pass a function that takes the same inputs as the main `fn` argument, but instead returns an embedding that represents all the input values as a single list of floats. You can also pass the "default" string to `embedding` and Gradio will automatically generate embeddings for each sample in the examples dataset.
### Flagging
Underneath the output interfaces, there is a button marked "Flag". When a user testing your model sees input with interesting output, such as erroneous or unexpected model behaviour, they can flag the input for review. Within the directory provided by the `flagging_dir=` argument to the Interface constructor, a CSV file will log the flagged inputs. If the interface involved file inputs, such as for Image and Audio interfaces, folders will be created to store those flagged inputs as well.