g the commit.

merge
This commit is contained in:
Ali Abdalla 2019-03-06 00:35:14 -08:00
commit a812d5f555
18 changed files with 342 additions and 3622 deletions

View File

@ -4,39 +4,43 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"ename": "SyntaxError",
"evalue": "invalid syntax (networking.py, line 200)",
"output_type": "error",
"traceback": [
"Traceback \u001b[1;36m(most recent call last)\u001b[0m:\n",
" File \u001b[0;32m\"C:\\Users\\islam\\Anaconda3\\lib\\site-packages\\IPython\\core\\interactiveshell.py\"\u001b[0m, line \u001b[0;32m2961\u001b[0m, in \u001b[0;35mrun_code\u001b[0m\n exec(code_obj, self.user_global_ns, self.user_ns)\n",
" File \u001b[0;32m\"<ipython-input-1-ab911fe0150a>\"\u001b[0m, line \u001b[0;32m4\u001b[0m, in \u001b[0;35m<module>\u001b[0m\n import gradio\n",
" File \u001b[0;32m\"C:\\Users\\islam\\Repos\\gradio\\gradio\\__init__.py\"\u001b[0m, line \u001b[0;32m1\u001b[0m, in \u001b[0;35m<module>\u001b[0m\n from gradio.interface import Interface # This makes it possible to import `Interface` as `gradio.Interface`.\n",
"\u001b[1;36m File \u001b[1;32m\"C:\\Users\\islam\\Repos\\gradio\\gradio\\interface.py\"\u001b[1;36m, line \u001b[1;32m12\u001b[1;36m, in \u001b[1;35m<module>\u001b[1;36m\u001b[0m\n\u001b[1;33m from gradio import networking\u001b[0m\n",
"\u001b[1;36m File \u001b[1;32m\"C:\\Users\\islam\\Repos\\gradio\\gradio\\networking.py\"\u001b[1;36m, line \u001b[1;32m200\u001b[0m\n\u001b[1;33m except AccessDenied, NoSuchProcess:\u001b[0m\n\u001b[1;37m ^\u001b[0m\n\u001b[1;31mSyntaxError\u001b[0m\u001b[1;31m:\u001b[0m invalid syntax\n"
]
}
],
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import tensorflow as tf\n",
"import gradio"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
"x_train, x_test = x_train / 255.0, x_test / 255.0"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"(60000, 28, 28)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def test(x):\n",
" return x.upper()\n",
"\n",
"def test2(x):\n",
" return x.lower()"
"x_train.shape"
]
},
{
@ -47,7 +51,16 @@
},
"outputs": [],
"source": [
"iface = gradio.Interface(inputs=\"textbox\", outputs=\"textbox\", model=test, model_type='function')"
"model = tf.keras.models.Sequential([\n",
" tf.keras.layers.Flatten(),\n",
" tf.keras.layers.Dense(512, activation=tf.nn.relu),\n",
" tf.keras.layers.Dropout(0.2),\n",
" tf.keras.layers.Dense(10, activation=tf.nn.softmax)\n",
"])\n",
"\n",
"model.compile(optimizer='adam',\n",
" loss='sparse_categorical_crossentropy',\n",
" metrics=['accuracy'])"
]
},
{
@ -59,36 +72,163 @@
"name": "stdout",
"output_type": "stream",
"text": [
"NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n",
"Model available locally at: http://localhost:7860/interface.html\n"
"Epoch 1/1\n",
"3/3 [==============================] - 8s 3s/step - loss: 2.0548 - acc: 0.3412\n"
]
},
{
"ename": "NoSuchProcess",
"evalue": "psutil.NoSuchProcess process no longer exists (pid=2744)",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mNoSuchProcess\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-5-8be24579e72c>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0miface\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlaunch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mshare\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;32m~\\Repos\\gradio\\gradio\\interface.py\u001b[0m in \u001b[0;36mlaunch\u001b[1;34m(self, share)\u001b[0m\n\u001b[0;32m 136\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 137\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mshare\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 138\u001b[1;33m \u001b[0msite_ngrok_url\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnetworking\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msetup_ngrok\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mserver_port\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mwebsocket_port\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moutput_directory\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 139\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mverbose\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 140\u001b[0m print(\"Model available publicly for 8 hours at: {}\".format(\n",
"\u001b[1;32m~\\Repos\\gradio\\gradio\\networking.py\u001b[0m in \u001b[0;36msetup_ngrok\u001b[1;34m(server_port, websocket_port, output_directory)\u001b[0m\n\u001b[0;32m 185\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 186\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0msetup_ngrok\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mserver_port\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mwebsocket_port\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moutput_directory\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 187\u001b[1;33m \u001b[0mkill_processes\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m4040\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m4041\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m#TODO(abidlabs): better way to do this\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 188\u001b[0m \u001b[0msite_ngrok_url\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcreate_ngrok_tunnel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mserver_port\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mNGROK_TUNNELS_API_URL\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 189\u001b[0m \u001b[0msocket_ngrok_url\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcreate_ngrok_tunnel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mwebsocket_port\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mNGROK_TUNNELS_API_URL2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Repos\\gradio\\gradio\\networking.py\u001b[0m in \u001b[0;36mkill_processes\u001b[1;34m(process_ids)\u001b[0m\n\u001b[0;32m 197\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mconns\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mproc\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconnections\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkind\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'inet'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 198\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mconns\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mladdr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mport\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mprocess_ids\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 199\u001b[1;33m \u001b[0mproc\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msend_signal\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mSIGTERM\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# or SIGKILL\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 200\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mAccessDenied\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 201\u001b[0m \u001b[1;32mpass\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\psutil\\__init__.py\u001b[0m in \u001b[0;36mwrapper\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 282\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 283\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_running\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 284\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mNoSuchProcess\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpid\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_name\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 285\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 286\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mNoSuchProcess\u001b[0m: psutil.NoSuchProcess process no longer exists (pid=2744)"
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x1fca491be48>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(x_train, y_train, epochs=1, steps_per_epoch=3)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"inp = gradio.inputs.ImageUpload(image_width=28, image_height=28, num_channels=None, image_mode='L')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"iface = gradio.Interface(inputs=inp, outputs=\"label\", model=model, model_type='keras')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n",
"Model available locally at: http://localhost:7860/interface.html\n",
"To create a public link, set `share=True` in the argument to `launch()`\n"
]
},
{
"data": {
"text/plain": [
"('http://localhost:7860/interface.html', None)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"127.0.0.1 - - [05/Mar/2019 23:13:12] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [05/Mar/2019 23:13:12] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'label': 2, 'confidences': [{'label': 2, 'confidence': 1.0}, {'label': 0, 'confidence': 0.0}, {'label': 0, 'confidence': 0.0}]}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"127.0.0.1 - - [05/Mar/2019 23:14:47] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [05/Mar/2019 23:14:47] \"GET /static/css/style.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [05/Mar/2019 23:14:47] \"GET /static/css/gradio.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [05/Mar/2019 23:14:47] \"GET /static/js/utils.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [05/Mar/2019 23:14:47] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [05/Mar/2019 23:14:47] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [05/Mar/2019 23:14:47] \"GET /static/img/logo_inline.png HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [05/Mar/2019 23:14:47] \"GET /static/js/class-output.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [05/Mar/2019 23:14:48] code 404, message File not found\n",
"127.0.0.1 - - [05/Mar/2019 23:14:48] \"GET /favicon.ico HTTP/1.1\" 404 -\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'label': 2, 'confidences': [{'label': 2, 'confidence': 1.0}, {'label': 0, 'confidence': 0.0}, {'label': 0, 'confidence': 0.0}]}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"127.0.0.1 - - [05/Mar/2019 23:15:29] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [05/Mar/2019 23:15:29] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [05/Mar/2019 23:15:32] \"GET /interface.html HTTP/1.1\" 200 -\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'label': 2, 'confidences': [{'label': 2, 'confidence': 1.0}, {'label': 0, 'confidence': 0.0}, {'label': 0, 'confidence': 0.0}]}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"127.0.0.1 - - [05/Mar/2019 23:18:47] \"GET /interface.html HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [05/Mar/2019 23:18:47] \"GET /static/css/gradio.css HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [05/Mar/2019 23:18:47] \"GET /static/js/utils.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [05/Mar/2019 23:18:47] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [05/Mar/2019 23:18:47] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [05/Mar/2019 23:18:48] \"GET /static/js/class-output.js HTTP/1.1\" 200 -\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'label': 2, 'confidences': [{'label': 2, 'confidence': 1.0}, {'label': 0, 'confidence': 0.0}, {'label': 0, 'confidence': 0.0}]}\n"
]
}
],
"source": [
"iface.launch(share=True)"
"iface.launch(share=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3.6 (tensorflow)",
"language": "python",
"name": "python3"
"name": "tensorflow"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,

View File

@ -1,5 +1,6 @@
from argparse import ArgumentParser
import gradio
import numpy as np
parser = ArgumentParser(description='Arguments for Building Interface')
parser.add_argument('-i', '--inputs', type=str, help="name of input interface")
@ -12,7 +13,7 @@ args = parser.parse_args()
def launch_interface(args):
io = gradio.Interface(inputs=args.inputs, outputs=args.outputs, model=lambda x:x, model_type='function')
io = gradio.Interface(inputs=args.inputs, outputs=args.outputs, model=lambda x:np.array(1), model_type='function')
io.launch(share=args.share)
# input_interface = gradio.inputs.registry[args.inputs.lower()]()
# output_interface = gradio.outputs.registry[args.outputs.lower()]()

View File

@ -44,6 +44,10 @@ class AbstractInput(ABC):
class Sketchpad(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=28, image_height=28):
self.image_width = image_width
self.image_height = image_height
super().__init__(preprocessing_fn=preprocessing_fn)
def get_template_path(self):
return 'templates/sketchpad_input.html'
@ -55,12 +59,17 @@ class Sketchpad(AbstractInput):
content = inp.split(';')[1]
image_encoded = content.split(',')[1]
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
im = preprocessing_utils.resize_and_crop(im, (28, 28))
array = np.array(im).flatten().reshape(1, 28, 28, 1)
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, 1)
return array
class Webcam(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3):
self.image_width = image_width
self.image_height = image_height
self.num_channels = num_channels
super().__init__(preprocessing_fn=preprocessing_fn)
def get_template_path(self):
return 'templates/webcam_input.html'
@ -71,9 +80,9 @@ class Webcam(AbstractInput):
"""
content = inp.split(';')[1]
image_encoded = content.split(',')[1]
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
im = preprocessing_utils.resize_and_crop(im, (48, 48))
array = np.array(im).flatten().reshape(1, 48, 48, 1)
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('RGB')
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
return array
@ -90,6 +99,12 @@ class Textbox(AbstractInput):
class ImageUpload(AbstractInput):
def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3, image_mode='RGB'):
self.image_width = image_width
self.image_height = image_height
self.num_channels = num_channels
self.image_mode = image_mode
super().__init__(preprocessing_fn=preprocessing_fn)
def get_template_path(self):
return 'templates/image_upload_input.html'
@ -100,9 +115,12 @@ class ImageUpload(AbstractInput):
"""
content = inp.split(';')[1]
image_encoded = content.split(',')[1]
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L')
im = preprocessing_utils.resize_and_crop(im, (48, 48))
array = np.array(im).flatten().reshape(1, 48, 48, 1)
im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert(self.image_mode)
im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height))
if self.num_channels is None:
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height)
else:
array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels)
return array

View File

@ -32,16 +32,26 @@ class Interface:
def __init__(self, inputs, outputs, model, model_type=None, preprocessing_fns=None, postprocessing_fns=None,
verbose=True):
"""
:param inputs: a string representing the input interface.
:param outputs: a string representing the output interface.
:param inputs: a string or `AbstractInput` representing the input interface.
:param outputs: a string or `AbstractOutput` representing the output interface.
:param model_obj: the model object, such as a sklearn classifier or keras model.
:param model_type: what kind of trained model, can be 'keras' or 'sklearn' or 'function'. Inferred if not
provided.
:param preprocessing_fns: an optional function that overrides the preprocessing function of the input interface.
:param postprocessing_fns: an optional function that overrides the postprocessing fn of the output interface.
"""
self.input_interface = gradio.inputs.registry[inputs.lower()](preprocessing_fns)
self.output_interface = gradio.outputs.registry[outputs.lower()](postprocessing_fns)
if isinstance(inputs, str):
self.input_interface = gradio.inputs.registry[inputs.lower()](preprocessing_fns)
elif isinstance(inputs, gradio.inputs.AbstractInput):
self.input_interface = inputs
else:
raise ValueError('Input interface must be of type `str` or `AbstractInput`')
if isinstance(outputs, str):
self.output_interface = gradio.outputs.registry[outputs.lower()](postprocessing_fns)
elif isinstance(outputs, gradio.outputs.AbstractOutput):
self.output_interface = outputs
else:
raise ValueError('Output interface must be of type `str` or `AbstractOutput`')
self.model_obj = model
if model_type is None:
model_type = self._infer_model_type(model)

View File

@ -18,7 +18,7 @@ from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
import pkg_resources
from bs4 import BeautifulSoup
import shutil
from distutils import dir_util
INITIAL_PORT_VALUE = 7860 # The http server will try to open on port 7860. If not available, 7861, 7862, etc.
TRY_NUM_PORTS = 100 # Number of ports to try before giving up and throwing an exception.
@ -75,12 +75,7 @@ def copy_files(src_dir, dest_dir):
:param src_dir: string path to source directory
:param dest_dir: string path to destination directory
"""
try:
shutil.copytree(src_dir, dest_dir)
except OSError as exc: # python >2.5
if exc.errno == errno.ENOTDIR:
shutil.copy(src_dir, dest_dir)
else: raise
dir_util.copy_tree(src_dir, dest_dir)
#TODO(abidlabs): Handle the http vs. https issue that sometimes happens (a ws cannot be loaded from an https page)

View File

@ -6,7 +6,7 @@ automatically added to a registry, which allows them to be easily referenced in
from abc import ABC, abstractmethod
import numpy as np
import json
class AbstractOutput(ABC):
"""
@ -38,6 +38,14 @@ class AbstractOutput(ABC):
class Label(AbstractOutput):
LABEL_KEY = 'label'
CONFIDENCES_KEY = 'confidences'
CONFIDENCE_KEY = 'confidence'
def __init__(self, postprocessing_fn=None, num_top_classes=3, show_confidences=True):
self.num_top_classes = num_top_classes
self.show_confidences = show_confidences
super().__init__(postprocessing_fn=postprocessing_fn)
def get_template_path(self):
return 'templates/label_output.html'
@ -45,16 +53,28 @@ class Label(AbstractOutput):
def postprocess(self, prediction):
"""
"""
response = dict()
# TODO(abidlabs): check if list, if so convert to numpy array
if isinstance(prediction, np.ndarray):
prediction = prediction.squeeze()
if prediction.size == 1:
return prediction
else:
return prediction.argmax()
if prediction.size == 1: # if it's single value
response[Label.LABEL_KEY] = np.asscalar(prediction)
elif len(prediction.shape) == 1: # if a 1D
response[Label.LABEL_KEY] = int(prediction.argmax())
if self.show_confidences:
response[Label.CONFIDENCES_KEY] = []
for i in range(self.num_top_classes):
response[Label.CONFIDENCES_KEY].append({
Label.LABEL_KEY: int(prediction.argmax()),
Label.CONFIDENCE_KEY: float(prediction.max()),
})
prediction[prediction.argmax()] = 0
elif isinstance(prediction, str):
return prediction
response[Label.LABEL_KEY] = prediction
else:
raise ValueError("Unable to post-process model prediction.")
print(response)
return json.dumps(response)
class Textbox(AbstractOutput):

View File

@ -103,6 +103,7 @@
}
.confidence {
padding: 3px;
display: flex;
}
.level, .label {
display: inline-block;
@ -112,12 +113,13 @@
text-align: right;
}
.confidence_intervals .level {
font-size: 12px;
font-size: 14px;
margin-left: 8px;
margin-right: 8px;
background-color: #AAA;
padding: 2px 4px;
text-align: right;
font-family: monospace;
color: white;
font-weight: bold;
}

View File

@ -25,4 +25,4 @@ try {
const sleep = (milliseconds) => {
return new Promise(resolve => setTimeout(resolve, milliseconds))
}
}

View File

@ -24,12 +24,19 @@ try {
};
ws.onmessage = function (event) {
console.log(event.data);
sleep(300).then(() => {
if (event.data.length == 1) {
$(".output_class").css({ 'font-size':'300px'});
// $(".output_class").text(event.data);
var data = JSON.parse(event.data)
$(".output_class").text(data["label"])
$(".confidence_intervals").empty()
if ("confidences" in data) {
data["confidences"].forEach(function (c) {
var confidence = c["confidence"]
$(".confidence_intervals").append(`<div class="confidence"><div class=
"label">${c["label"]}</div><div class="level" style="flex-grow:
${confidence}">${Math.round(confidence * 100)}%</div></div>`)
})
}
$(".output_class").text(event.data);
})
}
@ -39,5 +46,5 @@ try {
$('body').on('click', '.clear', function(e) {
$(".output_class").text("")
$(".confidence_intervals").empty()
})

File diff suppressed because it is too large Load Diff

View File

@ -36,6 +36,7 @@ $(".hidden_upload").on("change", function() {
$('body').on('click', '.submit', function(e) {
var src = $('.input_image img').attr('src');
src = resizeImage(src)
ws.send(src, function(e) {
notifyError(e)
})

27
gradio/static/js/utils.js Normal file
View File

@ -0,0 +1,27 @@
function resizeImage(base64Str) {
var img = new Image();
img.src = base64Str;
var canvas = document.createElement('canvas');
var MAX_WIDTH = 360;
var MAX_HEIGHT = 360;
var width = img.width;
var height = img.height;
if (width > height) {
if (width > MAX_WIDTH) {
height *= MAX_WIDTH / width;
width = MAX_WIDTH;
}
} else {
if (height > MAX_HEIGHT) {
width *= MAX_HEIGHT / height;
height = MAX_HEIGHT;
}
}
canvas.width = width;
canvas.height = height;
var ctx = canvas.getContext('2d');
ctx.drawImage(img, 0, 0, width, height);
return canvas.toDataURL();
}

View File

@ -5,6 +5,7 @@
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css">
<link rel="stylesheet" href="../static/css/style.css">
<link rel="stylesheet" href="../static/css/gradio.css">
<script src="../static/js/utils.js"></script>
<script src="../static/js/all-io.js"></script>
<script src="https://code.jquery.com/jquery-3.2.1.slim.min.js" integrity="sha384-KJ3o2DKtIkvYIK3UENzmM7KCkRr/rE9/Qpg6aAZGJwFDMVNA/GpGFF93hXpG5KkN" crossorigin="anonymous"></script>
</head>
@ -16,12 +17,12 @@
<div id="panels">
<div class="panel">
<div id="input"></div>
<input class="submit" type="submit" value="Submit"/><!--
<input class="submit" type="submit" value="Submit"/><!--DO NOT DELETE
--><input class="clear" type="reset" value="Clear">
</div>
</div>
<div class="panel">
<div id="output"></div>
</div>
</div>
</div>
</div>
</body>
</html>

View File

@ -6,6 +6,4 @@
</div>
<input class="hidden_upload" type="file" accept="image/x-png,image/gif,image/jpeg" />
</div>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/cropper/4.0.0/cropper.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/cropper/4.0.0/cropper.min.js" type="module"></script>
<script src="../static/js/image-upload-input.js"></script>

View File

@ -1,5 +1,7 @@
<div class="gradio output classifier">
<div class="role">Output</div>
<div class="output_class"></div>
<div class="confidence_intervals">
</div>
</div>
<script src="../static/js/class-output.js"></script>

View File

@ -11,7 +11,7 @@ class TestSketchpad(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Sketchpad()
path = inp.get_template_path()
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_preprocessing(self):
inp = inputs.Sketchpad()
@ -23,19 +23,19 @@ class TestWebcam(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Webcam()
path = inp.get_template_path()
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_preprocessing(self):
inp = inputs.Webcam()
array = inp.preprocess(BASE64_IMG)
self.assertEqual(array.shape, (1, 48, 48, 1))
self.assertEqual(array.shape, (1, 224, 224, 3))
class TestTextbox(unittest.TestCase):
def test_path_exists(self):
inp = inputs.Textbox()
path = inp.get_template_path()
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_preprocessing(self):
inp = inputs.Textbox()
@ -52,7 +52,12 @@ class TestImageUpload(unittest.TestCase):
def test_preprocessing(self):
inp = inputs.ImageUpload()
array = inp.preprocess(BASE64_IMG)
self.assertEqual(array.shape, (1, 48, 48, 1))
self.assertEqual(array.shape, (1, 224, 224, 3))
def test_preprocessing(self):
inp = inputs.ImageUpload(image_height=48, image_width=48)
array = inp.preprocess(BASE64_IMG)
self.assertEqual(array.shape, (1, 48, 48, 3))
if __name__ == '__main__':

View File

@ -10,6 +10,16 @@ class TestInterface(unittest.TestCase):
self.assertIsInstance(io.input_interface, gradio.inputs.Sketchpad)
self.assertIsInstance(io.output_interface, gradio.outputs.Textbox)
def test_input_interface_is_instance(self):
inp = gradio.inputs.ImageUpload()
io = Interface(inputs=inp, outputs='textBOX', model=lambda x: x, model_type='function')
self.assertEqual(io.input_interface, inp)
def test_output_interface_is_instance(self):
out = gradio.outputs.Label(show_confidences=False)
io = Interface(inputs='SketCHPad', outputs=out, model=lambda x: x, model_type='function')
self.assertEqual(io.output_interface, out)
if __name__ == '__main__':
unittest.main()

View File

@ -2,6 +2,7 @@ import numpy as np
import unittest
import os
from gradio import outputs
import json
PACKAGE_NAME = 'gradio'
@ -16,28 +17,40 @@ class TestLabel(unittest.TestCase):
string = 'happy'
out = outputs.Label()
label = out.postprocess(string)
self.assertEqual(label, string)
self.assertDictEqual(label, {outputs.Label.LABEL_KEY: string})
def test_postprocessing_one_hot(self):
one_hot = np.array([0, 0, 0, 1, 0])
true_label = 3
def test_postprocessing_1D_array(self):
array = np.array([0.1, 0.2, 0, 0.7, 0])
true_label = {outputs.Label.LABEL_KEY: 3,
outputs.Label.CONFIDENCES_KEY: [
{outputs.Label.LABEL_KEY: 3, outputs.Label.CONFIDENCE_KEY: 0.7},
{outputs.Label.LABEL_KEY: 1, outputs.Label.CONFIDENCE_KEY: 0.2},
{outputs.Label.LABEL_KEY: 0, outputs.Label.CONFIDENCE_KEY: 0.1},
]}
out = outputs.Label()
label = out.postprocess(one_hot)
self.assertEqual(label, true_label)
label = out.postprocess(array)
self.assertDictEqual(label, true_label)
def test_postprocessing_1D_array_no_confidences(self):
array = np.array([0.1, 0.2, 0, 0.7, 0])
true_label = {outputs.Label.LABEL_KEY: 3}
out = outputs.Label(show_confidences=False)
label = out.postprocess(array)
self.assertDictEqual(label, true_label)
def test_postprocessing_int(self):
true_label_array = np.array([[[3]]])
true_label = 3
true_label = {outputs.Label.LABEL_KEY: 3}
out = outputs.Label()
label = out.postprocess(true_label_array)
self.assertEqual(label, true_label)
self.assertDictEqual(label, true_label)
class TestTextbox(unittest.TestCase):
def test_path_exists(self):
out = outputs.Textbox()
path = out.get_template_path()
self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
# self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path)))
def test_postprocessing(self):
string = 'happy'