add support for passing in interface instances

This commit is contained in:
Abubakar Abid 2019-03-05 22:45:08 -08:00
parent a3516c331b
commit 9e64d86039
3 changed files with 24 additions and 5 deletions

View File

@ -69,7 +69,6 @@ class Webcam(AbstractInput):
"""
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
"""
print('>>>>>>>>>in preprocess')
content = inp.split(';')[1]
image_encoded = content.split(',')[1]
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')

View File

@ -31,16 +31,26 @@ class Interface:
def __init__(self, inputs, outputs, model, model_type=None, preprocessing_fns=None, postprocessing_fns=None,
verbose=True):
"""
:param inputs: a string representing the input interface.
:param outputs: a string representing the output interface.
:param inputs: a string or `AbstractInput` representing the input interface.
:param outputs: a string or `AbstractOutput` representing the output interface.
:param model_obj: the model object, such as a sklearn classifier or keras model.
:param model_type: what kind of trained model, can be 'keras' or 'sklearn' or 'function'. Inferred if not
provided.
:param preprocessing_fns: an optional function that overrides the preprocessing function of the input interface.
:param postprocessing_fns: an optional function that overrides the postprocessing fn of the output interface.
"""
self.input_interface = gradio.inputs.registry[inputs.lower()](preprocessing_fns)
self.output_interface = gradio.outputs.registry[outputs.lower()](postprocessing_fns)
if isinstance(inputs, str):
self.input_interface = gradio.inputs.registry[inputs.lower()](preprocessing_fns)
elif isinstance(inputs, gradio.inputs.AbstractInput):
self.input_interface = inputs
else:
raise ValueError('Input interface must be of type `str` or `AbstractInput`')
if isinstance(outputs, str):
self.output_interface = gradio.outputs.registry[outputs.lower()](postprocessing_fns)
elif isinstance(outputs, gradio.outputs.AbstractOutput):
self.output_interface = outputs
else:
raise ValueError('Output interface must be of type `str` or `AbstractOutput`')
self.model_obj = model
if model_type is None:
model_type = self._infer_model_type(model)

View File

@ -10,6 +10,16 @@ class TestInterface(unittest.TestCase):
self.assertIsInstance(io.input_interface, gradio.inputs.Sketchpad)
self.assertIsInstance(io.output_interface, gradio.outputs.Textbox)
def test_input_interface_is_instance(self):
inp = gradio.inputs.ImageUpload()
io = Interface(inputs=inp, outputs='textBOX', model=lambda x: x, model_type='function')
self.assertEqual(io.input_interface, inp)
def test_output_interface_is_instance(self):
out = gradio.outputs.Label(show_confidences=False)
io = Interface(inputs='SketCHPad', outputs=out, model=lambda x: x, model_type='function')
self.assertEqual(io.input_interface, inp)
if __name__ == '__main__':
unittest.main()