mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-06 10:25:17 +08:00
added detailed explanations of what each output expects along with relevant tests
This commit is contained in:
parent
89fcb2fa93
commit
2d3681cdad
@ -86,6 +86,7 @@ class Interface:
|
||||
self.thumbnail = thumbnail
|
||||
self.examples = examples
|
||||
self.server_port = None
|
||||
self.simple_server = None
|
||||
Interface.instances.add(self)
|
||||
|
||||
def get_config_file(self):
|
||||
@ -210,7 +211,7 @@ class Interface:
|
||||
raise RuntimeError("Validation did not pass")
|
||||
|
||||
def close(self):
|
||||
if self.server_port:
|
||||
if self.simple_server and not(self.simple_server.fileno() == -1): # checks to see if server is running
|
||||
print("Closing Gradio server on port {}...".format(self.server_port))
|
||||
networking.close_server(self.simple_server)
|
||||
|
||||
|
@ -44,6 +44,10 @@ class AbstractOutput(ABC):
|
||||
|
||||
|
||||
class Label(AbstractOutput):
|
||||
LABEL_KEY = "label"
|
||||
CONFIDENCE_KEY = "confidence"
|
||||
CONFIDENCES_KEY = "confidences"
|
||||
|
||||
def __init__(self, num_top_classes=None, label=None):
|
||||
self.num_top_classes = num_top_classes
|
||||
super().__init__(label)
|
||||
@ -60,16 +64,19 @@ class Label(AbstractOutput):
|
||||
if self.num_top_classes is not None:
|
||||
sorted_pred = sorted_pred[:self.num_top_classes]
|
||||
return {
|
||||
"label": sorted_pred[0][0],
|
||||
"confidences": [
|
||||
self.LABEL_KEY: sorted_pred[0][0],
|
||||
self.CONFIDENCES_KEY: [
|
||||
{
|
||||
"label": pred[0],
|
||||
"confidence" : pred[1]
|
||||
self.LABEL_KEY: pred[0],
|
||||
self.CONFIDENCE_KEY: pred[1]
|
||||
} for pred in sorted_pred
|
||||
]
|
||||
}
|
||||
elif isinstance(prediction, int) or isinstance(prediction, float):
|
||||
return {self.LABEL_KEY: str(prediction)}
|
||||
else:
|
||||
raise ValueError("Function output should be string or dict")
|
||||
raise ValueError("The `Label` output interface expects one of: a string label, or an int label, a "
|
||||
"float label, or a dictionary whose keys are labels and values are confidences.")
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
@ -82,6 +89,13 @@ class KeyValues(AbstractOutput):
|
||||
def __init__(self, label=None):
|
||||
super().__init__(label)
|
||||
|
||||
def postprocess(self, prediction):
|
||||
if isinstance(prediction, dict):
|
||||
return prediction
|
||||
else:
|
||||
raise ValueError("The `KeyValues` output interface expects an output that is a dictionary whose keys are "
|
||||
"labels and values are corresponding values.")
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
return {
|
||||
@ -111,9 +125,11 @@ class Textbox(AbstractOutput):
|
||||
}
|
||||
|
||||
def postprocess(self, prediction):
|
||||
"""
|
||||
"""
|
||||
return prediction
|
||||
if isinstance(prediction, str) or isinstance(prediction, int) or isinstance(prediction, float):
|
||||
return str(prediction)
|
||||
else:
|
||||
raise ValueError("The `Textbox` output interface expects an output that is one of: a string, or"
|
||||
"an int/float that can be converted to a string.")
|
||||
|
||||
|
||||
class Image(AbstractOutput):
|
||||
@ -132,9 +148,16 @@ class Image(AbstractOutput):
|
||||
"""
|
||||
"""
|
||||
if self.plot:
|
||||
try:
|
||||
return preprocessing_utils.encode_plot_to_base64(prediction)
|
||||
except:
|
||||
raise ValueError("The `Image` output interface expects a `matplotlib.pyplot` object"
|
||||
"if plt=True.")
|
||||
else:
|
||||
try:
|
||||
return preprocessing_utils.encode_array_to_base64(prediction)
|
||||
except:
|
||||
raise ValueError("The `Image` output interface (with plt=False) expects a numpy array.")
|
||||
|
||||
def rebuild_flagged(self, dir, msg):
|
||||
"""
|
||||
|
@ -7,8 +7,7 @@ import gradio.outputs
|
||||
|
||||
class TestInterface(unittest.TestCase):
|
||||
def test_input_output_mapping(self):
|
||||
io = gr.Interface(inputs='SketCHPad', outputs='textBOX', fn=lambda
|
||||
x: x)
|
||||
io = gr.Interface(inputs='SketCHPad', outputs='textBOX', fn=lambda x: x)
|
||||
self.assertIsInstance(io.input_interfaces[0], gradio.inputs.Sketchpad)
|
||||
self.assertIsInstance(io.output_interfaces[0], gradio.outputs.Textbox)
|
||||
|
||||
@ -18,52 +17,15 @@ class TestInterface(unittest.TestCase):
|
||||
self.assertEqual(io.input_interfaces[0], inp)
|
||||
|
||||
def test_output_interface_is_instance(self):
|
||||
# out = gradio.outputs.Label(show_confidences=False)
|
||||
out = gradio.outputs.Label()
|
||||
io = gr.Interface(inputs='SketCHPad', outputs=out, fn=lambda x: x)
|
||||
self.assertEqual(io.output_interfaces[0], out)
|
||||
|
||||
def test_keras_model(self):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except:
|
||||
raise unittest.SkipTest("Need tensorflow installed to do keras-based tests")
|
||||
inputs = tf.keras.Input(shape=(3,))
|
||||
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
|
||||
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
|
||||
model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
io = gr.Interface(inputs='SketCHPad', outputs='textBOX', fn=model)
|
||||
# pred = io.predict(np.ones(shape=(1, 3), ))
|
||||
# self.assertEqual(pred.shape, (1, 5))
|
||||
|
||||
def test_func_model(self):
|
||||
def test_prediction(self):
|
||||
def model(x):
|
||||
return 2*x
|
||||
io = gr.Interface(inputs='SketCHPad', outputs='textBOX', fn=model)
|
||||
# pred = io.predict(np.ones(shape=(1, 3)))
|
||||
# self.assertEqual(pred.shape, (1, 3))
|
||||
|
||||
def test_pytorch_model(self):
|
||||
try:
|
||||
import torch
|
||||
except:
|
||||
raise unittest.SkipTest("Need torch installed to do pytorch-based tests")
|
||||
|
||||
class TwoLayerNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TwoLayerNet, self).__init__()
|
||||
self.linear1 = torch.nn.Linear(3, 4)
|
||||
self.linear2 = torch.nn.Linear(4, 5)
|
||||
|
||||
def forward(self, x):
|
||||
h_relu = torch.nn.functional.relu(self.linear1(x))
|
||||
y_pred = self.linear2(h_relu)
|
||||
return y_pred
|
||||
|
||||
model = TwoLayerNet()
|
||||
io = gr.Interface(inputs='SketCHPad', outputs='textBOX', fn=model)
|
||||
# pred = io.predict(np.ones(shape=(1, 3), dtype=np.float32))
|
||||
# self.assertEqual(pred.shape, (1, 5))
|
||||
io = gr.Interface(inputs='textbox', outputs='textBOX', fn=model)
|
||||
self.assertEqual(io.predict[0](11), 22)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -14,37 +14,39 @@ class TestLabel(unittest.TestCase):
|
||||
path = outputs.BASE_OUTPUT_INTERFACE_JS_PATH.format(out.__class__.__name__.lower())
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
# def test_postprocessing_string(self):
|
||||
# string = 'happy'
|
||||
# out = outputs.Label()
|
||||
# label = json.loads(out.postprocess(string))
|
||||
# self.assertDictEqual(label, {outputs.Label.LABEL_KEY: string})
|
||||
#
|
||||
# def test_postprocessing_1D_array(self):
|
||||
# array = np.array([0.1, 0.2, 0, 0.7, 0])
|
||||
# true_label = {outputs.Label.LABEL_KEY: 3,
|
||||
# outputs.Label.CONFIDENCES_KEY: [
|
||||
# {outputs.Label.LABEL_KEY: 3, outputs.Label.CONFIDENCE_KEY: 0.7},
|
||||
# {outputs.Label.LABEL_KEY: 1, outputs.Label.CONFIDENCE_KEY: 0.2},
|
||||
# {outputs.Label.LABEL_KEY: 0, outputs.Label.CONFIDENCE_KEY: 0.1},
|
||||
# ]}
|
||||
# out = outputs.Label()
|
||||
# label = json.loads(out.postprocess(array))
|
||||
# self.assertDictEqual(label, true_label)
|
||||
def test_postprocessing_string(self):
|
||||
string = 'happy'
|
||||
out = outputs.Label()
|
||||
label = out.postprocess(string)
|
||||
self.assertDictEqual(label, {outputs.Label.LABEL_KEY: string})
|
||||
|
||||
# def test_postprocessing_1D_array_no_confidences(self):
|
||||
# array = np.array([0.1, 0.2, 0, 0.7, 0])
|
||||
# true_label = {outputs.Label.LABEL_KEY: 3}
|
||||
# out = outputs.Label(show_confidences=False)
|
||||
# label = json.loads(out.postprocess(array))
|
||||
# self.assertDictEqual(label, true_label)
|
||||
#
|
||||
# def test_postprocessing_int(self):
|
||||
# true_label_array = np.array([[[3]]])
|
||||
# true_label = {outputs.Label.LABEL_KEY: 3}
|
||||
# out = outputs.Label()
|
||||
# label = json.loads(out.postprocess(true_label_array))
|
||||
# self.assertDictEqual(label, true_label)
|
||||
def test_postprocessing_dict(self):
|
||||
orig_label = {
|
||||
3: 0.7,
|
||||
1: 0.2,
|
||||
0: 0.1
|
||||
}
|
||||
true_label = {outputs.Label.LABEL_KEY: 3,
|
||||
outputs.Label.CONFIDENCES_KEY: [
|
||||
{outputs.Label.LABEL_KEY: 3, outputs.Label.CONFIDENCE_KEY: 0.7},
|
||||
{outputs.Label.LABEL_KEY: 1, outputs.Label.CONFIDENCE_KEY: 0.2},
|
||||
{outputs.Label.LABEL_KEY: 0, outputs.Label.CONFIDENCE_KEY: 0.1},
|
||||
]}
|
||||
out = outputs.Label()
|
||||
label = out.postprocess(orig_label)
|
||||
self.assertDictEqual(label, true_label)
|
||||
|
||||
def test_postprocessing_array(self):
|
||||
array = np.array([0.1, 0.2, 0, 0.7, 0])
|
||||
out = outputs.Label()
|
||||
self.assertRaises(ValueError, out.postprocess, array)
|
||||
|
||||
def test_postprocessing_int(self):
|
||||
label = 3
|
||||
true_label = {outputs.Label.LABEL_KEY: '3'}
|
||||
out = outputs.Label()
|
||||
label = out.postprocess(label)
|
||||
self.assertDictEqual(label, true_label)
|
||||
|
||||
|
||||
class TestTextbox(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user