mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
add support for passing in interface instances
This commit is contained in:
parent
a3516c331b
commit
9e64d86039
@ -69,7 +69,6 @@ class Webcam(AbstractInput):
|
||||
"""
|
||||
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
|
||||
"""
|
||||
print('>>>>>>>>>in preprocess')
|
||||
content = inp.split(';')[1]
|
||||
image_encoded = content.split(',')[1]
|
||||
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
|
||||
|
@ -31,16 +31,26 @@ class Interface:
|
||||
def __init__(self, inputs, outputs, model, model_type=None, preprocessing_fns=None, postprocessing_fns=None,
|
||||
verbose=True):
|
||||
"""
|
||||
:param inputs: a string representing the input interface.
|
||||
:param outputs: a string representing the output interface.
|
||||
:param inputs: a string or `AbstractInput` representing the input interface.
|
||||
:param outputs: a string or `AbstractOutput` representing the output interface.
|
||||
:param model_obj: the model object, such as a sklearn classifier or keras model.
|
||||
:param model_type: what kind of trained model, can be 'keras' or 'sklearn' or 'function'. Inferred if not
|
||||
provided.
|
||||
:param preprocessing_fns: an optional function that overrides the preprocessing function of the input interface.
|
||||
:param postprocessing_fns: an optional function that overrides the postprocessing fn of the output interface.
|
||||
"""
|
||||
self.input_interface = gradio.inputs.registry[inputs.lower()](preprocessing_fns)
|
||||
self.output_interface = gradio.outputs.registry[outputs.lower()](postprocessing_fns)
|
||||
if isinstance(inputs, str):
|
||||
self.input_interface = gradio.inputs.registry[inputs.lower()](preprocessing_fns)
|
||||
elif isinstance(inputs, gradio.inputs.AbstractInput):
|
||||
self.input_interface = inputs
|
||||
else:
|
||||
raise ValueError('Input interface must be of type `str` or `AbstractInput`')
|
||||
if isinstance(outputs, str):
|
||||
self.output_interface = gradio.outputs.registry[outputs.lower()](postprocessing_fns)
|
||||
elif isinstance(outputs, gradio.outputs.AbstractOutput):
|
||||
self.output_interface = outputs
|
||||
else:
|
||||
raise ValueError('Output interface must be of type `str` or `AbstractOutput`')
|
||||
self.model_obj = model
|
||||
if model_type is None:
|
||||
model_type = self._infer_model_type(model)
|
||||
|
@ -10,6 +10,16 @@ class TestInterface(unittest.TestCase):
|
||||
self.assertIsInstance(io.input_interface, gradio.inputs.Sketchpad)
|
||||
self.assertIsInstance(io.output_interface, gradio.outputs.Textbox)
|
||||
|
||||
def test_input_interface_is_instance(self):
|
||||
inp = gradio.inputs.ImageUpload()
|
||||
io = Interface(inputs=inp, outputs='textBOX', model=lambda x: x, model_type='function')
|
||||
self.assertEqual(io.input_interface, inp)
|
||||
|
||||
def test_output_interface_is_instance(self):
|
||||
out = gradio.outputs.Label(show_confidences=False)
|
||||
io = Interface(inputs='SketCHPad', outputs=out, model=lambda x: x, model_type='function')
|
||||
self.assertEqual(io.input_interface, inp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user