added detailed explanations of what each output expects along with relevant tests

This commit is contained in:
Abubakar Abid 2020-07-07 00:51:07 -05:00
parent 89fcb2fa93
commit 2d3681cdad
4 changed files with 71 additions and 83 deletions

View File

@ -86,6 +86,7 @@ class Interface:
self.thumbnail = thumbnail
self.examples = examples
self.server_port = None
self.simple_server = None
Interface.instances.add(self)
def get_config_file(self):
@ -210,7 +211,7 @@ class Interface:
raise RuntimeError("Validation did not pass")
def close(self):
if self.server_port:
if self.simple_server and not(self.simple_server.fileno() == -1): # checks to see if server is running
print("Closing Gradio server on port {}...".format(self.server_port))
networking.close_server(self.simple_server)

View File

@ -44,6 +44,10 @@ class AbstractOutput(ABC):
class Label(AbstractOutput):
LABEL_KEY = "label"
CONFIDENCE_KEY = "confidence"
CONFIDENCES_KEY = "confidences"
def __init__(self, num_top_classes=None, label=None):
self.num_top_classes = num_top_classes
super().__init__(label)
@ -60,16 +64,19 @@ class Label(AbstractOutput):
if self.num_top_classes is not None:
sorted_pred = sorted_pred[:self.num_top_classes]
return {
"label": sorted_pred[0][0],
"confidences": [
self.LABEL_KEY: sorted_pred[0][0],
self.CONFIDENCES_KEY: [
{
"label": pred[0],
"confidence" : pred[1]
self.LABEL_KEY: pred[0],
self.CONFIDENCE_KEY: pred[1]
} for pred in sorted_pred
]
}
elif isinstance(prediction, int) or isinstance(prediction, float):
return {self.LABEL_KEY: str(prediction)}
else:
raise ValueError("Function output should be string or dict")
raise ValueError("The `Label` output interface expects one of: a string label, or an int label, a "
"float label, or a dictionary whose keys are labels and values are confidences.")
@classmethod
def get_shortcut_implementations(cls):
@ -82,6 +89,13 @@ class KeyValues(AbstractOutput):
def __init__(self, label=None):
super().__init__(label)
def postprocess(self, prediction):
if isinstance(prediction, dict):
return prediction
else:
raise ValueError("The `KeyValues` output interface expects an output that is a dictionary whose keys are "
"labels and values are corresponding values.")
@classmethod
def get_shortcut_implementations(cls):
return {
@ -111,9 +125,11 @@ class Textbox(AbstractOutput):
}
def postprocess(self, prediction):
"""
"""
return prediction
if isinstance(prediction, str) or isinstance(prediction, int) or isinstance(prediction, float):
return str(prediction)
else:
raise ValueError("The `Textbox` output interface expects an output that is one of: a string, or"
"an int/float that can be converted to a string.")
class Image(AbstractOutput):
@ -132,9 +148,16 @@ class Image(AbstractOutput):
"""
"""
if self.plot:
return preprocessing_utils.encode_plot_to_base64(prediction)
try:
return preprocessing_utils.encode_plot_to_base64(prediction)
except:
raise ValueError("The `Image` output interface expects a `matplotlib.pyplot` object"
"if plt=True.")
else:
return preprocessing_utils.encode_array_to_base64(prediction)
try:
return preprocessing_utils.encode_array_to_base64(prediction)
except:
raise ValueError("The `Image` output interface (with plt=False) expects a numpy array.")
def rebuild_flagged(self, dir, msg):
"""

View File

@ -7,8 +7,7 @@ import gradio.outputs
class TestInterface(unittest.TestCase):
def test_input_output_mapping(self):
io = gr.Interface(inputs='SketCHPad', outputs='textBOX', fn=lambda
x: x)
io = gr.Interface(inputs='SketCHPad', outputs='textBOX', fn=lambda x: x)
self.assertIsInstance(io.input_interfaces[0], gradio.inputs.Sketchpad)
self.assertIsInstance(io.output_interfaces[0], gradio.outputs.Textbox)
@ -18,52 +17,15 @@ class TestInterface(unittest.TestCase):
self.assertEqual(io.input_interfaces[0], inp)
def test_output_interface_is_instance(self):
# out = gradio.outputs.Label(show_confidences=False)
out = gradio.outputs.Label()
io = gr.Interface(inputs='SketCHPad', outputs=out, fn=lambda x: x)
self.assertEqual(io.output_interfaces[0], out)
def test_keras_model(self):
try:
import tensorflow as tf
except:
raise unittest.SkipTest("Need tensorflow installed to do keras-based tests")
inputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
io = gr.Interface(inputs='SketCHPad', outputs='textBOX', fn=model)
# pred = io.predict(np.ones(shape=(1, 3), ))
# self.assertEqual(pred.shape, (1, 5))
def test_func_model(self):
def test_prediction(self):
def model(x):
return 2*x
io = gr.Interface(inputs='SketCHPad', outputs='textBOX', fn=model)
# pred = io.predict(np.ones(shape=(1, 3)))
# self.assertEqual(pred.shape, (1, 3))
def test_pytorch_model(self):
try:
import torch
except:
raise unittest.SkipTest("Need torch installed to do pytorch-based tests")
class TwoLayerNet(torch.nn.Module):
def __init__(self):
super(TwoLayerNet, self).__init__()
self.linear1 = torch.nn.Linear(3, 4)
self.linear2 = torch.nn.Linear(4, 5)
def forward(self, x):
h_relu = torch.nn.functional.relu(self.linear1(x))
y_pred = self.linear2(h_relu)
return y_pred
model = TwoLayerNet()
io = gr.Interface(inputs='SketCHPad', outputs='textBOX', fn=model)
# pred = io.predict(np.ones(shape=(1, 3), dtype=np.float32))
# self.assertEqual(pred.shape, (1, 5))
io = gr.Interface(inputs='textbox', outputs='textBOX', fn=model)
self.assertEqual(io.predict[0](11), 22)
if __name__ == '__main__':

View File

@ -14,37 +14,39 @@ class TestLabel(unittest.TestCase):
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__.lower())
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
# def test_postprocessing_string(self):
# string = 'happy'
# out = outputs.Label()
# label = json.loads(out.postprocess(string))
# self.assertDictEqual(label, {outputs.Label.LABEL_KEY: string})
#
# def test_postprocessing_1D_array(self):
# array = np.array([0.1, 0.2, 0, 0.7, 0])
# true_label = {outputs.Label.LABEL_KEY: 3,
# outputs.Label.CONFIDENCES_KEY: [
# {outputs.Label.LABEL_KEY: 3, outputs.Label.CONFIDENCE_KEY: 0.7},
# {outputs.Label.LABEL_KEY: 1, outputs.Label.CONFIDENCE_KEY: 0.2},
# {outputs.Label.LABEL_KEY: 0, outputs.Label.CONFIDENCE_KEY: 0.1},
# ]}
# out = outputs.Label()
# label = json.loads(out.postprocess(array))
# self.assertDictEqual(label, true_label)
def test_postprocessing_string(self):
string = 'happy'
out = outputs.Label()
label = out.postprocess(string)
self.assertDictEqual(label, {outputs.Label.LABEL_KEY: string})
# def test_postprocessing_1D_array_no_confidences(self):
# array = np.array([0.1, 0.2, 0, 0.7, 0])
# true_label = {outputs.Label.LABEL_KEY: 3}
# out = outputs.Label(show_confidences=False)
# label = json.loads(out.postprocess(array))
# self.assertDictEqual(label, true_label)
#
# def test_postprocessing_int(self):
# true_label_array = np.array([[[3]]])
# true_label = {outputs.Label.LABEL_KEY: 3}
# out = outputs.Label()
# label = json.loads(out.postprocess(true_label_array))
# self.assertDictEqual(label, true_label)
def test_postprocessing_dict(self):
orig_label = {
3: 0.7,
1: 0.2,
0: 0.1
}
true_label = {outputs.Label.LABEL_KEY: 3,
outputs.Label.CONFIDENCES_KEY: [
{outputs.Label.LABEL_KEY: 3, outputs.Label.CONFIDENCE_KEY: 0.7},
{outputs.Label.LABEL_KEY: 1, outputs.Label.CONFIDENCE_KEY: 0.2},
{outputs.Label.LABEL_KEY: 0, outputs.Label.CONFIDENCE_KEY: 0.1},
]}
out = outputs.Label()
label = out.postprocess(orig_label)
self.assertDictEqual(label, true_label)
def test_postprocessing_array(self):
array = np.array([0.1, 0.2, 0, 0.7, 0])
out = outputs.Label()
self.assertRaises(ValueError, out.postprocess, array)
def test_postprocessing_int(self):
label = 3
true_label = {outputs.Label.LABEL_KEY: '3'}
out = outputs.Label()
label = out.postprocess(label)
self.assertDictEqual(label, true_label)
class TestTextbox(unittest.TestCase):