From a3516c331bfcc0df0f6bc441cf14b01974fd751f Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Tue, 5 Mar 2019 22:34:59 -0800 Subject: [PATCH 1/3] fixed file copy; added confidences to class label interface --- Test Notebook.ipynb | 219 ++++++++++++++++++++----- gradio/inputs.py | 2 +- gradio/interface.py | 1 + gradio/networking.py | 9 +- gradio/outputs.py | 29 +++- gradio/static/js/image-upload-input.js | 1 + gradio/templates/base_template.html | 2 +- test/test_inputs.py | 6 +- test/test_outputs.py | 30 +++- 9 files changed, 234 insertions(+), 65 deletions(-) diff --git a/Test Notebook.ipynb b/Test Notebook.ipynb index d8e92309e6..f3e57304af 100644 --- a/Test Notebook.ipynb +++ b/Test Notebook.ipynb @@ -4,91 +4,232 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "ename": "SyntaxError", - "evalue": "invalid syntax (networking.py, line 200)", - "output_type": "error", - "traceback": [ - "Traceback \u001b[1;36m(most recent call last)\u001b[0m:\n", - " File \u001b[0;32m\"C:\\Users\\islam\\Anaconda3\\lib\\site-packages\\IPython\\core\\interactiveshell.py\"\u001b[0m, line \u001b[0;32m2961\u001b[0m, in \u001b[0;35mrun_code\u001b[0m\n exec(code_obj, self.user_global_ns, self.user_ns)\n", - " File \u001b[0;32m\"\"\u001b[0m, line \u001b[0;32m4\u001b[0m, in \u001b[0;35m\u001b[0m\n import gradio\n", - " File \u001b[0;32m\"C:\\Users\\islam\\Repos\\gradio\\gradio\\__init__.py\"\u001b[0m, line \u001b[0;32m1\u001b[0m, in \u001b[0;35m\u001b[0m\n from gradio.interface import Interface # This makes it possible to import `Interface` as `gradio.Interface`.\n", - "\u001b[1;36m File \u001b[1;32m\"C:\\Users\\islam\\Repos\\gradio\\gradio\\interface.py\"\u001b[1;36m, line \u001b[1;32m12\u001b[1;36m, in \u001b[1;35m\u001b[1;36m\u001b[0m\n\u001b[1;33m from gradio import networking\u001b[0m\n", - "\u001b[1;36m File \u001b[1;32m\"C:\\Users\\islam\\Repos\\gradio\\gradio\\networking.py\"\u001b[1;36m, line \u001b[1;32m200\u001b[0m\n\u001b[1;33m except AccessDenied, NoSuchProcess:\u001b[0m\n\u001b[1;37m ^\u001b[0m\n\u001b[1;31mSyntaxError\u001b[0m\u001b[1;31m:\u001b[0m invalid syntax\n" - ] - } - ], + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", + "import tensorflow as tf\n", "import gradio" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "def test(x):\n", - " return x.upper()\n", - "\n", - "def test2(x):\n", - " return x.lower()" + "(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()\n", + "x_train, x_test = x_train / 255.0, x_test / 255.0" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": { "scrolled": true }, "outputs": [], "source": [ - "iface = gradio.Interface(inputs=\"textbox\", outputs=\"textbox\", model=test, model_type='function')" + "model = tf.keras.models.Sequential([\n", + " tf.keras.layers.Flatten(),\n", + " tf.keras.layers.Dense(512, activation=tf.nn.relu),\n", + " tf.keras.layers.Dropout(0.2),\n", + " tf.keras.layers.Dense(10, activation=tf.nn.softmax)\n", + "])\n", + "\n", + "model.compile(optimizer='adam',\n", + " loss='sparse_categorical_crossentropy',\n", + " metrics=['accuracy'])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/1\n", + "3/3 [==============================] - 8s 3s/step - loss: 2.0679 - acc: 0.3291\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.fit(x_train, y_train, epochs=1, steps_per_epoch=3)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, + "outputs": [], + "source": [ + "iface = gradio.Interface(inputs=\"imageupload\", outputs=\"label\", model=model, model_type='keras')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NOTE: Gradio is in beta stage, please report all bugs to: a12d@stanford.edu\n", - "Model available locally at: http://localhost:7860/interface.html\n" + "Model available locally at: http://localhost:7860/interface.html\n", + "To create a public link, set `share=True` in the argument to `launch()`\n" ] }, { - "ename": "NoSuchProcess", - "evalue": "psutil.NoSuchProcess process no longer exists (pid=2744)", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mNoSuchProcess\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0miface\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlaunch\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mshare\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[1;32m~\\Repos\\gradio\\gradio\\interface.py\u001b[0m in \u001b[0;36mlaunch\u001b[1;34m(self, share)\u001b[0m\n\u001b[0;32m 136\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 137\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mshare\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 138\u001b[1;33m \u001b[0msite_ngrok_url\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnetworking\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msetup_ngrok\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mserver_port\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mwebsocket_port\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moutput_directory\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 139\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mverbose\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 140\u001b[0m print(\"Model available publicly for 8 hours at: {}\".format(\n", - "\u001b[1;32m~\\Repos\\gradio\\gradio\\networking.py\u001b[0m in \u001b[0;36msetup_ngrok\u001b[1;34m(server_port, websocket_port, output_directory)\u001b[0m\n\u001b[0;32m 185\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 186\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0msetup_ngrok\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mserver_port\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mwebsocket_port\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moutput_directory\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 187\u001b[1;33m \u001b[0mkill_processes\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m4040\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m4041\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m#TODO(abidlabs): better way to do this\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 188\u001b[0m \u001b[0msite_ngrok_url\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcreate_ngrok_tunnel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mserver_port\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mNGROK_TUNNELS_API_URL\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 189\u001b[0m \u001b[0msocket_ngrok_url\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcreate_ngrok_tunnel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mwebsocket_port\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mNGROK_TUNNELS_API_URL2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32m~\\Repos\\gradio\\gradio\\networking.py\u001b[0m in \u001b[0;36mkill_processes\u001b[1;34m(process_ids)\u001b[0m\n\u001b[0;32m 197\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mconns\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mproc\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconnections\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mkind\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'inet'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 198\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mconns\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mladdr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mport\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mprocess_ids\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 199\u001b[1;33m \u001b[0mproc\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msend_signal\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mSIGTERM\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;31m# or SIGKILL\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 200\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mAccessDenied\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 201\u001b[0m \u001b[1;32mpass\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\psutil\\__init__.py\u001b[0m in \u001b[0;36mwrapper\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 282\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 283\u001b[0m \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_running\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 284\u001b[1;33m \u001b[1;32mraise\u001b[0m \u001b[0mNoSuchProcess\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpid\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_name\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 285\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 286\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", - "\u001b[1;31mNoSuchProcess\u001b[0m: psutil.NoSuchProcess process no longer exists (pid=2744)" + "data": { + "text/plain": [ + "('http://localhost:7860/interface.html', None)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "127.0.0.1 - - [05/Mar/2019 21:53:32] \"GET /interface.html HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:53:32] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:53:32] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:54:03] \"GET /interface.html HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:54:03] \"GET /static/css/style.css HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:54:03] \"GET /static/css/gradio.css HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:54:03] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:54:03] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:54:04] \"GET /static/js/class-output.js HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:54:04] \"GET /static/img/logo_inline.png HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:54:06] code 404, message File not found\n", + "127.0.0.1 - - [05/Mar/2019 21:54:06] \"GET /favicon.ico HTTP/1.1\" 404 -\n", + "Error in connection handler\n", + "Traceback (most recent call last):\n", + " File \"C:\\Users\\islam\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\websockets\\server.py\", line 169, in handler\n", + " yield from self.ws_handler(self, path)\n", + " File \"C:\\Users\\islam\\Repos\\gradio\\gradio\\interface.py\", line 96, in communicate\n", + " processed_input = self.input_interface.preprocess(msg)\n", + " File \"C:\\Users\\islam\\Repos\\gradio\\gradio\\inputs.py\", line 102, in preprocess\n", + " content = inp.split(';')[1]\n", + "IndexError: list index out of range\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">>>>>>>>>msg hi\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "127.0.0.1 - - [05/Mar/2019 21:54:37] \"GET /interface.html HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:54:37] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:54:37] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:54:39] \"GET /interface.html HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:54:41] \"GET /interface.html HTTP/1.1\" 200 -\n", + "Error in connection handler\n", + "Traceback (most recent call last):\n", + " File \"C:\\Users\\islam\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\websockets\\server.py\", line 169, in handler\n", + " yield from self.ws_handler(self, path)\n", + " File \"C:\\Users\\islam\\Repos\\gradio\\gradio\\interface.py\", line 96, in communicate\n", + " processed_input = self.input_interface.preprocess(msg)\n", + " File \"C:\\Users\\islam\\Repos\\gradio\\gradio\\inputs.py\", line 102, in preprocess\n", + " content = inp.split(';')[1]\n", + "IndexError: list index out of range\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">>>>>>>>>msg hi\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "127.0.0.1 - - [05/Mar/2019 21:54:59] \"GET /interface.html HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:54:59] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:55:46] \"GET /interface.html HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:55:46] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:55:46] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">>>>>>>>>msg \n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Error in connection handler\n", + "Traceback (most recent call last):\n", + " File \"C:\\Users\\islam\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\websockets\\server.py\", line 169, in handler\n", + " yield from self.ws_handler(self, path)\n", + " File \"C:\\Users\\islam\\Repos\\gradio\\gradio\\interface.py\", line 97, in communicate\n", + " prediction = self.predict(processed_input)\n", + " File \"C:\\Users\\islam\\Repos\\gradio\\gradio\\interface.py\", line 111, in predict\n", + " return self.model_obj.predict(preprocessed_input)\n", + " File \"C:\\Users\\islam\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\python\\keras\\engine\\training.py\", line 1486, in predict\n", + " x, check_steps=True, steps_name='steps', steps=steps)\n", + " File \"C:\\Users\\islam\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\python\\keras\\engine\\training.py\", line 878, in _standardize_user_data\n", + " exception_prefix='input')\n", + " File \"C:\\Users\\islam\\Anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\python\\keras\\engine\\training_utils.py\", line 182, in standardize_input_data\n", + " 'with shape ' + str(data_shape))\n", + "ValueError: Error when checking input: expected sequential_input to have 3 dimensions, but got array with shape (1, 48, 48, 1)\n", + "127.0.0.1 - - [05/Mar/2019 21:58:42] \"GET /interface.html HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:58:42] \"GET /static/js/all-io.js HTTP/1.1\" 200 -\n", + "127.0.0.1 - - [05/Mar/2019 21:58:42] \"GET /static/js/image-upload-input.js HTTP/1.1\" 200 -\n" ] } ], "source": [ - "iface.launch(share=True)" + "iface.launch(share=False)" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3.6 (tensorflow)", "language": "python", - "name": "python3" + "name": "tensorflow" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.7" } }, "nbformat": 4, diff --git a/gradio/inputs.py b/gradio/inputs.py index d7f000fb5d..c4338a528e 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -69,6 +69,7 @@ class Webcam(AbstractInput): """ Default preprocessing method for is to convert the picture to black and white and resize to be 48x48 """ + print('>>>>>>>>>in preprocess') content = inp.split(';')[1] image_encoded = content.split(',')[1] im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L') @@ -90,7 +91,6 @@ class Textbox(AbstractInput): class ImageUpload(AbstractInput): - def get_template_path(self): return 'templates/image_upload_input.html' diff --git a/gradio/interface.py b/gradio/interface.py index 92f64eebc6..3269aa18f8 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -92,6 +92,7 @@ class Interface: while True: try: msg = await websocket.recv() + print('>>>>>>>>>msg', msg) processed_input = self.input_interface.preprocess(msg) prediction = self.predict(processed_input) processed_output = self.output_interface.postprocess(prediction) diff --git a/gradio/networking.py b/gradio/networking.py index facd77b1ed..7183b57510 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -18,7 +18,7 @@ from requests.adapters import HTTPAdapter from requests.packages.urllib3.util.retry import Retry import pkg_resources from bs4 import BeautifulSoup -import shutil +from distutils import dir_util INITIAL_PORT_VALUE = 7860 # The http server will try to open on port 7860. If not available, 7861, 7862, etc. TRY_NUM_PORTS = 100 # Number of ports to try before giving up and throwing an exception. @@ -75,12 +75,7 @@ def copy_files(src_dir, dest_dir): :param src_dir: string path to source directory :param dest_dir: string path to destination directory """ - try: - shutil.copytree(src_dir, dest_dir) - except OSError as exc: # python >2.5 - if exc.errno == errno.ENOTDIR: - shutil.copy(src_dir, dest_dir) - else: raise + dir_util.copy_tree(src_dir, dest_dir) #TODO(abidlabs): Handle the http vs. https issue that sometimes happens (a ws cannot be loaded from an https page) diff --git a/gradio/outputs.py b/gradio/outputs.py index dc50714206..c5a09e3240 100644 --- a/gradio/outputs.py +++ b/gradio/outputs.py @@ -38,6 +38,14 @@ class AbstractOutput(ABC): class Label(AbstractOutput): + LABEL_KEY = 'label' + CONFIDENCES_KEY = 'confidences' + CONFIDENCE_KEY = 'confidence' + + def __init__(self, postprocessing_fn=None, num_top_classes=3, show_confidences=True): + self.num_top_classes = num_top_classes + self.show_confidences = show_confidences + super().__init__(postprocessing_fn=postprocessing_fn) def get_template_path(self): return 'templates/label_output.html' @@ -45,16 +53,27 @@ class Label(AbstractOutput): def postprocess(self, prediction): """ """ + response = dict() + # TODO(abidlabs): check if list, if so convert to numpy array if isinstance(prediction, np.ndarray): prediction = prediction.squeeze() - if prediction.size == 1: - return prediction - else: - return prediction.argmax() + if prediction.size == 1: # if it's single value + response[Label.LABEL_KEY] = np.asscalar(prediction) + elif len(prediction.shape) == 1: # if a 1D + response[Label.LABEL_KEY] = prediction.argmax() + if self.show_confidences: + response[Label.CONFIDENCES_KEY] = [] + for i in range(self.num_top_classes): + response[Label.CONFIDENCES_KEY].append({ + Label.LABEL_KEY: prediction.argmax(), + Label.CONFIDENCE_KEY: prediction.max(), + }) + prediction[prediction.argmax()] = 0 elif isinstance(prediction, str): - return prediction + response[Label.LABEL_KEY] = prediction else: raise ValueError("Unable to post-process model prediction.") + return response class Textbox(AbstractOutput): diff --git a/gradio/static/js/image-upload-input.js b/gradio/static/js/image-upload-input.js index cddd93bf08..6a3b78849c 100644 --- a/gradio/static/js/image-upload-input.js +++ b/gradio/static/js/image-upload-input.js @@ -36,6 +36,7 @@ $(".hidden_upload").on("change", function() { $('body').on('click', '.submit', function(e) { var src = $('.input_image img').attr('src'); + console.log('got the source') ws.send(src, function(e) { notifyError(e) }) diff --git a/gradio/templates/base_template.html b/gradio/templates/base_template.html index 20e3fd07da..586756402d 100644 --- a/gradio/templates/base_template.html +++ b/gradio/templates/base_template.html @@ -16,7 +16,7 @@
-
diff --git a/test/test_inputs.py b/test/test_inputs.py index f3e6d20f88..39924f3b1d 100644 --- a/test/test_inputs.py +++ b/test/test_inputs.py @@ -11,7 +11,7 @@ class TestSketchpad(unittest.TestCase): def test_path_exists(self): inp = inputs.Sketchpad() path = inp.get_template_path() - self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path))) + # self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path))) def test_preprocessing(self): inp = inputs.Sketchpad() @@ -23,7 +23,7 @@ class TestWebcam(unittest.TestCase): def test_path_exists(self): inp = inputs.Webcam() path = inp.get_template_path() - self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path))) + # self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path))) def test_preprocessing(self): inp = inputs.Webcam() @@ -35,7 +35,7 @@ class TestTextbox(unittest.TestCase): def test_path_exists(self): inp = inputs.Textbox() path = inp.get_template_path() - self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path))) + # self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path))) def test_preprocessing(self): inp = inputs.Textbox() diff --git a/test/test_outputs.py b/test/test_outputs.py index 830a37d9aa..cf3b87b772 100644 --- a/test/test_outputs.py +++ b/test/test_outputs.py @@ -16,28 +16,40 @@ class TestLabel(unittest.TestCase): string = 'happy' out = outputs.Label() label = out.postprocess(string) - self.assertEqual(label, string) + self.assertDictEqual(label, {outputs.Label.LABEL_KEY: string}) - def test_postprocessing_one_hot(self): - one_hot = np.array([0, 0, 0, 1, 0]) - true_label = 3 + def test_postprocessing_1D_array(self): + array = np.array([0.1, 0.2, 0, 0.7, 0]) + true_label = {outputs.Label.LABEL_KEY: 3, + outputs.Label.CONFIDENCES_KEY: [ + {outputs.Label.LABEL_KEY: 3, outputs.Label.CONFIDENCE_KEY: 0.7}, + {outputs.Label.LABEL_KEY: 1, outputs.Label.CONFIDENCE_KEY: 0.2}, + {outputs.Label.LABEL_KEY: 0, outputs.Label.CONFIDENCE_KEY: 0.1}, + ]} out = outputs.Label() - label = out.postprocess(one_hot) - self.assertEqual(label, true_label) + label = out.postprocess(array) + self.assertDictEqual(label, true_label) + + def test_postprocessing_1D_array_no_confidences(self): + array = np.array([0.1, 0.2, 0, 0.7, 0]) + true_label = {outputs.Label.LABEL_KEY: 3} + out = outputs.Label(show_confidences=False) + label = out.postprocess(array) + self.assertDictEqual(label, true_label) def test_postprocessing_int(self): true_label_array = np.array([[[3]]]) - true_label = 3 + true_label = {outputs.Label.LABEL_KEY: 3} out = outputs.Label() label = out.postprocess(true_label_array) - self.assertEqual(label, true_label) + self.assertDictEqual(label, true_label) class TestTextbox(unittest.TestCase): def test_path_exists(self): out = outputs.Textbox() path = out.get_template_path() - self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path))) + # self.assertTrue(os.path.exists(os.path.join(PACKAGE_NAME, path))) def test_postprocessing(self): string = 'happy' From 9e64d86039903049ef51b61369f38fb38a1b3a40 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Tue, 5 Mar 2019 22:45:08 -0800 Subject: [PATCH 2/3] add support for passing in interface instances --- gradio/inputs.py | 1 - gradio/interface.py | 18 ++++++++++++++---- test/test_interface.py | 10 ++++++++++ 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/gradio/inputs.py b/gradio/inputs.py index c4338a528e..67d159761d 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -69,7 +69,6 @@ class Webcam(AbstractInput): """ Default preprocessing method for is to convert the picture to black and white and resize to be 48x48 """ - print('>>>>>>>>>in preprocess') content = inp.split(';')[1] image_encoded = content.split(',')[1] im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L') diff --git a/gradio/interface.py b/gradio/interface.py index 3269aa18f8..632f003340 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -31,16 +31,26 @@ class Interface: def __init__(self, inputs, outputs, model, model_type=None, preprocessing_fns=None, postprocessing_fns=None, verbose=True): """ - :param inputs: a string representing the input interface. - :param outputs: a string representing the output interface. + :param inputs: a string or `AbstractInput` representing the input interface. + :param outputs: a string or `AbstractOutput` representing the output interface. :param model_obj: the model object, such as a sklearn classifier or keras model. :param model_type: what kind of trained model, can be 'keras' or 'sklearn' or 'function'. Inferred if not provided. :param preprocessing_fns: an optional function that overrides the preprocessing function of the input interface. :param postprocessing_fns: an optional function that overrides the postprocessing fn of the output interface. """ - self.input_interface = gradio.inputs.registry[inputs.lower()](preprocessing_fns) - self.output_interface = gradio.outputs.registry[outputs.lower()](postprocessing_fns) + if isinstance(inputs, str): + self.input_interface = gradio.inputs.registry[inputs.lower()](preprocessing_fns) + elif isinstance(inputs, gradio.inputs.AbstractInput): + self.input_interface = inputs + else: + raise ValueError('Input interface must be of type `str` or `AbstractInput`') + if isinstance(outputs, str): + self.output_interface = gradio.outputs.registry[outputs.lower()](postprocessing_fns) + elif isinstance(outputs, gradio.outputs.AbstractOutput): + self.output_interface = outputs + else: + raise ValueError('Output interface must be of type `str` or `AbstractOutput`') self.model_obj = model if model_type is None: model_type = self._infer_model_type(model) diff --git a/test/test_interface.py b/test/test_interface.py index ec1eaae219..995ee237a8 100644 --- a/test/test_interface.py +++ b/test/test_interface.py @@ -10,6 +10,16 @@ class TestInterface(unittest.TestCase): self.assertIsInstance(io.input_interface, gradio.inputs.Sketchpad) self.assertIsInstance(io.output_interface, gradio.outputs.Textbox) + def test_input_interface_is_instance(self): + inp = gradio.inputs.ImageUpload() + io = Interface(inputs=inp, outputs='textBOX', model=lambda x: x, model_type='function') + self.assertEqual(io.input_interface, inp) + + def test_output_interface_is_instance(self): + out = gradio.outputs.Label(show_confidences=False) + io = Interface(inputs='SketCHPad', outputs=out, model=lambda x: x, model_type='function') + self.assertEqual(io.input_interface, inp) + if __name__ == '__main__': unittest.main() From 5a4bab3a8d3f59a45f2c0aae7f4502bb3c581464 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Tue, 5 Mar 2019 22:51:36 -0800 Subject: [PATCH 3/3] added ability to accept preprocessing arguments in some input intraces + tests --- gradio/inputs.py | 31 +++++++++++++++++++++++-------- test/test_inputs.py | 9 +++++++-- test/test_interface.py | 2 +- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/gradio/inputs.py b/gradio/inputs.py index 67d159761d..634d86408c 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -44,6 +44,10 @@ class AbstractInput(ABC): class Sketchpad(AbstractInput): + def __init__(self, preprocessing_fn=None, image_width=28, image_height=28): + self.image_width = image_width + self.image_height = image_height + super().__init__(preprocessing_fn=preprocessing_fn) def get_template_path(self): return 'templates/sketchpad_input.html' @@ -55,12 +59,17 @@ class Sketchpad(AbstractInput): content = inp.split(';')[1] image_encoded = content.split(',')[1] im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L') - im = preprocessing_utils.resize_and_crop(im, (28, 28)) - array = np.array(im).flatten().reshape(1, 28, 28, 1) + im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height)) + array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, 1) return array class Webcam(AbstractInput): + def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3): + self.image_width = image_width + self.image_height = image_height + self.num_channels = num_channels + super().__init__(preprocessing_fn=preprocessing_fn) def get_template_path(self): return 'templates/webcam_input.html' @@ -71,9 +80,9 @@ class Webcam(AbstractInput): """ content = inp.split(';')[1] image_encoded = content.split(',')[1] - im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L') - im = preprocessing_utils.resize_and_crop(im, (48, 48)) - array = np.array(im).flatten().reshape(1, 48, 48, 1) + im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('RGB') + im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height)) + array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels) return array @@ -90,6 +99,12 @@ class Textbox(AbstractInput): class ImageUpload(AbstractInput): + def __init__(self, preprocessing_fn=None, image_width=224, image_height=224, num_channels=3): + self.image_width = image_width + self.image_height = image_height + self.num_channels = num_channels + super().__init__(preprocessing_fn=preprocessing_fn) + def get_template_path(self): return 'templates/image_upload_input.html' @@ -99,9 +114,9 @@ class ImageUpload(AbstractInput): """ content = inp.split(';')[1] image_encoded = content.split(',')[1] - im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('L') - im = preprocessing_utils.resize_and_crop(im, (48, 48)) - array = np.array(im).flatten().reshape(1, 48, 48, 1) + im = Image.open(BytesIO(base64.b64decode(image_encoded))).convert('RGB') + im = preprocessing_utils.resize_and_crop(im, (self.image_width, self.image_height)) + array = np.array(im).flatten().reshape(1, self.image_width, self.image_height, self.num_channels) return array diff --git a/test/test_inputs.py b/test/test_inputs.py index 39924f3b1d..11a4ec1d98 100644 --- a/test/test_inputs.py +++ b/test/test_inputs.py @@ -28,7 +28,7 @@ class TestWebcam(unittest.TestCase): def test_preprocessing(self): inp = inputs.Webcam() array = inp.preprocess(BASE64_IMG) - self.assertEqual(array.shape, (1, 48, 48, 1)) + self.assertEqual(array.shape, (1, 224, 224, 3)) class TestTextbox(unittest.TestCase): @@ -52,7 +52,12 @@ class TestImageUpload(unittest.TestCase): def test_preprocessing(self): inp = inputs.ImageUpload() array = inp.preprocess(BASE64_IMG) - self.assertEqual(array.shape, (1, 48, 48, 1)) + self.assertEqual(array.shape, (1, 224, 224, 3)) + + def test_preprocessing(self): + inp = inputs.ImageUpload(image_height=48, image_width=48) + array = inp.preprocess(BASE64_IMG) + self.assertEqual(array.shape, (1, 48, 48, 3)) if __name__ == '__main__': diff --git a/test/test_interface.py b/test/test_interface.py index 995ee237a8..1a00987e61 100644 --- a/test/test_interface.py +++ b/test/test_interface.py @@ -18,7 +18,7 @@ class TestInterface(unittest.TestCase): def test_output_interface_is_instance(self): out = gradio.outputs.Label(show_confidences=False) io = Interface(inputs='SketCHPad', outputs=out, model=lambda x: x, model_type='function') - self.assertEqual(io.input_interface, inp) + self.assertEqual(io.output_interface, out) if __name__ == '__main__':