diff --git a/build/lib/gradio/inputs.py b/build/lib/gradio/inputs.py index 4f9141aefb..1395f76f1a 100644 --- a/build/lib/gradio/inputs.py +++ b/build/lib/gradio/inputs.py @@ -7,7 +7,7 @@ automatically added to a registry, which allows them to be easily referenced in from abc import ABC, abstractmethod from gradio import preprocessing_utils, validation_data import numpy as np -from PIL import Image, ImageOps +import PIL.Image, PIL.ImageOps import time import warnings import json @@ -58,11 +58,12 @@ class AbstractInput(ABC): """ return {} - def rebuild_flagged(self, dir, msg): + @classmethod + def process_example(self, example): """ - All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64) + Proprocess example for UI """ - pass + return example class Sketchpad(AbstractInput): @@ -84,11 +85,11 @@ class Sketchpad(AbstractInput): Default preprocessing method for the SketchPad is to convert the sketch to black and white and resize 28x28 """ im_transparent = preprocessing_utils.decode_base64_to_image(inp) - im = Image.new("RGBA", im_transparent.size, "WHITE") # Create a white background for the alpha channel + im = PIL.Image.new("RGBA", im_transparent.size, "WHITE") # Create a white background for the alpha channel im.paste(im_transparent, (0, 0), im_transparent) im = im.convert('L') if self.invert_colors: - im = ImageOps.invert(im) + im = PIL.ImageOps.invert(im) im = im.resize((self.image_width, self.image_height)) if self.flatten: array = np.array(im).flatten().reshape(1, self.image_width * self.image_height) @@ -98,30 +99,6 @@ class Sketchpad(AbstractInput): array = array.astype(self.dtype) return array - # TODO(abidlabs): clean this up - def rebuild_flagged(self, dir, msg): - """ - Default rebuild method to decode a base64 image - """ - - im = preprocessing_utils.decode_base64_to_image(msg) - - timestamp = datetime.datetime.now() - filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png' - im.save(f'{dir}/{filename}', 'PNG') - return filename - - def get_sample_inputs(self): - encoded_images = [] - if self.sample_inputs is not None: - for input in self.sample_inputs: - if self.flatten: - input = input.reshape((self.image_width, self.image_height)) - if self.invert_colors: - input = 1 - input - encoded_images.append(preprocessing_utils.encode_array_to_base64(input)) - return encoded_images - class Webcam(AbstractInput): def __init__(self, image_width=224, image_height=224, num_channels=3, label=None): @@ -149,17 +126,6 @@ class Webcam(AbstractInput): array = np.array(im).flatten().reshape(self.image_width, self.image_height, self.num_channels) return array - def rebuild_flagged(self, dir, msg): - """ - Default rebuild method to decode a base64 image - """ - inp = msg['data']['input'] - im = preprocessing_utils.decode_base64_to_image(inp) - timestamp = datetime.datetime.now() - filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png' - im.save(f'{dir}/{filename}', 'PNG') - return filename - class Textbox(AbstractInput): def __init__(self, sample_inputs=None, lines=1, placeholder=None, label=None, numeric=False): @@ -196,15 +162,6 @@ class Textbox(AbstractInput): else: return inp - def rebuild_flagged(self, dir, msg): - """ - Default rebuild method for text saves it .txt file - """ - return json.loads(msg) - - def get_sample_inputs(self): - return self.sample_inputs - class Radio(AbstractInput): def __init__(self, choices, label=None): @@ -261,6 +218,7 @@ class Slider(AbstractInput): "checkbox": {}, } + class Checkbox(AbstractInput): def __init__(self, label=None): super().__init__(label) @@ -272,7 +230,7 @@ class Checkbox(AbstractInput): } -class ImageIn(AbstractInput): +class Image(AbstractInput): def __init__(self, cast_to=None, shape=(224, 224, 3), image_mode='RGB', scale=1/127.5, shift=-1, cropper_aspect_ratio=None, label=None): self.cast_to = cast_to @@ -338,43 +296,6 @@ class ImageIn(AbstractInput): self.num_channels) return array - def rebuild_flagged(self, dir, msg): - """ - Default rebuild method to decode a base64 image - """ - im = preprocessing_utils.decode_base64_to_image(msg) - timestamp = datetime.datetime.now() - filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png' - im.save(f'{dir}/{filename}', 'PNG') - return filename - - # TODO(abidlabs): clean this up - def save_to_file(self, dir, img): - """ - """ - timestamp = time.time()*1000 - filename = 'input_{}.png'.format(timestamp) - img.save('{}/{}'.format(dir, filename), 'PNG') - return filename - - -class CSV(AbstractInput): - - def get_name(self): - return 'csv' - - def preprocess(self, inp): - """ - By default, no pre-processing is applied to a CSV file (TODO:aliabid94 fix this) - """ - return inp - - def rebuild_flagged(self, dir, msg): - """ - Default rebuild method for csv - """ - return json.loads(msg) - class Microphone(AbstractInput): @@ -386,12 +307,6 @@ class Microphone(AbstractInput): mfcc_array = preprocessing_utils.generate_mfcc_features_from_audio_file(file_obj.name) return mfcc_array - def rebuild_flagged(self, dir, msg): - """ - Default rebuild method for csv - """ - return json.loads(msg) - # Automatically adds all shortcut implementations in AbstractInput into a dictionary. shortcuts = {} diff --git a/build/lib/gradio/interface.py b/build/lib/gradio/interface.py index d70c301fcb..7aab25df28 100644 --- a/build/lib/gradio/interface.py +++ b/build/lib/gradio/interface.py @@ -16,7 +16,6 @@ import requests import random import time from IPython import get_ipython -import tensorflow as tf LOCALHOST_IP = "0.0.0.0" TRY_NUM_PORTS = 100 @@ -29,9 +28,9 @@ class Interface: the appropriate inputs and outputs """ - def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, + def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, examples=None, live=False, show_input=True, show_output=True, - load_fn=None, capture_session=False, + load_fn=None, capture_session=False, title=None, description=None, server_name=LOCALHOST_IP): """ :param fn: a function that will process the input panel data from the interface and return the output panel data. @@ -81,6 +80,9 @@ class Interface: self.capture_session = capture_session self.session = None self.server_name = server_name + self.title = title + self.description = description + self.examples = examples def get_config_file(self): return { @@ -93,7 +95,9 @@ class Interface: "function_count": len(self.predict), "live": self.live, "show_input": self.show_input, - "show_output": self.show_output, + "show_output": self.show_output, + "title": self.title, + "description": self.description, } def process(self, raw_input): @@ -109,8 +113,15 @@ class Interface: prediction = predict_fn(*processed_input, self.context) else: - prediction = predict_fn(*processed_input, - self.context) + try: + prediction = predict_fn(*processed_input, self.context) + except ValueError: + print("It looks like you might be " + "using tensorflow < 2.0. Please pass " + "capture_session=True in Interface to avoid " + "a 'Tensor is not an element of this graph.' " + "error.") + prediction = predict_fn(*processed_input, self.context) else: if self.capture_session: graph, sess = self.session @@ -118,7 +129,16 @@ class Interface: with sess.as_default(): prediction = predict_fn(*processed_input) else: - prediction = predict_fn(*processed_input) + try: + prediction = predict_fn(*processed_input) + except ValueError: + print("It looks like you might be " + "using tensorflow < 2.0. Please pass " + "capture_session=True in Interface to avoid " + "a 'Tensor is not an element of this graph.' " + "error.") + prediction = predict_fn(*processed_input) + if len(self.output_interfaces) / \ len(self.predict) == 1: prediction = [prediction] @@ -127,7 +147,6 @@ class Interface: predictions[i]) for i, output_interface in enumerate(self.output_interfaces)] return processed_output - def validate(self): if self.validate_flag: if self.verbose: @@ -180,11 +199,7 @@ class Interface: return raise RuntimeError("Validation did not pass") -<<<<<<< HEAD - def launch(self, inline=None, inbrowser=None, share=False, validate=True, title=None, description=None): -======= def launch(self, inline=None, inbrowser=None, share=False, validate=True): ->>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217 """ Standard method shared by interfaces that creates the interface and sets up a websocket to communicate with it. :param inline: boolean. If True, then a gradio interface is created inline (e.g. in jupyter or colab notebook) @@ -198,6 +213,7 @@ class Interface: self.context = context if self.capture_session: + import tensorflow as tf self.session = tf.get_default_graph(), \ tf.keras.backend.get_session() @@ -294,11 +310,7 @@ class Interface: config = self.get_config_file() config["share_url"] = share_url -<<<<<<< HEAD - config["title"] = title - config["description"] = description -======= ->>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217 + config["examples"] = self.examples networking.set_config(config, output_directory) return httpd, path_to_local_server, share_url diff --git a/build/lib/gradio/outputs.py b/build/lib/gradio/outputs.py index cd7d6bb168..af96d0de1b 100644 --- a/build/lib/gradio/outputs.py +++ b/build/lib/gradio/outputs.py @@ -76,12 +76,6 @@ class Label(AbstractOutput): "label": {}, } - def rebuild_flagged(self, dir, msg): - """ - Default rebuild method for label - """ - return json.loads(msg) - class KeyValues(AbstractOutput): def __init__(self, label=None): @@ -120,12 +114,6 @@ class Textbox(AbstractOutput): """ return prediction - def rebuild_flagged(self, dir, msg): - """ - Default rebuild method for label - """ - return json.loads(msg) - class Image(AbstractOutput): def __init__(self, label=None, plot=False): diff --git a/build/lib/gradio/static/css/style.css b/build/lib/gradio/static/css/style.css index 7cf66a3a7b..dfbfbfdaf9 100644 --- a/build/lib/gradio/static/css/style.css +++ b/build/lib/gradio/static/css/style.css @@ -31,7 +31,6 @@ nav img { padding: 4px; border-radius: 2px; } -<<<<<<< HEAD #title { text-align: center; } @@ -40,19 +39,11 @@ nav img { width: 100%; margin: 0 auto; } -======= ->>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217 .panels { display: flex; flex-flow: row; flex-wrap: wrap; justify-content: center; -<<<<<<< HEAD -======= - max-width: 1028px; - width: 100%; - margin: 0 auto; ->>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217 } button.primary { color: white; diff --git a/build/lib/gradio/static/js/gradio.js b/build/lib/gradio/static/js/gradio.js index 912ab57b89..bf6cb9e87e 100644 --- a/build/lib/gradio/static/js/gradio.js +++ b/build/lib/gradio/static/js/gradio.js @@ -1,11 +1,7 @@ function gradio(config, fn, target) { target = $(target); target.html(` -<<<<<<< HEAD
-======= -
->>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217
@@ -30,7 +26,7 @@ function gradio(config, fn, target) { let input_to_object_map = { "csv" : {}, - "imagein" : image_input, + "image" : image_input, "sketchpad" : sketchpad_input, "textbox" : textbox_input, "webcam" : webcam, diff --git a/build/lib/gradio/static/js/interfaces/input/webcam.js b/build/lib/gradio/static/js/interfaces/input/webcam.js index f551a3855a..ec081b2601 100644 --- a/build/lib/gradio/static/js/interfaces/input/webcam.js +++ b/build/lib/gradio/static/js/interfaces/input/webcam.js @@ -20,27 +20,16 @@ const webcam = { }, submit: function() { var io = this; -<<<<<<< HEAD Webcam.snap(function(image_data) { io.io_master.input(io.id, image_data); }); // Webcam.freeze(); -======= - Webcam.freeze(); - Webcam.snap(function(image_data) { - io.io_master.input(io.id, image_data); - }); ->>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217 this.state = "SNAPPED"; }, clear: function() { if (this.state == "SNAPPED") { this.state = "CAMERA_ON"; -<<<<<<< HEAD // Webcam.unfreeze(); -======= - Webcam.unfreeze(); ->>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217 } }, state: "NOT_STARTED", diff --git a/build/lib/gradio/templates/index.html b/build/lib/gradio/templates/index.html index 929a785bb1..a5fbe7ab04 100644 --- a/build/lib/gradio/templates/index.html +++ b/build/lib/gradio/templates/index.html @@ -34,14 +34,16 @@ Live at .
-<<<<<<< HEAD

-======= ->>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217
+ @@ -89,27 +91,40 @@ }); }); }, "#interface_target"); -<<<<<<< HEAD if (config["title"]) { $("#title").text(config["title"]); } if (config["description"]) { $("#description").text(config["description"]); } -======= ->>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217 if (config["share_url"]) { let share_url = config["share_url"]; $("#share").removeClass("invisible"); $("#share-link").text(share_url).attr("href", share_url); $("#share-copy").click(function() { copyToClipboard(share_url); -<<<<<<< HEAD $("#share-copy").text("Copied!"); -======= ->>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217 }) - } + }; + if (config["examples"]) { + $("#examples").removeClass("invisible"); + let html = "" + for (let i = 0; i < config["input_interfaces"].length; i++) { + label = config["input_interfaces"][i][1]["label"]; + html += "" + label + ""; + } + html += ""; + html += ""; + for (let example of config["examples"]) { + html += ""; + for (let col of example) { + html += "" + col + ""; + } + html += ""; + } + html += ""; + $("#examples table").html(html); + }; }); const copyToClipboard = str => { const el = document.createElement('textarea'); diff --git a/demo/basic_text.py b/demo/basic_text.py index 73bd98b9fd..fc76f68673 100644 --- a/demo/basic_text.py +++ b/demo/basic_text.py @@ -17,9 +17,8 @@ gr.Interface(answer_question, ], [ gr.outputs.Textbox(label="out", lines=8), "key_values" + ], examples=[ + ["things1", "things2"], + ["things10", "things20"], ] -<<<<<<< HEAD - ).launch(title="Demo", description="Trying out a funky model!") -======= - ).launch(share=True) ->>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217 + ).launch() diff --git a/gradio/inputs.py b/gradio/inputs.py index c65b57e978..1395f76f1a 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -58,11 +58,12 @@ class AbstractInput(ABC): """ return {} - def rebuild_flagged(self, dir, msg): + @classmethod + def process_example(self, example): """ - All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64) + Proprocess example for UI """ - pass + return example class Sketchpad(AbstractInput): @@ -98,30 +99,6 @@ class Sketchpad(AbstractInput): array = array.astype(self.dtype) return array - # TODO(abidlabs): clean this up - def rebuild_flagged(self, dir, msg): - """ - Default rebuild method to decode a base64 image - """ - - im = preprocessing_utils.decode_base64_to_image(msg) - - timestamp = datetime.datetime.now() - filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png' - im.save(f'{dir}/{filename}', 'PNG') - return filename - - def get_sample_inputs(self): - encoded_images = [] - if self.sample_inputs is not None: - for input in self.sample_inputs: - if self.flatten: - input = input.reshape((self.image_width, self.image_height)) - if self.invert_colors: - input = 1 - input - encoded_images.append(preprocessing_utils.encode_array_to_base64(input)) - return encoded_images - class Webcam(AbstractInput): def __init__(self, image_width=224, image_height=224, num_channels=3, label=None): @@ -149,17 +126,6 @@ class Webcam(AbstractInput): array = np.array(im).flatten().reshape(self.image_width, self.image_height, self.num_channels) return array - def rebuild_flagged(self, dir, msg): - """ - Default rebuild method to decode a base64 image - """ - inp = msg['data']['input'] - im = preprocessing_utils.decode_base64_to_image(inp) - timestamp = datetime.datetime.now() - filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png' - im.save(f'{dir}/{filename}', 'PNG') - return filename - class Textbox(AbstractInput): def __init__(self, sample_inputs=None, lines=1, placeholder=None, label=None, numeric=False): @@ -196,15 +162,6 @@ class Textbox(AbstractInput): else: return inp - def rebuild_flagged(self, dir, msg): - """ - Default rebuild method for text saves it .txt file - """ - return json.loads(msg) - - def get_sample_inputs(self): - return self.sample_inputs - class Radio(AbstractInput): def __init__(self, choices, label=None): @@ -339,43 +296,6 @@ class Image(AbstractInput): self.num_channels) return array - def rebuild_flagged(self, dir, msg): - """ - Default rebuild method to decode a base64 image - """ - im = preprocessing_utils.decode_base64_to_image(msg) - timestamp = datetime.datetime.now() - filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png' - im.save(f'{dir}/{filename}', 'PNG') - return filename - - # TODO(abidlabs): clean this up - def save_to_file(self, dir, img): - """ - """ - timestamp = time.time()*1000 - filename = 'input_{}.png'.format(timestamp) - img.save('{}/{}'.format(dir, filename), 'PNG') - return filename - - -class CSV(AbstractInput): - - def get_name(self): - return 'csv' - - def preprocess(self, inp): - """ - By default, no pre-processing is applied to a CSV file (TODO:aliabid94 fix this) - """ - return inp - - def rebuild_flagged(self, dir, msg): - """ - Default rebuild method for csv - """ - return json.loads(msg) - class Microphone(AbstractInput): @@ -387,12 +307,6 @@ class Microphone(AbstractInput): mfcc_array = preprocessing_utils.generate_mfcc_features_from_audio_file(file_obj.name) return mfcc_array - def rebuild_flagged(self, dir, msg): - """ - Default rebuild method for csv - """ - return json.loads(msg) - # Automatically adds all shortcut implementations in AbstractInput into a dictionary. shortcuts = {} diff --git a/gradio/interface.py b/gradio/interface.py index fbd1640f57..7aab25df28 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -28,7 +28,7 @@ class Interface: the appropriate inputs and outputs """ - def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, + def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, examples=None, live=False, show_input=True, show_output=True, load_fn=None, capture_session=False, title=None, description=None, server_name=LOCALHOST_IP): @@ -82,6 +82,7 @@ class Interface: self.server_name = server_name self.title = title self.description = description + self.examples = examples def get_config_file(self): return { @@ -309,6 +310,7 @@ class Interface: config = self.get_config_file() config["share_url"] = share_url + config["examples"] = self.examples networking.set_config(config, output_directory) return httpd, path_to_local_server, share_url diff --git a/gradio/outputs.py b/gradio/outputs.py index cd7d6bb168..af96d0de1b 100644 --- a/gradio/outputs.py +++ b/gradio/outputs.py @@ -76,12 +76,6 @@ class Label(AbstractOutput): "label": {}, } - def rebuild_flagged(self, dir, msg): - """ - Default rebuild method for label - """ - return json.loads(msg) - class KeyValues(AbstractOutput): def __init__(self, label=None): @@ -120,12 +114,6 @@ class Textbox(AbstractOutput): """ return prediction - def rebuild_flagged(self, dir, msg): - """ - Default rebuild method for label - """ - return json.loads(msg) - class Image(AbstractOutput): def __init__(self, label=None, plot=False): diff --git a/gradio/templates/index.html b/gradio/templates/index.html index 328669b63a..a5fbe7ab04 100644 --- a/gradio/templates/index.html +++ b/gradio/templates/index.html @@ -39,6 +39,11 @@

+ @@ -100,7 +105,26 @@ copyToClipboard(share_url); $("#share-copy").text("Copied!"); }) - } + }; + if (config["examples"]) { + $("#examples").removeClass("invisible"); + let html = "" + for (let i = 0; i < config["input_interfaces"].length; i++) { + label = config["input_interfaces"][i][1]["label"]; + html += "" + label + ""; + } + html += ""; + html += ""; + for (let example of config["examples"]) { + html += ""; + for (let col of example) { + html += "" + col + ""; + } + html += ""; + } + html += ""; + $("#examples table").html(html); + }; }); const copyToClipboard = str => { const el = document.createElement('textarea');