mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
confidence intervals, img upload bug fix
This commit is contained in:
commit
fc7e0c2c42
File diff suppressed because one or more lines are too long
@ -44,6 +44,10 @@ class AbstractInput(ABC):
|
||||
|
||||
|
||||
class Sketchpad(AbstractInput):
|
||||
def __init__(self, preprocessing_fn=None, image_width=28, image_height=28):
|
||||
self.image_width = image_width
|
||||
self.image_height = image_height
|
||||
super().__init__(preprocessing_fn=preprocessing_fn)
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/sketchpad_input.html'
|
||||
@ -55,12 +59,17 @@ class Sketchpad(AbstractInput):
|
||||
content = inp.split(';')[1]
|
||||
image_encoded = content.split(',')[1]
|
||||
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
|
||||
im = preprocessing_utils.resize_and_crop(im, (28, 28))
|
||||
array = np.array(im).flatten().reshape(1, 28, 28, 1)
|
||||
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
|
||||
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, 1)
|
||||
return array
|
||||
|
||||
|
||||
class Webcam(AbstractInput):
|
||||
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3):
|
||||
self.image_width = image_width
|
||||
self.image_height = image_height
|
||||
self.num_channels = num_channels
|
||||
super().__init__(preprocessing_fn=preprocessing_fn)
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/webcam_input.html'
|
||||
@ -71,9 +80,9 @@ class Webcam(AbstractInput):
|
||||
"""
|
||||
content = inp.split(';')[1]
|
||||
image_encoded = content.split(',')[1]
|
||||
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
|
||||
im = preprocessing_utils.resize_and_crop(im, (48, 48))
|
||||
array = np.array(im).flatten().reshape(1, 48, 48, 1)
|
||||
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('RGB')
|
||||
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
|
||||
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
|
||||
return array
|
||||
|
||||
|
||||
@ -90,6 +99,11 @@ class Textbox(AbstractInput):
|
||||
|
||||
|
||||
class ImageUpload(AbstractInput):
|
||||
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3):
|
||||
self.image_width = image_width
|
||||
self.image_height = image_height
|
||||
self.num_channels = num_channels
|
||||
super().__init__(preprocessing_fn=preprocessing_fn)
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/image_upload_input.html'
|
||||
@ -100,9 +114,9 @@ class ImageUpload(AbstractInput):
|
||||
"""
|
||||
content = inp.split(';')[1]
|
||||
image_encoded = content.split(',')[1]
|
||||
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
|
||||
im = preprocessing_utils.resize_and_crop(im, (48, 48))
|
||||
array = np.array(im).flatten().reshape(1, 48, 48, 1)
|
||||
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('RGB')
|
||||
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
|
||||
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
|
||||
return array
|
||||
|
||||
|
||||
|
@ -31,16 +31,26 @@ class Interface:
|
||||
def __init__(self, inputs, outputs, model, model_type=None, preprocessing_fns=None, postprocessing_fns=None,
|
||||
verbose=True):
|
||||
"""
|
||||
:param inputs: a string representing the input interface.
|
||||
:param outputs: a string representing the output interface.
|
||||
:param inputs: a string or `AbstractInput` representing the input interface.
|
||||
:param outputs: a string or `AbstractOutput` representing the output interface.
|
||||
:param model_obj: the model object, such as a sklearn classifier or keras model.
|
||||
:param model_type: what kind of trained model, can be 'keras' or 'sklearn' or 'function'. Inferred if not
|
||||
provided.
|
||||
:param preprocessing_fns: an optional function that overrides the preprocessing function of the input interface.
|
||||
:param postprocessing_fns: an optional function that overrides the postprocessing fn of the output interface.
|
||||
"""
|
||||
self.input_interface = gradio.inputs.registry[inputs.lower()](preprocessing_fns)
|
||||
self.output_interface = gradio.outputs.registry[outputs.lower()](postprocessing_fns)
|
||||
if isinstance(inputs, str):
|
||||
self.input_interface = gradio.inputs.registry[inputs.lower()](preprocessing_fns)
|
||||
elif isinstance(inputs, gradio.inputs.AbstractInput):
|
||||
self.input_interface = inputs
|
||||
else:
|
||||
raise ValueError('Input interface must be of type `str` or `AbstractInput`')
|
||||
if isinstance(outputs, str):
|
||||
self.output_interface = gradio.outputs.registry[outputs.lower()](postprocessing_fns)
|
||||
elif isinstance(outputs, gradio.outputs.AbstractOutput):
|
||||
self.output_interface = outputs
|
||||
else:
|
||||
raise ValueError('Output interface must be of type `str` or `AbstractOutput`')
|
||||
self.model_obj = model
|
||||
if model_type is None:
|
||||
model_type = self._infer_model_type(model)
|
||||
@ -92,6 +102,7 @@ class Interface:
|
||||
while True:
|
||||
try:
|
||||
msg = await websocket.recv()
|
||||
print('>>>>>>>>>msg', msg)
|
||||
processed_input = self.input_interface.preprocess(msg)
|
||||
prediction = self.predict(processed_input)
|
||||
processed_output = self.output_interface.postprocess(prediction)
|
||||
|
@ -18,7 +18,7 @@ from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
import pkg_resources
|
||||
from bs4 import BeautifulSoup
|
||||
import shutil
|
||||
from distutils import dir_util
|
||||
|
||||
INITIAL_PORT_VALUE = 7860 # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
|
||||
TRY_NUM_PORTS = 100 # Number of ports to try before giving up and throwing an exception.
|
||||
@ -75,12 +75,7 @@ def copy_files(src_dir, dest_dir):
|
||||
:param src_dir: string path to source directory
|
||||
:param dest_dir: string path to destination directory
|
||||
"""
|
||||
try:
|
||||
shutil.copytree(src_dir, dest_dir)
|
||||
except OSError as exc: # python >2.5
|
||||
if exc.errno == errno.ENOTDIR:
|
||||
shutil.copy(src_dir, dest_dir)
|
||||
else: raise
|
||||
dir_util.copy_tree(src_dir, dest_dir)
|
||||
|
||||
|
||||
#TODO(abidlabs): Handle the http vs. https issue that sometimes happens (a ws cannot be loaded from an https page)
|
||||
|
@ -38,6 +38,14 @@ class AbstractOutput(ABC):
|
||||
|
||||
|
||||
class Label(AbstractOutput):
|
||||
LABEL_KEY = 'label'
|
||||
CONFIDENCES_KEY = 'confidences'
|
||||
CONFIDENCE_KEY = 'confidence'
|
||||
|
||||
def __init__(self, postprocessing_fn=None, num_top_classes=3, show_confidences=True):
|
||||
self.num_top_classes = num_top_classes
|
||||
self.show_confidences = show_confidences
|
||||
super().__init__(postprocessing_fn=postprocessing_fn)
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/label_output.html'
|
||||
@ -45,16 +53,27 @@ class Label(AbstractOutput):
|
||||
def postprocess(self, prediction):
|
||||
"""
|
||||
"""
|
||||
response = dict()
|
||||
# TODO(abidlabs): check if list, if so convert to numpy array
|
||||
if isinstance(prediction, np.ndarray):
|
||||
prediction = prediction.squeeze()
|
||||
if prediction.size == 1:
|
||||
return prediction
|
||||
else:
|
||||
return prediction.argmax()
|
||||
if prediction.size == 1: # if it's single value
|
||||
response[Label.LABEL_KEY] = np.asscalar(prediction)
|
||||
elif len(prediction.shape) == 1: # if a 1D
|
||||
response[Label.LABEL_KEY] = prediction.argmax()
|
||||
if self.show_confidences:
|
||||
response[Label.CONFIDENCES_KEY] = []
|
||||
for i in range(self.num_top_classes):
|
||||
response[Label.CONFIDENCES_KEY].append({
|
||||
Label.LABEL_KEY: prediction.argmax(),
|
||||
Label.CONFIDENCE_KEY: prediction.max(),
|
||||
})
|
||||
prediction[prediction.argmax()] = 0
|
||||
elif isinstance(prediction, str):
|
||||
return prediction
|
||||
response[Label.LABEL_KEY] = prediction
|
||||
else:
|
||||
raise ValueError("Unable to post-process model prediction.")
|
||||
return response
|
||||
|
||||
|
||||
class Textbox(AbstractOutput):
|
||||
|
@ -17,7 +17,7 @@
|
||||
<div id="panels">
|
||||
<div class="panel">
|
||||
<div id="input"></div>
|
||||
<input class="submit" type="submit" value="Submit"/><!--
|
||||
<input class="submit" type="submit" value="Submit"/><!--DO NOT DELETE
|
||||
--><input class="clear" type="reset" value="Clear">
|
||||
</div>
|
||||
<div class="panel">
|
||||
|
@ -11,7 +11,7 @@ class TestSketchpad(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Sketchpad()
|
||||
path = inp.get_template_path()
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Sketchpad()
|
||||
@ -23,19 +23,19 @@ class TestWebcam(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Webcam()
|
||||
path = inp.get_template_path()
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Webcam()
|
||||
array = inp.preprocess(BASE64_IMG)
|
||||
self.assertEqual(array.shape, (1, 48, 48, 1))
|
||||
self.assertEqual(array.shape, (1, 224, 224, 3))
|
||||
|
||||
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
inp = inputs.Textbox()
|
||||
path = inp.get_template_path()
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.Textbox()
|
||||
@ -52,7 +52,12 @@ class TestImageUpload(unittest.TestCase):
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.ImageUpload()
|
||||
array = inp.preprocess(BASE64_IMG)
|
||||
self.assertEqual(array.shape, (1, 48, 48, 1))
|
||||
self.assertEqual(array.shape, (1, 224, 224, 3))
|
||||
|
||||
def test_preprocessing(self):
|
||||
inp = inputs.ImageUpload(image_height=48, image_width=48)
|
||||
array = inp.preprocess(BASE64_IMG)
|
||||
self.assertEqual(array.shape, (1, 48, 48, 3))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -10,6 +10,16 @@ class TestInterface(unittest.TestCase):
|
||||
self.assertIsInstance(io.input_interface, gradio.inputs.Sketchpad)
|
||||
self.assertIsInstance(io.output_interface, gradio.outputs.Textbox)
|
||||
|
||||
def test_input_interface_is_instance(self):
|
||||
inp = gradio.inputs.ImageUpload()
|
||||
io = Interface(inputs=inp, outputs='textBOX', model=lambda x: x, model_type='function')
|
||||
self.assertEqual(io.input_interface, inp)
|
||||
|
||||
def test_output_interface_is_instance(self):
|
||||
out = gradio.outputs.Label(show_confidences=False)
|
||||
io = Interface(inputs='SketCHPad', outputs=out, model=lambda x: x, model_type='function')
|
||||
self.assertEqual(io.output_interface, out)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -16,28 +16,40 @@ class TestLabel(unittest.TestCase):
|
||||
string = 'happy'
|
||||
out = outputs.Label()
|
||||
label = out.postprocess(string)
|
||||
self.assertEqual(label, string)
|
||||
self.assertDictEqual(label, {outputs.Label.LABEL_KEY: string})
|
||||
|
||||
def test_postprocessing_one_hot(self):
|
||||
one_hot = np.array([0, 0, 0, 1, 0])
|
||||
true_label = 3
|
||||
def test_postprocessing_1D_array(self):
|
||||
array = np.array([0.1, 0.2, 0, 0.7, 0])
|
||||
true_label = {outputs.Label.LABEL_KEY: 3,
|
||||
outputs.Label.CONFIDENCES_KEY: [
|
||||
{outputs.Label.LABEL_KEY: 3, outputs.Label.CONFIDENCE_KEY: 0.7},
|
||||
{outputs.Label.LABEL_KEY: 1, outputs.Label.CONFIDENCE_KEY: 0.2},
|
||||
{outputs.Label.LABEL_KEY: 0, outputs.Label.CONFIDENCE_KEY: 0.1},
|
||||
]}
|
||||
out = outputs.Label()
|
||||
label = out.postprocess(one_hot)
|
||||
self.assertEqual(label, true_label)
|
||||
label = out.postprocess(array)
|
||||
self.assertDictEqual(label, true_label)
|
||||
|
||||
def test_postprocessing_1D_array_no_confidences(self):
|
||||
array = np.array([0.1, 0.2, 0, 0.7, 0])
|
||||
true_label = {outputs.Label.LABEL_KEY: 3}
|
||||
out = outputs.Label(show_confidences=False)
|
||||
label = out.postprocess(array)
|
||||
self.assertDictEqual(label, true_label)
|
||||
|
||||
def test_postprocessing_int(self):
|
||||
true_label_array = np.array([[[3]]])
|
||||
true_label = 3
|
||||
true_label = {outputs.Label.LABEL_KEY: 3}
|
||||
out = outputs.Label()
|
||||
label = out.postprocess(true_label_array)
|
||||
self.assertEqual(label, true_label)
|
||||
self.assertDictEqual(label, true_label)
|
||||
|
||||
|
||||
class TestTextbox(unittest.TestCase):
|
||||
def test_path_exists(self):
|
||||
out = outputs.Textbox()
|
||||
path = out.get_template_path()
|
||||
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
|
||||
|
||||
def test_postprocessing(self):
|
||||
string = 'happy'
|
||||
|
Loading…
x
Reference in New Issue
Block a user