This commit is contained in:
Ali Abid 2020-08-10 14:22:39 -07:00
commit 7e7629c65d
11 changed files with 42 additions and 1423 deletions

View File

@ -23,4 +23,4 @@ gr.Interface(flip2,
["images/cheetah2.jpg"],
["images/lion.jpg"],
]
).launch()
).launch(share=False)

View File

@ -1,2 +0,0 @@
recursive-include gradio/templates *
recursive-include gradio/static *

View File

@ -1,170 +0,0 @@
[![CircleCI](https://circleci.com/gh/gradio-app/gradio.svg?style=svg)](https://circleci.com/gh/gradio-app/gradio) [![PyPI version](https://badge.fury.io/py/gradio.svg)](https://badge.fury.io/py/gradio)
# Welcome to `gradio` :rocket:
Quickly create customizable UI components around your TensorFlow or PyTorch models, or even arbitrary Python functions. Mix and match components to support any combination of inputs and outputs. Gradio makes it easy for you to "play around" with your model in your browser by dragging-and-dropping in your own images (or pasting your own text, recording your own voice, etc.) and seeing what the model outputs. You can also generate a share link which allows anyone, anywhere to use the interface as the model continues to run on your machine. Our core library is free and open-source! Take a look:
<p align="center">
<img src="https://i.ibb.co/m0skD0j/bert.gif" alt="drawing"/>
</p>
Gradio is useful for:
* Creating demos of your machine learning code for clients / collaborators / users
* Getting feedback on model performance from users
* Debugging your model interactively during development
To get a sense of `gradio`, take a look at a few of these examples, and find more on our website: www.gradio.app.
## Installation
```
pip install gradio
```
(you may need to replace `pip` with `pip3` if you're running `python3`).
## Usage
Gradio is very easy to use with your existing code. Here are a few working examples:
### 0. Hello World [![alt text](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/18ODkJvyxHutTN0P5APWyGFO_xwNcgHDZ?usp=sharing)
Let's start with a basic function (no machine learning yet!) that greets an input name. We'll wrap the function with a `Text` to `Text` interface.
```python
import gradio as gr
def greet(name):
return "Hello " + name + "!"
gr.Interface(fn=greet, inputs="text", outputs="text").launch()
```
The core Interface class is initialized with three parameters:
- `fn`: the function to wrap
- `inputs`: the name of the input interface
- `outputs`: the name of the output interface
Calling the `launch()` function of the `Interface` object produces the interface shown in image below. Click on the gif to go the live interface in our getting started page.
<a href="https://gradio.app/getting_started#interface_4">
<p align="center">
<img src="https://i.ibb.co/T4Rqs5y/hello-name.gif" alt="drawing"/>
</p>
</a>
### 1. Inception Net [![alt text](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1c6gQiW88wKBwWq96nqEwuQ1Kyt5LejiU?usp=sharing)
Now, let's do a machine learning example. We're going to wrap an
interface around the InceptionV3 image classifier, which we'll load
using Tensorflow! Since this is an image classification model, we will use the `Image` input interface.
We'll output a dictionary of labels and their corresponding confidence scores with the `Label` output
interface. (The original Inception Net architecture [can be found here](https://arxiv.org/abs/1409.4842))
```python
import gradio as gr
import tensorflow as tf
import numpy as np
import requests
inception_net = tf.keras.applications.InceptionV3() # load the model
# Download human-readable labels for ImageNet.
response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")
def classify_image(inp):
inp = inp.reshape((-1, 299, 299, 3))
inp = tf.keras.applications.inception_v3.preprocess_input(inp)
prediction = inception_net.predict(inp).flatten()
return {labels[i]: float(prediction[i]) for i in range(1000)}
image = gr.inputs.Image(shape=(299, 299, 3))
label = gr.outputs.Label(num_top_classes=3)
gr.Interface(fn=classify_image, inputs=image, outputs=label).launch()
```
This code will produce the interface below. The interface gives you a way to test
Inception Net by dragging and dropping images, and also allows you to use naturally modify the input image using image editing tools that
appear when you click EDIT. Notice here we provided actual `gradio.inputs` and `gradio.outputs` objects to the Interface
function instead of using string shortcuts. This lets us use built-in preprocessing (e.g. image resizing)
and postprocessing (e.g. choosing the number of labels to display) provided by these
interfaces.
<p align="center">
<img src="https://i.ibb.co/X8KGJqB/inception-net-2.gif" alt="drawing"/>
</p>
You can supply your own model instead of the pretrained model above, as well as use different kinds of models or functions. Here's a list of the interfaces we currently support, along with their preprocessing / postprocessing parameters:
**Input Interfaces**:
- `Sketchpad(shape=(28, 28), invert_colors=True, flatten=False, scale=1/255, shift=0, dtype='float64')`
- `Webcam(image_width=224, image_height=224, num_channels=3, label=None)`
- `Textbox(lines=1, placeholder=None, label=None, numeric=False)`
- `Radio(choices, label=None)`
- `Dropdown(choices, label=None)`
- `CheckboxGroup(choices, label=None)`
- `Slider(minimum=0, maximum=100, default=None, label=None)`
- `Image(shape=(224, 224, 3), image_mode='RGB', scale=1/127.5, shift=-1, label=None)`
- `Microphone()`
**Output Interfaces**:
- `Label(num_top_classes=None, label=None)`
- `KeyValues(label=None)`
- `Textbox(lines=1, placeholder=None, label=None)`
- `Image(label=None, plot=False)`
Interfaces can also be combined together, for multiple-input or multiple-output models.
### 2. Real-Time MNIST [![alt text](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LXJqwdkZNkt1J_yfLWQ3FLxbG2cAF8p4?usp=sharing)
Let's wrap a fun `Sketchpad`-to-`Label` UI around MNIST. For this example, we'll take advantage of the `live`
feature in the library. Set `live=True` inside `Interface()`> to have it run continuous predictions.
We've abstracted the model training from the code below, but you can see the full code on the colab link.
```python
import tensorflow as tf
import gradio as gr
from urllib.request import urlretrieve
urlretrieve("https://gr-models.s3-us-west-2.amazonaws.com/mnist-model.h5","mnist-model.h5")
model = tf.keras.models.load_model("mnist-model.h5")
def recognize_digit(inp):
prediction = model.predict(inp.reshape(1, 28, 28, 1)).tolist()[0]
return {str(i): prediction[i] for i in range(10)}
sketchpad = gr.inputs.Sketchpad()
label = gr.outputs.Label(num_top_classes=3)
gr.Interface(fn=recognize_digit, inputs=sketchpad,
outputs=label, live=True).launch()
```
This code will produce the interface below.
<p align="center">
<img src="https://i.ibb.co/vkgZLcH/gif6.gif" alt="drawing"/>
</p>
## Contributing:
If you would like to contribute and your contribution is small, you can directly open a pull request (PR). If you would like to contribute a larger feature, we recommend first creating an issue with a proposed design for discussion. Please see our contributing guidelines for more info.
## License:
Gradio is licensed under the Apache License 2.0
## See more:
You can find many more examples (like GPT-2, model comparison, multiple inputs, and numerical interfaces) as well as more info on usage on our website: www.gradio.app
See, also, the accompanying paper: ["Gradio: Hassle-Free Sharing and Testing of ML Models in the Wild"](https://arxiv.org/pdf/1906.02569.pdf), *ICML HILL 2019*, and please use the citation below.
```
@article{abid2019gradio,
title={Gradio: Hassle-Free Sharing and Testing of ML Models in the Wild},
author={Abid, Abubakar and Abdalla, Ali and Abid, Ali and Khan, Dawood and Alfozan, Abdulrahman and Zou, James},
journal={arXiv preprint arXiv:1906.02569},
year={2019}
}
```

View File

@ -1 +0,0 @@
from gradio.interface import * # This makes it possible to import `Interface` as `gradio.Interface`.

View File

@ -1,68 +0,0 @@
import json
from gradio.inputs import AbstractInput
from gradio.outputs import AbstractOutput
from gradio.interface import Interface
import inspect
def get_params(func):
params_str = inspect.getdoc(func)
params_doc = []
documented_params = {"self"}
for param_line in params_str.split("\n")[1:]:
space_index = param_line.index(" ")
colon_index = param_line.index(":")
name = param_line[:space_index]
documented_params.add(name)
params_doc.append((name, param_line[space_index+2:colon_index-1], param_line[colon_index+2:]))
params = inspect.getfullargspec(func)
param_set = []
for i in range(len(params.args)):
neg_index = -1 - i
if params.args[neg_index] not in documented_params:
continue
if i < len(params.defaults):
default = params.defaults[neg_index]
if type(default) == str:
default = '"' + default + '"'
else:
default = str(default)
param_set.insert(0, (params.args[neg_index], default))
else:
param_set.insert(0, (params.args[neg_index],))
return param_set, params_doc
def document(cls_set):
docset = []
for cls in cls_set:
inp = {}
inp["name"] = cls.__name__
doc = inspect.getdoc(cls)
inp["doc"] = "\n".join(doc.split("\n")[:-1])
inp["type"] = doc.split("\n")[-1].split("type: ")[-1]
inp["params"], inp["params_doc"] = get_params(cls.__init__)
inp["shortcuts"] = list(cls.get_shortcut_implementations().items())
docset.append(inp)
return docset
inputs = document(AbstractInput.__subclasses__())
outputs = document(AbstractOutput.__subclasses__())
interface_params = get_params(Interface.__init__)
interface = {
"doc": inspect.getdoc(Interface),
"params": interface_params[0],
"params_doc": interface_params[1],
}
launch_params = get_params(Interface.launch)
launch = {
"params": launch_params[0],
"params_doc": launch_params[1],
}
with open("docs.json", "w") as docs:
json.dump({
"inputs": inputs,
"outputs": outputs,
"interface": interface,
"launch": launch
}, docs)

View File

@ -1,422 +0,0 @@
"""
This module defines various classes that can serve as the `input` to an interface. Each class must inherit from
`AbstractInput`, and each class must define a path to its template. All of the subclasses of `AbstractInput` are
automatically added to a registry, which allows them to be easily referenced in other parts of the code.
"""
import datetime
import json
import os
import time
import warnings
from abc import ABC, abstractmethod
import numpy as np
import PIL.Image
import PIL.ImageOps
import scipy.io.wavfile
from gradio import processing_utils, validation_data
# Where to find the static resources associated with each template.
# BASE_INPUT_INTERFACE_TEMPLATE_PATH = 'static/js/interfaces/input/{}.js'
BASE_INPUT_INTERFACE_JS_PATH = 'static/js/interfaces/input/{}.js'
class AbstractInput(ABC):
"""
An abstract class for defining the methods that all gradio inputs should have.
When this is subclassed, it is automatically added to the registry
"""
def __init__(self, label):
self.label = label
def get_validation_inputs(self):
"""
An interface can optionally implement a method that returns a list of examples inputs that it should be able to
accept and preprocess for validation purposes.
"""
return []
def get_template_context(self):
"""
:return: a dictionary with context variables for the javascript file associated with the context
"""
return {"label": self.label}
def preprocess(self, inp):
"""
By default, no pre-processing is applied to text.
"""
return inp
def process_example(self, example):
"""
Proprocess example for UI
"""
return example
@classmethod
def get_shortcut_implementations(cls):
"""
Return dictionary of shortcut implementations
"""
return {}
class Textbox(AbstractInput):
"""
Component creates a textbox for user to enter input. Provides a string (or number is `is_numeric` is true) as an argument to the wrapped function.
Input type: str
"""
def __init__(self, lines=1, placeholder=None, default=None, numeric=False, label=None):
'''
Parameters:
lines (int): number of line rows to provide in textarea.
placeholder (str): placeholder hint to provide behind textarea.
default (str): default text to provide in textarea.
numeric (bool): whether the input should be parsed as a number instead of a string.
label (str): component name in interface.
'''
self.lines = lines
self.placeholder = placeholder
self.default = default
self.numeric = numeric
super().__init__(label)
def get_template_context(self):
return {
"lines": self.lines,
"placeholder": self.placeholder,
"default": self.default,
**super().get_template_context()
}
@classmethod
def get_shortcut_implementations(cls):
return {
"text": {},
"textbox": {"lines": 7},
"number": {"numeric": True}
}
def preprocess(self, inp):
"""
Cast type of input
"""
if self.numeric:
return float(inp)
else:
return inp
class Slider(AbstractInput):
"""
Component creates a slider that ranges from `minimum` to `maximum`. Provides a number as an argument to the wrapped function.
Input type: float
"""
def __init__(self, minimum=0, maximum=100, step=None, default=None, label=None):
'''
Parameters:
minimum (float): minimum value for slider.
maximum (float): maximum value for slider.
step (float): increment between slider values.
default (float): default value.
label (str): component name in interface.
'''
self.minimum = minimum
self.maximum = maximum
self.default = minimum if default is None else default
super().__init__(label)
def get_template_context(self):
return {
"minimum": self.minimum,
"maximum": self.maximum,
"default": self.default,
**super().get_template_context()
}
@classmethod
def get_shortcut_implementations(cls):
return {
"slider": {},
}
class Checkbox(AbstractInput):
"""
Component creates a checkbox that can be set to `True` or `False`. Provides a boolean as an argument to the wrapped function.
Input type: bool
"""
def __init__(self, label=None):
'''
Parameters:
label (str): component name in interface.
'''
super().__init__(label)
@classmethod
def get_shortcut_implementations(cls):
return {
"checkbox": {},
}
class CheckboxGroup(AbstractInput):
"""
Component creates a set of checkboxes of which a subset can be selected. Provides a list of strings representing the selected choices as an argument to the wrapped function.
Input type: List[str]
"""
def __init__(self, choices, label=None):
'''
Parameters:
choices (List[str]): list of options to select from.
label (str): component name in interface.
'''
self.choices = choices
super().__init__(label)
def get_template_context(self):
return {
"choices": self.choices,
**super().get_template_context()
}
class Radio(AbstractInput):
"""
Component creates a set of radio buttons of which only one can be selected. Provides string representing selected choice as an argument to the wrapped function.
Input type: str
"""
def __init__(self, choices, label=None):
'''
Parameters:
choices (List[str]): list of options to select from.
label (str): component name in interface.
'''
self.choices = choices
super().__init__(label)
def get_template_context(self):
return {
"choices": self.choices,
**super().get_template_context()
}
class Dropdown(AbstractInput):
"""
Component creates a dropdown of which only one can be selected. Provides string representing selected choice as an argument to the wrapped function.
Input type: str
"""
def __init__(self, choices, label=None):
'''
Parameters:
choices (List[str]): list of options to select from.
label (str): component name in interface.
'''
self.choices = choices
super().__init__(label)
def get_template_context(self):
return {
"choices": self.choices,
**super().get_template_context()
}
class Image(AbstractInput):
"""
Component creates an image upload box with editing capabilities. Provides numpy array of shape `(width, height, 3)` if `image_mode` is "RGB" as an argument to the wrapped function. Provides numpy array of shape `(width, height)` if `image_mode` is "L" as an argument to the wrapped function.
Input type: numpy.array
"""
def __init__(self, shape=None, image_mode='RGB', label=None):
'''
Parameters:
shape (Tuple[int, int]): shape to crop and resize image to; if None, matches input image size.
image_mode (str): "RGB" if color, or "L" if black and white.
label (str): component name in interface.
'''
if shape is None:
self.image_width, self.image_height = None, None
else:
self.image_width = shape[0]
self.image_height = shape[1]
self.image_mode = image_mode
super().__init__(label)
def get_validation_inputs(self):
return validation_data.BASE64_COLOR_IMAGES
@classmethod
def get_shortcut_implementations(cls):
return {
"image": {},
}
def get_template_context(self):
return {
**super().get_template_context()
}
def preprocess(self, inp):
"""
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
"""
im = processing_utils.decode_base64_to_image(inp)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
im = im.convert(self.image_mode)
image_width, image_height = self.image_width, self.image_height
if image_width is None:
image_width = im.size[0]
if image_height is None:
image_height = im.size[1]
im = processing_utils.resize_and_crop(
im, (image_width, image_height))
return np.array(im)
def process_example(self, example):
if os.path.exists(example):
return processing_utils.convert_file_to_base64(example)
else:
return example
class Sketchpad(AbstractInput):
"""
Component creates a sketchpad for black and white illustration. Provides numpy array of shape `(width, height)` as an argument to the wrapped function.
Input type: numpy.array
"""
def __init__(self, shape=(28, 28), invert_colors=True,
flatten=False, label=None):
'''
Parameters:
shape (Tuple[int, int]): shape to crop and resize image to.
invert_colors (bool): whether to represent black as 1 and white as 0 in the numpy array.
flatten (bool): whether to reshape the numpy array to a single dimension.
label (str): component name in interface.
'''
self.image_width = shape[0]
self.image_height = shape[1]
self.invert_colors = invert_colors
self.flatten = flatten
super().__init__(label)
@classmethod
def get_shortcut_implementations(cls):
return {
"sketchpad": {},
}
def preprocess(self, inp):
"""
Default preprocessing method for the SketchPad is to convert the sketch to black and white and resize 28x28
"""
im_transparent = processing_utils.decode_base64_to_image(inp)
# Create a white background for the alpha channel
im = PIL.Image.new("RGBA", im_transparent.size, "WHITE")
im.paste(im_transparent, (0, 0), im_transparent)
im = im.convert('L')
if self.invert_colors:
im = PIL.ImageOps.invert(im)
im = im.resize((self.image_width, self.image_height))
if self.flatten:
array = np.array(im).flatten().reshape(
1, self.image_width * self.image_height)
else:
array = np.array(im).flatten().reshape(
1, self.image_width, self.image_height)
return array
def process_example(self, example):
return processing_utils.convert_file_to_base64(example)
class Webcam(AbstractInput):
"""
Component creates a webcam for captured image input. Provides numpy array of shape `(width, height, 3)` as an argument to the wrapped function.
Input type: numpy.array
"""
def __init__(self, shape=(224, 224), label=None):
'''
Parameters:
shape (Tuple[int, int]): shape to crop and resize image to.
label (str): component name in interface.
'''
self.image_width = shape[0]
self.image_height = shape[1]
self.num_channels = 3
super().__init__(label)
def get_validation_inputs(self):
return validation_data.BASE64_COLOR_IMAGES
@classmethod
def get_shortcut_implementations(cls):
return {
"webcam": {},
}
def preprocess(self, inp):
"""
Default preprocessing method for is to convert the picture to black and white and resize to be 48x48
"""
im = processing_utils.decode_base64_to_image(inp)
im = im.convert('RGB')
im = processing_utils.resize_and_crop(
im, (self.image_width, self.image_height))
return np.array(im)
class Microphone(AbstractInput):
"""
Component creates a microphone element for audio inputs. Provides numpy array of shape `(samples, 2)` as an argument to the wrapped function.
Input type: numpy.array
"""
def __init__(self, preprocessing=None, label=None):
'''
Parameters:
preprocessing (Union[str, Callable]): preprocessing to apply to input
label (str): component name in interface.
'''
super().__init__(label)
if preprocessing is None or preprocessing == "mfcc":
self.preprocessing = preprocessing
else:
raise ValueError(
"unexpected value for preprocessing", preprocessing)
@classmethod
def get_shortcut_implementations(cls):
return {
"microphone": {},
}
def preprocess(self, inp):
"""
By default, no pre-processing is applied to a microphone input file
"""
file_obj = processing_utils.decode_base64_to_file(inp)
if self.preprocessing == "mfcc":
return processing_utils.generate_mfcc_features_from_audio_file(file_obj.name)
_, signal = scipy.io.wavfile.read(file_obj.name)
return signal
# Automatically adds all shortcut implementations in AbstractInput into a dictionary.
shortcuts = {}
for cls in AbstractInput.__subclasses__():
for shortcut, parameters in cls.get_shortcut_implementations().items():
shortcuts[shortcut] = cls(**parameters)

View File

@ -1,420 +0,0 @@
"""
This is the core file in the `gradio` package, and defines the Interface class, including methods for constructing the
interface using the input and output types.
"""
import tempfile
import traceback
import webbrowser
import gradio.inputs
import gradio.outputs
from gradio import networking, strings
from distutils.version import StrictVersion
import pkg_resources
import requests
import random
import time
import inspect
from IPython import get_ipython
import sys
import weakref
import analytics
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
analytics_url = 'https://api.gradio.app/'
try:
ip_address = requests.get('https://api.ipify.org').text
except requests.ConnectionError:
ip_address = "No internet connection"
class Interface:
"""
Interfaces are created with Gradio using the `gradio.Interface()` function.
"""
instances = weakref.WeakSet()
def __init__(self, fn, inputs, outputs, saliency=None, verbose=False, examples=None,
live=False, show_input=True, show_output=True,
capture_session=False, title=None, description=None,
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
allow_screenshot=True):
"""
Parameters:
fn (Callable): the function to wrap an interface around.
inputs (Union[str, List[Union[str, AbstractInput]]]): a single Gradio input component, or list of Gradio input components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of input components should match the number of parameters in fn.
outputs (Union[str, List[Union[str, AbstractOutput]]]): a single Gradio output component, or list of Gradio output components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of output components should match the number of values returned by fn.
live (bool): whether the interface should automatically reload on change.
capture_session (bool): if True, captures the default graph and session (needed for Tensorflow 1.x)
title (str): a title for the interface; if provided, appears above the input and output components.
description (str): a description for the interface; if provided, appears above the input and output components.
examples (List[List[Any]]): sample inputs for the function; if provided, appears below the UI components and can be used to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component.
"""
def get_input_instance(iface):
if isinstance(iface, str):
return gradio.inputs.shortcuts[iface.lower()]
elif isinstance(iface, gradio.inputs.AbstractInput):
return iface
else:
raise ValueError("Input interface must be of type `str` or "
"`AbstractInput`")
def get_output_instance(iface):
if isinstance(iface, str):
return gradio.outputs.shortcuts[iface.lower()]
elif isinstance(iface, gradio.outputs.AbstractOutput):
return iface
else:
raise ValueError(
"Output interface must be of type `str` or "
"`AbstractOutput`"
)
if isinstance(inputs, list):
self.input_interfaces = [get_input_instance(i) for i in inputs]
else:
self.input_interfaces = [get_input_instance(inputs)]
if isinstance(outputs, list):
self.output_interfaces = [get_output_instance(i) for i in outputs]
else:
self.output_interfaces = [get_output_instance(outputs)]
if not isinstance(fn, list):
fn = [fn]
self.output_interfaces *= len(fn)
self.predict = fn
self.verbose = verbose
self.status = "OFF"
self.saliency = saliency
self.live = live
self.show_input = show_input
self.show_output = show_output
self.flag_hash = random.getrandbits(32)
self.capture_session = capture_session
self.session = None
self.server_name = server_name
self.title = title
self.description = description
self.thumbnail = thumbnail
self.examples = examples
self.server_port = server_port
self.simple_server = None
self.allow_screenshot = allow_screenshot
Interface.instances.add(self)
data = {'fn': fn,
'inputs': inputs,
'outputs': outputs,
'saliency': saliency,
'live': live,
'capture_session': capture_session,
'ip_address': ip_address
}
if self.capture_session:
try:
import tensorflow as tf
self.session = tf.get_default_graph(), \
tf.keras.backend.get_session()
except (ImportError, AttributeError): # If they are using TF >= 2.0 or don't have TF, just ignore this.
pass
try:
requests.post(analytics_url + 'gradio-initiated-analytics/',
data=data)
except requests.ConnectionError:
pass # do not push analytics if no network
def get_config_file(self):
config = {
"input_interfaces": [
(iface.__class__.__name__.lower(), iface.get_template_context())
for iface in self.input_interfaces],
"output_interfaces": [
(iface.__class__.__name__.lower(), iface.get_template_context())
for iface in self.output_interfaces],
"function_count": len(self.predict),
"live": self.live,
"show_input": self.show_input,
"show_output": self.show_output,
"title": self.title,
"description": self.description,
"thumbnail": self.thumbnail,
"allow_screenshot": self.allow_screenshot
}
try:
param_names = inspect.getfullargspec(self.predict[0])[0]
for iface, param in zip(config["input_interfaces"], param_names):
if not iface[1]["label"]:
iface[1]["label"] = param.replace("_", " ")
for i, iface in enumerate(config["output_interfaces"]):
ret_name = "Output " + str(i + 1) if len(config["output_interfaces"]) > 1 else "Output"
if not iface[1]["label"]:
iface[1]["label"] = ret_name
except ValueError:
pass
return config
def process(self, raw_input):
processed_input = [input_interface.preprocess(
raw_input[i]) for i, input_interface in
enumerate(self.input_interfaces)]
predictions = []
durations = []
for predict_fn in self.predict:
start = time.time()
if self.capture_session and not(self.session is None):
graph, sess = self.session
with graph.as_default():
with sess.as_default():
prediction = predict_fn(*processed_input)
else:
try:
prediction = predict_fn(*processed_input)
except ValueError as exception:
if str(exception).endswith("is not an element of this "
"graph."):
raise ValueError("It looks like you might be using "
"tensorflow < 2.0. Please "
"pass capture_session=True in "
"Interface to avoid the 'Tensor is "
"not an element of this graph.' "
"error.")
else:
raise exception
duration = time.time() - start
if len(self.output_interfaces) == len(self.predict):
prediction = [prediction]
durations.append(duration)
predictions.extend(prediction)
processed_output = [output_interface.postprocess(
predictions[i]) for i, output_interface in enumerate(self.output_interfaces)]
return processed_output, durations
def validate(self):
if self.validate_flag:
if self.verbose:
print("Interface already validated")
return
validation_inputs = self.input_interface.get_validation_inputs()
n = len(validation_inputs)
if n == 0:
self.validate_flag = True
if self.verbose:
print(
"No validation samples for this interface... skipping validation."
)
return
for m, msg in enumerate(validation_inputs):
if self.verbose:
print(
"Validating samples: {}/{} [".format(m+1, n)
+ "=" * (m + 1)
+ "." * (n - m - 1)
+ "]",
end="\r",
)
try:
processed_input = self.input_interface.preprocess(msg)
prediction = self.predict(processed_input)
except Exception as e:
data = {'error': e}
try:
requests.post(analytics_url + 'gradio-error-analytics/',
data=data)
except requests.ConnectionError:
pass # do not push analytics if no network
if self.verbose:
print("\n----------")
print(
"Validation failed, likely due to incompatible pre-processing and model input. See below:\n"
)
print(traceback.format_exc())
break
try:
_ = self.output_interface.postprocess(prediction)
except Exception as e:
data = {'error': e}
try:
requests.post(analytics_url + 'gradio-error-analytics/',
data=data)
except requests.ConnectionError:
pass # do not push analytics if no network
if self.verbose:
print("\n----------")
print(
"Validation failed, likely due to incompatible model output and post-processing."
"See below:\n"
)
print(traceback.format_exc())
break
else: # This means if a break was not explicitly called
self.validate_flag = True
if self.verbose:
print("\n\nValidation passed successfully!")
return
raise RuntimeError("Validation did not pass")
def close(self):
if self.simple_server and not(self.simple_server.fileno() == -1): # checks to see if server is running
print("Closing Gradio server on port {}...".format(self.server_port))
networking.close_server(self.simple_server)
def launch(self, inline=None, inbrowser=None, share=False, validate=True, debug=False):
"""
Parameters
share (bool): whether to create a publicly shareable link from your computer for the interface.
"""
# if validate and not self.validate_flag:
# self.validate()
output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface.
server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name,
server_port=self.server_port)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory)
self.server_port = server_port
self.status = "RUNNING"
self.simple_server = httpd
is_colab = False
try: # Check if running interactively using ipython.
from_ipynb = get_ipython()
if "google.colab" in str(from_ipynb):
is_colab = True
except NameError:
data = {'error': 'NameError in launch method'}
try:
requests.post(analytics_url + 'gradio-error-analytics/',
data=data)
except requests.ConnectionError:
pass # do not push analytics if no network
pass
try:
current_pkg_version = pkg_resources.require("gradio")[0].version
latest_pkg_version = requests.get(url=PKG_VERSION_URL).json()["version"]
if StrictVersion(latest_pkg_version) > StrictVersion(current_pkg_version):
print("IMPORTANT: You are using gradio version {}, "
"however version {} "
"is available, please upgrade.".format(
current_pkg_version, latest_pkg_version))
print('--------')
except: # TODO(abidlabs): don't catch all exceptions
pass
if not is_colab:
print(strings.en["RUNNING_LOCALLY"].format(path_to_local_server))
else:
if debug:
print("Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. "
"To turn off, set debug=False in launch().")
else:
print("Colab notebook detected. To show errors in colab notebook, set debug=True in launch()")
if share:
try:
share_url = networking.setup_tunnel(server_port)
print("Running on External URL:", share_url)
except RuntimeError:
data = {'error': 'RuntimeError in launch method'}
try:
requests.post(analytics_url + 'gradio-error-analytics/',
data=data)
except requests.ConnectionError:
pass # do not push analytics if no network
share_url = None
if self.verbose:
print(strings.en["NGROK_NO_INTERNET"])
else:
if (
is_colab
): # For a colab notebook, create a public link even if share is False.
share_url = networking.setup_tunnel(server_port)
print("Running on External URL:", share_url)
if self.verbose:
print(strings.en["COLAB_NO_LOCAL"])
else: # If it's not a colab notebook and share=False, print a message telling them about the share option.
if self.verbose:
print(strings.en["PUBLIC_SHARE_TRUE"])
share_url = None
if inline is None:
try: # Check if running interactively using ipython.
get_ipython()
inline = True
if inbrowser is None:
inbrowser = False
except NameError:
inline = False
if inbrowser is None:
inbrowser = True
else:
if inbrowser is None:
inbrowser = False
if inbrowser and not is_colab:
webbrowser.open(
path_to_local_server
) # Open a browser tab with the interface.
if inline:
from IPython.display import IFrame, display
if (
is_colab
): # Embed the remote interface page if on google colab;
# otherwise, embed the local page.
print("Interface loading below...")
while not networking.url_ok(share_url):
time.sleep(1)
display(IFrame(share_url, width=1000, height=500))
else:
display(IFrame(path_to_local_server, width=1000, height=500))
config = self.get_config_file()
config["share_url"] = share_url
processed_examples = []
if self.examples is not None:
for example_set in self.examples:
processed_set = []
for iface, example in zip(self.input_interfaces, example_set):
processed_set.append(iface.process_example(example))
processed_examples.append(processed_set)
config["examples"] = processed_examples
networking.set_config(config, output_directory)
networking.set_meta_tags(output_directory, self.title, self.description, self.thumbnail)
if debug:
while True:
sys.stdout.flush()
time.sleep(0.1)
launch_method = 'browser' if inbrowser else 'inline'
data = {'launch_method': launch_method,
'is_google_colab': is_colab,
'is_sharing_on': share,
'share_url': share_url,
'ip_address': ip_address
}
try:
requests.post(analytics_url + 'gradio-launched-analytics/',
data=data)
except requests.ConnectionError:
pass # do not push analytics if no network
return httpd, path_to_local_server, share_url
@classmethod
def get_instances(cls):
return list(Interface.instances) # Returns list of all current instances.
def reset_all():
for io in Interface.get_instances():
io.close()

View File

@ -1,287 +0,0 @@
"""
Defines helper methods useful for setting up ports, launching servers, and handling `ngrok`
"""
import os
import socket
import threading
from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
import pkg_resources
from distutils import dir_util
from gradio import inputs, outputs
import json
from gradio.tunneling import create_tunnel
import urllib.request
from shutil import copyfile
import requests
import sys
import analytics
INITIAL_PORT_VALUE = int(os.getenv(
'GRADIO_SERVER_PORT', "7860")) # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
TRY_NUM_PORTS = int(os.getenv(
'GRADIO_NUM_PORTS', "100")) # Number of ports to try before giving up and throwing an exception.
LOCALHOST_NAME = os.getenv(
'GRADIO_SERVER_NAME', "127.0.0.1")
GRADIO_API_SERVER = "https://api.gradio.app/v1/tunnel-request"
STATIC_TEMPLATE_LIB = pkg_resources.resource_filename("gradio", "templates/")
STATIC_PATH_LIB = pkg_resources.resource_filename("gradio", "static/")
STATIC_PATH_TEMP = "static/"
TEMPLATE_TEMP = "index.html"
BASE_JS_FILE = "static/js/all_io.js"
CONFIG_FILE = "static/config.json"
ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association"
ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association"
FLAGGING_DIRECTORY = 'static/flagged/'
FLAGGING_FILENAME = 'data.txt'
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
analytics_url = 'https://api.gradio.app/'
def build_template(temp_dir):
"""
Create HTML file with supporting JS and CSS files in a given directory.
:param temp_dir: string with path to temp directory in which the html file should be built
"""
dir_util.copy_tree(STATIC_TEMPLATE_LIB, temp_dir)
dir_util.copy_tree(STATIC_PATH_LIB, os.path.join(
temp_dir, STATIC_PATH_TEMP))
# Move association file to root of temporary directory.
copyfile(os.path.join(temp_dir, ASSOCIATION_PATH_IN_STATIC),
os.path.join(temp_dir, ASSOCIATION_PATH_IN_ROOT))
def render_template_with_tags(template_path, context):
"""
Combines the given template with a given context dictionary by replacing all of the occurrences of tags (enclosed
in double curly braces) with corresponding values.
:param template_path: a string with the path to the template file
:param context: a dictionary whose string keys are the tags to replace and whose string values are the replacements.
"""
with open(template_path) as fin:
old_lines = fin.readlines()
new_lines = render_string_or_list_with_tags(old_lines, context)
with open(template_path, "w") as fout:
for line in new_lines:
fout.write(line)
def render_string_or_list_with_tags(old_lines, context):
# Handle string case
if isinstance(old_lines, str):
for key, value in context.items():
old_lines = old_lines.replace(r"{{" + key + r"}}", str(value))
return old_lines
# Handle list case
new_lines = []
for line in old_lines:
for key, value in context.items():
line = line.replace(r"{{" + key + r"}}", str(value))
new_lines.append(line)
return new_lines
def set_meta_tags(temp_dir, title, description, thumbnail):
title = "Gradio" if title is None else title
description = "Easy-to-use UI for your machine learning model" if description is None else description
thumbnail = "https://gradio.app/static/img/logo_only.png" if thumbnail is None else thumbnail
index_file = os.path.join(temp_dir, TEMPLATE_TEMP)
render_template_with_tags(index_file, {
"title": title,
"description": description,
"thumbnail": thumbnail
})
def set_config(config, temp_dir):
config_file = os.path.join(temp_dir, CONFIG_FILE)
with open(config_file, "w") as output:
json.dump(config, output)
def get_first_available_port(initial, final):
"""
Gets the first open port in a specified range of port numbers
:param initial: the initial value in the range of port numbers
:param final: final (exclusive) value in the range of port numbers, should be greater than `initial`
:return:
"""
for port in range(initial, final):
try:
s = socket.socket() # create a socket object
s.bind((LOCALHOST_NAME, port)) # Bind to the port
s.close()
return port
except OSError:
pass
raise OSError(
"All ports from {} to {} are in use. Please close a port.".format(
initial, final
)
)
def send_prediction_analytics(interface):
data = {'title': interface.title,
'description': interface.description,
'thumbnail': interface.thumbnail,
'input_interface': interface.input_interfaces,
'output_interface': interface.output_interfaces,
}
print(data)
try:
requests.post(
analytics_url + 'gradio-prediction-analytics/',
data=data)
except requests.ConnectionError:
pass # do not push analytics if no network
def serve_files_in_background(interface, port, directory_to_serve=None, server_name=LOCALHOST_NAME):
class HTTPHandler(SimpleHTTPRequestHandler):
"""This handler uses server.base_path instead of always using os.getcwd()"""
def _set_headers(self):
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
def translate_path(self, path):
path = SimpleHTTPRequestHandler.translate_path(self, path)
relpath = os.path.relpath(path, os.getcwd())
fullpath = os.path.join(self.server.base_path, relpath)
return fullpath
def log_message(self, format, *args):
return
def do_POST(self):
# Read body of the request.
if self.path == "/api/predict/":
# Make the prediction.
self._set_headers()
data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
raw_input = msg["data"]
prediction, durations = interface.process(raw_input)
output = {"data": prediction, "durations": durations}
if interface.saliency is not None:
saliency = interface.saliency(raw_input, prediction)
output['saliency'] = saliency.tolist()
# if interface.always_flag:
# msg = json.loads(data_string)
# flag_dir = os.path.join(FLAGGING_DIRECTORY, str(interface.hash))
# os.makedirs(flag_dir, exist_ok=True)
# output_flag = {'input': interface.input_interface.rebuild_flagged(flag_dir, msg['data']),
# 'output': interface.output_interface.rebuild_flagged(flag_dir, processed_output),
# }
# with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
# f.write(json.dumps(output_flag))
# f.write("\n")
self.wfile.write(json.dumps(output).encode())
analytics_thread = threading.Thread(
target=send_prediction_analytics, args=[interface])
analytics_thread.start()
elif self.path == "/api/flag/":
self._set_headers()
data_string = self.rfile.read(
int(self.headers["Content-Length"]))
msg = json.loads(data_string)
flag_dir = os.path.join(FLAGGING_DIRECTORY,
str(interface.flag_hash))
os.makedirs(flag_dir, exist_ok=True)
output = {'inputs': [interface.input_interfaces[
i].rebuild_flagged(
flag_dir, msg['data']['input_data']) for i
in range(len(interface.input_interfaces))],
'outputs': [interface.output_interfaces[
i].rebuild_flagged(
flag_dir, msg['data']['output_data']) for i
in range(len(interface.output_interfaces))],
'message': msg['data']['message']}
with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f:
f.write(json.dumps(output))
f.write("\n")
else:
self.send_error(404, 'Path not found: {}'.format(self.path))
class HTTPServer(BaseHTTPServer):
"""The main server, you pass in base_path which is the path you want to serve requests from"""
def __init__(self, base_path, server_address, RequestHandlerClass=HTTPHandler):
self.base_path = base_path
BaseHTTPServer.__init__(self, server_address, RequestHandlerClass)
httpd = HTTPServer(directory_to_serve, (server_name, port))
# Now loop forever
def serve_forever():
try:
while True:
sys.stdout.flush()
httpd.serve_forever()
except (KeyboardInterrupt, OSError):
httpd.shutdown()
httpd.server_close()
thread = threading.Thread(target=serve_forever, daemon=False)
thread.start()
return httpd
def start_simple_server(interface, directory_to_serve=None, server_name=None, server_port=None):
if server_port is None:
server_port = INITIAL_PORT_VALUE
port = get_first_available_port(
server_port, server_port + TRY_NUM_PORTS
)
httpd = serve_files_in_background(interface, port, directory_to_serve, server_name)
return port, httpd
def close_server(server):
server.server_close()
def url_request(url):
try:
req = urllib.request.Request(
url=url, headers={"content-type": "application/json"}
)
res = urllib.request.urlopen(req, timeout=10)
return res
except Exception as e:
raise RuntimeError(str(e))
def setup_tunnel(local_server_port):
response = url_request(GRADIO_API_SERVER)
if response and response.code == 200:
try:
payload = json.loads(response.read().decode("utf-8"))[0]
return create_tunnel(payload, LOCALHOST_NAME, local_server_port)
except Exception as e:
raise RuntimeError(str(e))
def url_ok(url):
try:
r = requests.head(url)
return r.status_code == 200
except ConnectionError:
return False

View File

@ -1,25 +0,0 @@
try:
from setuptools import setup
except ImportError:
from distutils.core import setup
setup(
name='gradio',
version='1.0.2',
include_package_data=True,
description='Python library for easily interacting with trained machine learning models',
author='Abubakar Abid',
author_email='a12d@stanford.edu',
url='https://github.com/gradio-app/gradio-UI',
packages=['gradio'],
keywords=['machine learning', 'visualization', 'reproducibility'],
install_requires=[
'numpy',
'requests',
'paramiko',
'scipy',
'IPython',
'scikit-image',
'analytics-python',
],
)

View File

@ -21,7 +21,6 @@ import weakref
import analytics
import os
PKG_VERSION_URL = "https://gradio.app/api/pkg-version"
analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy"
analytics_url = 'https://api.gradio.app/'
@ -48,7 +47,7 @@ class Interface:
def __init__(self, fn, inputs, outputs, verbose=False, examples=None,
live=False, show_input=True, show_output=True,
capture_session=False, title=None, description=None,
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
thumbnail=None, server_port=None, server_name=networking.LOCALHOST_NAME,
allow_screenshot=True, allow_flagging=True,
flagging_dir="flagged"):
"""
@ -69,6 +68,7 @@ class Interface:
allow_flagging (bool): if False, users will not see a button to flag an input and output.
flagging_dir (str): what to name the dir where flagged data is stored.
"""
def get_input_instance(iface):
if isinstance(iface, str):
shortcut = InputComponent.get_all_shortcut_implementations()[iface]
@ -90,6 +90,7 @@ class Interface:
"Output interface must be of type `str` or "
"`OutputComponent`"
)
if isinstance(inputs, list):
self.input_interfaces = [get_input_instance(i) for i in inputs]
else:
@ -135,7 +136,7 @@ class Interface:
try:
import tensorflow as tf
self.session = tf.get_default_graph(), \
tf.keras.backend.get_session()
tf.keras.backend.get_session()
except (ImportError, AttributeError):
# If they are using TF >= 2.0 or don't have TF,
# just ignore this.
@ -151,7 +152,7 @@ class Interface:
"_{}".format(index)):
index += 1
self.flagging_dir = self.flagging_dir + "/" + dir_name + \
"_{}".format(index)
"_{}".format(index)
try:
requests.post(analytics_url + 'gradio-initiated-analytics/',
@ -188,8 +189,8 @@ class Interface:
iface[1]["label"] = ret_name
except ValueError:
pass
return config
return config
def process(self, raw_input):
"""
@ -208,7 +209,7 @@ class Interface:
durations = []
for predict_fn in self.predict:
start = time.time()
if self.capture_session and not(self.session is None):
if self.capture_session and not (self.session is None):
graph, sess = self.session
with graph.as_default():
with sess.as_default():
@ -238,10 +239,19 @@ class Interface:
return processed_output, durations
def close(self):
if self.simple_server and not(self.simple_server.fileno() == -1): # checks to see if server is running
if self.simple_server and not (self.simple_server.fileno() == -1): # checks to see if server is running
print("Closing Gradio server on port {}...".format(self.server_port))
networking.close_server(self.simple_server)
def run_until_interrupted(self, thread, path_to_local_server):
try:
while 1:
pass
except (KeyboardInterrupt, OSError):
print("Keyboard interruption in main thread... closing server.")
thread.keep_running = False
networking.url_ok(path_to_local_server)
def launch(self, inline=None, inbrowser=None, share=False, debug=False):
"""
Parameters
@ -258,11 +268,10 @@ class Interface:
path_to_local_server (str): Locally accessible link
share_url (str): Publicly accessible link (if share=True)
"""
output_directory = tempfile.mkdtemp()
# Set up a port to serve the directory containing the static files with interface.
server_port, httpd = networking.start_simple_server(self, output_directory, self.server_name,
server_port=self.server_port)
server_port, httpd, thread = networking.start_simple_server(
self, output_directory, self.server_name, server_port=self.server_port)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
networking.build_template(output_directory)
@ -277,7 +286,7 @@ class Interface:
print("IMPORTANT: You are using gradio version {}, "
"however version {} "
"is available, please upgrade.".format(
current_pkg_version, latest_pkg_version))
current_pkg_version, latest_pkg_version))
print('--------')
except: # TODO(abidlabs): don't catch all exceptions
pass
@ -370,6 +379,11 @@ class Interface:
data=data)
except requests.ConnectionError:
pass # do not push analytics if no network
is_in_interactive_mode = bool(getattr(sys, 'ps1', sys.flags.interactive))
if not is_in_interactive_mode:
self.run_until_interrupted(thread, path_to_local_server)
return httpd, path_to_local_server, share_url

View File

@ -9,6 +9,7 @@ from http.server import HTTPServer as BaseHTTPServer, SimpleHTTPRequestHandler
import pkg_resources
from distutils import dir_util
from gradio import inputs, outputs
import time
import json
from gradio.tunneling import create_tunnel
import urllib.request
@ -221,22 +222,21 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n
self.base_path = base_path
BaseHTTPServer.__init__(self, server_address, RequestHandlerClass)
class QuittableHTTPThread(threading.Thread):
def __init__(self, httpd):
super().__init__(daemon=False)
self.httpd = httpd
self.keep_running =True
def run(self):
while self.keep_running:
self.httpd.handle_request()
httpd = HTTPServer(directory_to_serve, (server_name, port))
# Now loop forever
def serve_forever():
try:
while True:
sys.stdout.flush()
httpd.serve_forever()
except (KeyboardInterrupt, OSError):
httpd.shutdown()
httpd.server_close()
thread = threading.Thread(target=serve_forever, daemon=False)
thread = QuittableHTTPThread(httpd=httpd)
thread.start()
return httpd
return httpd, thread
def start_simple_server(interface, directory_to_serve=None, server_name=None, server_port=None):
@ -245,8 +245,8 @@ def start_simple_server(interface, directory_to_serve=None, server_name=None, se
port = get_first_available_port(
server_port, server_port + TRY_NUM_PORTS
)
httpd = serve_files_in_background(interface, port, directory_to_serve, server_name)
return port, httpd
httpd, thread = serve_files_in_background(interface, port, directory_to_serve, server_name)
return port, httpd, thread
def close_server(server):