mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Merge branch 'master' of https://github.com/abidlabs/gradio
This commit is contained in:
commit
ff7587671c
1
.gitignore
vendored
1
.gitignore
vendored
@ -9,3 +9,4 @@ models/*
|
||||
.models/*
|
||||
gradio_files/*
|
||||
ngrok*
|
||||
examples/ngrok*
|
||||
|
@ -2,9 +2,18 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The autoreload extension is already loaded. To reload it, use:\n",
|
||||
" %reload_ext autoreload\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
@ -16,26 +25,26 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# model = tf.keras.applications.inception_v3.InceptionV3()"
|
||||
"model = tf.keras.applications.inception_v3.InceptionV3()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inp = gradio.inputs.ImageUpload(image_width=299, image_height=299)\n",
|
||||
"out = gradio.outputs.Label(label_names='imagenet1000', max_label_length=8, num_top_classes=8)\n",
|
||||
"inp = gradio.inputs.ImageUpload(image_width=299, image_height=299, num_channels=3)\n",
|
||||
"out = gradio.outputs.Label(label_names='imagenet1000', max_label_length=12, num_top_classes=5)\n",
|
||||
"\n",
|
||||
"iface = gradio.Interface(inputs=inp, \n",
|
||||
" outputs=out,\n",
|
||||
" model=lambda x: np.array(1), \n",
|
||||
" model_type='function')\n",
|
||||
" model=model, \n",
|
||||
" model_type='keras')\n",
|
||||
"\n",
|
||||
"# iface = gradio.Interface(inputs=inp, \n",
|
||||
"# outputs=out,\n",
|
||||
@ -45,7 +54,26 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Validating samples: 6/6 [======]\n",
|
||||
"\n",
|
||||
"Validation passed successfully!\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"iface.validate()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
@ -55,7 +83,7 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n",
|
||||
"Model is running locally at: http://localhost:7860/interface.html\n",
|
||||
"Model is running locally at: http://localhost:7866/interface.html\n",
|
||||
"To create a public link, set `share=True` in the argument to `launch()`\n"
|
||||
]
|
||||
},
|
||||
@ -66,14 +94,14 @@
|
||||
" <iframe\n",
|
||||
" width=\"1000\"\n",
|
||||
" height=\"500\"\n",
|
||||
" src=\"http://localhost:7860/interface.html\"\n",
|
||||
" src=\"http://localhost:7866/interface.html\"\n",
|
||||
" frameborder=\"0\"\n",
|
||||
" allowfullscreen\n",
|
||||
" ></iframe>\n",
|
||||
" "
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.lib.display.IFrame at 0x26c5900fba8>"
|
||||
"<IPython.lib.display.IFrame at 0x1b08a51f748>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
|
@ -6,11 +6,12 @@ automatically added to a registry, which allows them to be easily referenced in
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import base64
|
||||
from gradio import preprocessing_utils
|
||||
from gradio import preprocessing_utils, validation_data
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
|
||||
class AbstractInput(ABC):
|
||||
"""
|
||||
An abstract class for defining the methods that all gradio inputs should have.
|
||||
@ -27,6 +28,9 @@ class AbstractInput(ABC):
|
||||
self.preprocess = preprocessing_fn
|
||||
super().__init__()
|
||||
|
||||
def get_validation_inputs(self):
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def get_template_path(self):
|
||||
"""
|
||||
@ -74,6 +78,9 @@ class Webcam(AbstractInput):
|
||||
self.num_channels = num_channels
|
||||
super().__init__(preprocessing_fn=preprocessing_fn)
|
||||
|
||||
def get_validation_inputs(self):
|
||||
return validation_data.BASE64_COLOR_IMAGES
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/input/webcam.html'
|
||||
|
||||
@ -90,6 +97,8 @@ class Webcam(AbstractInput):
|
||||
|
||||
|
||||
class Textbox(AbstractInput):
|
||||
def get_validation_inputs(self):
|
||||
return validation_data.ENGLISH_TEXTS
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/input/textbox.html'
|
||||
@ -112,6 +121,9 @@ class ImageUpload(AbstractInput):
|
||||
self.shift = shift
|
||||
super().__init__(preprocessing_fn=preprocessing_fn)
|
||||
|
||||
def get_validation_inputs(self):
|
||||
return validation_data.BASE64_COLOR_IMAGES
|
||||
|
||||
def get_template_path(self):
|
||||
return 'templates/input/image_upload.html'
|
||||
|
||||
|
@ -12,6 +12,7 @@ import gradio.outputs
|
||||
from gradio import networking
|
||||
import tempfile
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
@ -63,6 +64,8 @@ class Interface:
|
||||
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
|
||||
self.model_type = model_type
|
||||
self.verbose = verbose
|
||||
self.launch_flag = False
|
||||
self.validate_flag = False
|
||||
|
||||
@staticmethod
|
||||
def _infer_model_type(model):
|
||||
@ -133,11 +136,55 @@ class Interface:
|
||||
else:
|
||||
ValueError('model_type must be one of: {}'.format(self.VALID_MODEL_TYPES))
|
||||
|
||||
def launch(self, inline=None, browser=None, share=False):
|
||||
def validate(self):
|
||||
if self.validate_flag:
|
||||
if self.verbose:
|
||||
print("Interface already validated")
|
||||
return
|
||||
validation_inputs = self.input_interface.get_validation_inputs()
|
||||
n = len(validation_inputs)
|
||||
if n == 0:
|
||||
self.validate_flag = True
|
||||
if self.verbose:
|
||||
print("No validation samples for this interface... skipping validation.")
|
||||
return
|
||||
for m, msg in enumerate(validation_inputs):
|
||||
if self.verbose:
|
||||
print(f"Validating samples: {m+1}/{n} [" + "="*(m+1) + "."*(n-m-1) + "]", end='\r')
|
||||
try:
|
||||
processed_input = self.input_interface.preprocess(msg)
|
||||
prediction = self.predict(processed_input)
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print("\n----------")
|
||||
print("Validation failed, likely due to incompatible pre-processing and model input. See below:\n")
|
||||
print(traceback.format_exc())
|
||||
break
|
||||
try:
|
||||
_ = self.output_interface.postprocess(prediction)
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print("\n----------")
|
||||
print("Validation failed, likely due to incompatible model output and post-processing."
|
||||
"See below:\n")
|
||||
print(traceback.format_exc())
|
||||
break
|
||||
else: # This means if a break was not explicitly called
|
||||
self.validate_flag = True
|
||||
if self.verbose:
|
||||
print("\n\nValidation passed successfully!")
|
||||
return
|
||||
raise RuntimeError("Validation did not pass")
|
||||
|
||||
def launch(self, inline=None, browser=None, share=False, validate=True):
|
||||
"""
|
||||
Standard method shared by interfaces that creates the interface and sets up a websocket to communicate with it.
|
||||
:param share: boolean. If True, then a share link is generated using ngrok is displayed to the user.
|
||||
"""
|
||||
if validate and not self.validate_flag:
|
||||
self.validate()
|
||||
|
||||
self.launch_flag = True
|
||||
output_directory = tempfile.mkdtemp()
|
||||
|
||||
# Set up a port to serve the directory containing the static files with interface.
|
||||
|
16
gradio/validation_data.py
Normal file
16
gradio/validation_data.py
Normal file
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user