mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-24 10:54:04 +08:00
upgraded to 0.9.3
This commit is contained in:
parent
4472e27438
commit
48087003ee
@ -7,7 +7,7 @@ automatically added to a registry, which allows them to be easily referenced in
|
||||
from abc import ABC, abstractmethod
|
||||
from gradio import preprocessing_utils, validation_data
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
import PIL.Image, PIL.ImageOps
|
||||
import time
|
||||
import warnings
|
||||
import json
|
||||
@ -84,11 +84,11 @@ class Sketchpad(AbstractInput):
|
||||
Default preprocessing method for the SketchPad is to convert the sketch to black and white and resize 28x28
|
||||
"""
|
||||
im_transparent = preprocessing_utils.decode_base64_to_image(inp)
|
||||
im = Image.new("RGBA", im_transparent.size, "WHITE") # Create a white background for the alpha channel
|
||||
im = PIL.Image.new("RGBA", im_transparent.size, "WHITE") # Create a white background for the alpha channel
|
||||
im.paste(im_transparent, (0, 0), im_transparent)
|
||||
im = im.convert('L')
|
||||
if self.invert_colors:
|
||||
im = ImageOps.invert(im)
|
||||
im = PIL.ImageOps.invert(im)
|
||||
im = im.resize((self.image_width, self.image_height))
|
||||
if self.flatten:
|
||||
array = np.array(im).flatten().reshape(1, self.image_width * self.image_height)
|
||||
@ -261,6 +261,7 @@ class Slider(AbstractInput):
|
||||
"checkbox": {},
|
||||
}
|
||||
|
||||
|
||||
class Checkbox(AbstractInput):
|
||||
def __init__(self, label=None):
|
||||
super().__init__(label)
|
||||
@ -272,7 +273,7 @@ class Checkbox(AbstractInput):
|
||||
}
|
||||
|
||||
|
||||
class ImageIn(AbstractInput):
|
||||
class Image(AbstractInput):
|
||||
def __init__(self, cast_to=None, shape=(224, 224, 3), image_mode='RGB',
|
||||
scale=1/127.5, shift=-1, cropper_aspect_ratio=None, label=None):
|
||||
self.cast_to = cast_to
|
||||
|
@ -16,7 +16,6 @@ import requests
|
||||
import random
|
||||
import time
|
||||
from IPython import get_ipython
|
||||
import tensorflow as tf
|
||||
|
||||
LOCALHOST_IP = "0.0.0.0"
|
||||
TRY_NUM_PORTS = 100
|
||||
@ -31,7 +30,7 @@ class Interface:
|
||||
|
||||
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False,
|
||||
live=False, show_input=True, show_output=True,
|
||||
load_fn=None, capture_session=False,
|
||||
load_fn=None, capture_session=False, title=None, description=None,
|
||||
server_name=LOCALHOST_IP):
|
||||
"""
|
||||
:param fn: a function that will process the input panel data from the interface and return the output panel data.
|
||||
@ -81,6 +80,8 @@ class Interface:
|
||||
self.capture_session = capture_session
|
||||
self.session = None
|
||||
self.server_name = server_name
|
||||
self.title = title
|
||||
self.description = description
|
||||
|
||||
def get_config_file(self):
|
||||
return {
|
||||
@ -93,7 +94,9 @@ class Interface:
|
||||
"function_count": len(self.predict),
|
||||
"live": self.live,
|
||||
"show_input": self.show_input,
|
||||
"show_output": self.show_output,
|
||||
"show_output": self.show_output,
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
}
|
||||
|
||||
def process(self, raw_input):
|
||||
@ -109,8 +112,15 @@ class Interface:
|
||||
prediction = predict_fn(*processed_input,
|
||||
self.context)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input,
|
||||
self.context)
|
||||
try:
|
||||
prediction = predict_fn(*processed_input, self.context)
|
||||
except ValueError:
|
||||
print("It looks like you might be "
|
||||
"using tensorflow < 2.0. Please pass "
|
||||
"capture_session=True in Interface to avoid "
|
||||
"a 'Tensor is not an element of this graph.' "
|
||||
"error.")
|
||||
prediction = predict_fn(*processed_input, self.context)
|
||||
else:
|
||||
if self.capture_session:
|
||||
graph, sess = self.session
|
||||
@ -118,7 +128,16 @@ class Interface:
|
||||
with sess.as_default():
|
||||
prediction = predict_fn(*processed_input)
|
||||
else:
|
||||
prediction = predict_fn(*processed_input)
|
||||
try:
|
||||
prediction = predict_fn(*processed_input)
|
||||
except ValueError:
|
||||
print("It looks like you might be "
|
||||
"using tensorflow < 2.0. Please pass "
|
||||
"capture_session=True in Interface to avoid "
|
||||
"a 'Tensor is not an element of this graph.' "
|
||||
"error.")
|
||||
prediction = predict_fn(*processed_input)
|
||||
|
||||
if len(self.output_interfaces) / \
|
||||
len(self.predict) == 1:
|
||||
prediction = [prediction]
|
||||
@ -127,7 +146,6 @@ class Interface:
|
||||
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
|
||||
return processed_output
|
||||
|
||||
|
||||
def validate(self):
|
||||
if self.validate_flag:
|
||||
if self.verbose:
|
||||
@ -180,11 +198,7 @@ class Interface:
|
||||
return
|
||||
raise RuntimeError("Validation did not pass")
|
||||
|
||||
<<<<<<< HEAD
|
||||
def launch(self, inline=None, inbrowser=None, share=False, validate=True, title=None, description=None):
|
||||
=======
|
||||
def launch(self, inline=None, inbrowser=None, share=False, validate=True):
|
||||
>>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217
|
||||
"""
|
||||
Standard method shared by interfaces that creates the interface and sets up a websocket to communicate with it.
|
||||
:param inline: boolean. If True, then a gradio interface is created inline (e.g. in jupyter or colab notebook)
|
||||
@ -198,6 +212,7 @@ class Interface:
|
||||
self.context = context
|
||||
|
||||
if self.capture_session:
|
||||
import tensorflow as tf
|
||||
self.session = tf.get_default_graph(), \
|
||||
tf.keras.backend.get_session()
|
||||
|
||||
@ -294,11 +309,6 @@ class Interface:
|
||||
|
||||
config = self.get_config_file()
|
||||
config["share_url"] = share_url
|
||||
<<<<<<< HEAD
|
||||
config["title"] = title
|
||||
config["description"] = description
|
||||
=======
|
||||
>>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217
|
||||
networking.set_config(config, output_directory)
|
||||
|
||||
return httpd, path_to_local_server, share_url
|
||||
|
@ -1,11 +1,7 @@
|
||||
function gradio(config, fn, target) {
|
||||
target = $(target);
|
||||
target.html(`
|
||||
<<<<<<< HEAD
|
||||
<div class="panels container">
|
||||
=======
|
||||
<div class="panels">
|
||||
>>>>>>> 2bd16c2f9c360c98583b94e2f6a6ea7259a98217
|
||||
<div class="panel input_panel">
|
||||
<div class="input_interfaces">
|
||||
</div>
|
||||
@ -30,7 +26,7 @@ function gradio(config, fn, target) {
|
||||
|
||||
let input_to_object_map = {
|
||||
"csv" : {},
|
||||
"imagein" : image_input,
|
||||
"image" : image_input,
|
||||
"sketchpad" : sketchpad_input,
|
||||
"textbox" : textbox_input,
|
||||
"webcam" : webcam,
|
||||
|
BIN
dist/gradio-0.9.1-py3.6.egg
vendored
BIN
dist/gradio-0.9.1-py3.6.egg
vendored
Binary file not shown.
@ -1,6 +1,6 @@
|
||||
Metadata-Version: 1.0
|
||||
Name: gradio
|
||||
Version: 0.9.2
|
||||
Version: 0.9.3
|
||||
Summary: Python library for easily interacting with trained machine learning models
|
||||
Home-page: https://github.com/abidlabs/gradio
|
||||
Author: Abubakar Abid
|
||||
|
Loading…
Reference in New Issue
Block a user