mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
sample-inputs for sketchpad
This commit is contained in:
parent
b43e76b148
commit
8ef628d203
@ -6,3 +6,4 @@ psutil
|
||||
paramiko
|
||||
scipy
|
||||
IPython
|
||||
scikit-image
|
||||
|
@ -53,6 +53,12 @@ class AbstractInput(ABC):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def sample_inputs(self):
|
||||
"""
|
||||
An interface can optionally implement a method that sends a list of sample inputs for inference.
|
||||
"""
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self):
|
||||
"""
|
||||
@ -77,7 +83,7 @@ class AbstractInput(ABC):
|
||||
|
||||
class Sketchpad(AbstractInput):
|
||||
def __init__(self, preprocessing_fn=None, shape=(28, 28), invert_colors=True, flatten=False, scale=1/255, shift=0,
|
||||
dtype='float64'):
|
||||
dtype='float64', sample_inputs=None):
|
||||
self.image_width = shape[0]
|
||||
self.image_height = shape[1]
|
||||
self.invert_colors = invert_colors
|
||||
@ -85,8 +91,10 @@ class Sketchpad(AbstractInput):
|
||||
self.scale = scale
|
||||
self.shift = shift
|
||||
self.dtype = dtype
|
||||
self.sample_inputs = sample_inputs
|
||||
super().__init__(preprocessing_fn=preprocessing_fn)
|
||||
|
||||
|
||||
def get_name(self):
|
||||
return 'sketchpad'
|
||||
|
||||
@ -121,6 +129,13 @@ class Sketchpad(AbstractInput):
|
||||
im.save(f'{dir}/{filename}', 'PNG')
|
||||
return filename
|
||||
|
||||
def get_sample_inputs(self):
|
||||
encoded_images = []
|
||||
if self.sample_inputs is not None:
|
||||
for input in self.sample_inputs:
|
||||
encoded_images.append(preprocessing_utils.encode_array_to_base64(input))
|
||||
return encoded_images
|
||||
|
||||
|
||||
class Webcam(AbstractInput):
|
||||
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3):
|
||||
|
@ -130,6 +130,22 @@ class Interface:
|
||||
|
||||
raise ValueError("model_type could not be inferred, please specify parameter `model_type`")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def update_config_file(self, output_directory):
|
||||
networking.set_interface_types_in_config_file(
|
||||
output_directory,
|
||||
self.input_interface.__class__.__name__.lower(),
|
||||
self.output_interface.__class__.__name__.lower(),
|
||||
)
|
||||
|
||||
if self.input_interface.__class__.__name__.lower() == "sketchpad" or self.input_interface.__class__.__name__.lower() == "textbox":
|
||||
networking.set_sample_data_in_config_file(
|
||||
output_directory,
|
||||
self.input_interface.get_sample_inputs()
|
||||
)
|
||||
|
||||
|
||||
def predict(self, preprocessed_input):
|
||||
"""
|
||||
Method that calls the relevant method of the model object to make a prediction.
|
||||
@ -235,11 +251,8 @@ class Interface:
|
||||
output_directory, self.input_interface, self.output_interface
|
||||
)
|
||||
|
||||
networking.set_interface_types_in_config_file(
|
||||
output_directory,
|
||||
self.input_interface.__class__.__name__.lower(),
|
||||
self.output_interface.__class__.__name__.lower(),
|
||||
)
|
||||
self.update_config_file(self, output_directory)
|
||||
|
||||
self.status = self.STATUS_TYPES["RUNNING"]
|
||||
self.simple_server = httpd
|
||||
|
||||
|
@ -117,6 +117,16 @@ def set_share_url_in_config_file(temp_dir, share_url):
|
||||
)
|
||||
|
||||
|
||||
def set_sample_data_in_config_file(temp_dir, sample_inputs):
|
||||
config_file = os.path.join(temp_dir, CONFIG_FILE)
|
||||
render_template_with_tags(
|
||||
config_file,
|
||||
{
|
||||
"sample_inputs": sample_inputs
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_first_available_port(initial, final):
|
||||
"""
|
||||
Gets the first open port in a specified range of port numbers
|
||||
|
@ -5,6 +5,7 @@ import tempfile
|
||||
import scipy.io.wavfile
|
||||
from scipy.fftpack import dct
|
||||
import numpy as np
|
||||
import skimage
|
||||
|
||||
|
||||
#########################
|
||||
@ -16,6 +17,16 @@ def decode_base64_to_image(encoding):
|
||||
return Image.open(BytesIO(base64.b64decode(image_encoded)))
|
||||
|
||||
|
||||
def encode_array_to_base64(image_array):
|
||||
with BytesIO() as output_bytes:
|
||||
PIL_image = Image.fromarray(skimage.img_as_ubyte(image_array))
|
||||
PIL_image.save(output_bytes, 'PNG')
|
||||
bytes_data = output_bytes.getvalue()
|
||||
|
||||
base64_str = str(base64.b64encode(bytes_data), 'utf-8')
|
||||
return base64_str
|
||||
|
||||
|
||||
def resize_and_crop(img, size, crop_type='top'):
|
||||
"""
|
||||
Resize and crop an image to fit the specified size.
|
||||
@ -24,8 +35,8 @@ def resize_and_crop(img, size, crop_type='top'):
|
||||
modified_path: path to store the modified image.
|
||||
size: `(width, height)` tuple.
|
||||
crop_type: can be 'top', 'middle' or 'bottom', depending on this
|
||||
value, the image will cropped getting the 'top/left', 'midle' or
|
||||
'bottom/rigth' of the image to fit the size.
|
||||
value, the image will cropped getting the 'top/left', 'middle' or
|
||||
'bottom/right' of the image to fit the size.
|
||||
raises:
|
||||
Exception: if can not open the file in img_path of there is problems
|
||||
to save the image.
|
||||
|
@ -1,5 +1,6 @@
|
||||
{
|
||||
"input_interface_type": "{{input_interface_type}}",
|
||||
"output_interface_type": "{{output_interface_type}}",
|
||||
"share_url": "{{share_url}}"
|
||||
"share_url": "{{share_url}}",
|
||||
"sample_inputs": "{{sample_inputs}}"
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user