From 998a5fb1563a290b897959baae23547d395b8f06 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 16 Aug 2021 21:58:30 +0000 Subject: [PATCH] updated PyPi version --- demo/calculator.py | 7 +- demo/gradio_queue.db | Bin 0 -> 16384 bytes frontend/src/components/base_component.jsx | 2 +- frontend/src/components/input/checkbox.jsx | 1 - frontend/src/components/input/textbox.jsx | 1 - frontend/src/gradio.jsx | 7 +- frontend/src/index.jsx | 55 +++++++-- frontend/src/themes/defaults.scss | 3 +- gradio/encryptor.py | 1 - gradio/frontend/asset-manifest.json | 4 +- gradio/frontend/index.html | 2 +- gradio/inputs.py | 6 - gradio/interface.py | 13 +- gradio/networking.py | 36 +++++- gradio/queue.py | 133 +++++++++++++++++++++ gradio/version.txt | 2 +- setup.py | 2 +- 17 files changed, 241 insertions(+), 34 deletions(-) create mode 100644 demo/gradio_queue.db create mode 100644 gradio/queue.py diff --git a/demo/calculator.py b/demo/calculator.py index b8a6712608..1c6591172b 100644 --- a/demo/calculator.py +++ b/demo/calculator.py @@ -1,8 +1,10 @@ import gradio as gr import random +import time def calculator(num1, operation, num2): print(num1, operation, num2) + time.sleep(10) if operation == "add": return num1 + num2 elif operation == "subtract": @@ -23,8 +25,9 @@ iface = gr.Interface(calculator, ], title="test calculator", description="heres a sample toy calculator. enjoy!", - flagging_options=["this", "or", "that"] + flagging_options=["this", "or", "that"], + enable_queue=True ) if __name__ == "__main__": - iface.launch() + iface.launch(share=True) diff --git a/demo/gradio_queue.db b/demo/gradio_queue.db new file mode 100644 index 0000000000000000000000000000000000000000..29925d36f0db5a79eaac2dc5203e88a411f88f11 GIT binary patch literal 16384 zcmeI1U27aw7{_O`d2e>6UW5oD45?7X>@w&5oc2P~cB{o~jWN9_61L~HO=+^}?kp6f z7wxqm_!;~Nf=IrEUIzRGeh77DWVem9nOtgk4xHJU^ZfUD`JK5r&wTyW)_(5OFc==_ zJSEZPsL_Z%P17ifR553ZIT&FjyNiQ$Z26?Hy2X(iX9W^AM6Mr@OJ&QOC12mk>f z00e*l5C8%|00;m9AOHmZ&jen4JhQO2y4rZDa&4aY#{)AeXY=MxcXPL!?rz@L>ZavN z`r&f=+IXZ#kJ8*f%hTx6yao%@2*Ke@;;czhQ zANbKoKlGJ->h+_m&H1(KtBq*?(D`Shr%#I0abJ(~LG|1()9j=2TAqX#%WZjay|J*i zw$^xguPX29*pGd6nK@Npwe+tFtF?aRIo&_Z{X;)|d#zI(rz_n*JfqN~!O@X-Q{T6` zw>R%??WWnQa$98ym!l{-Nsg1__?Pk#1`q%OKmZ5;0U!VbfB+Bx0zd!=0D=Dyfkyd3 z`ht$D=c}uGrtCaUTJayzN^2!b#>wsExA-{zF5Zhft-o5oChxWm3j+)w00e*l5C8%| z00;m9AOHk_zy&1mU}diN{wN~BYcCzjFya&iW2EptV+J`=+*p3|%ii`@cenfPdii;I z{aX4>8+GvBM!N3CL!Ivr4oBr8$~qa!C>EG-#zdBhz2^@uRS~SEQX#}8;hGXgWFcZy z6YG%|*y}e%D7?KT1QL`XEIC1Z?|FNvh|r3dw2oL|gkqX1N7Bb|Wr-B;ok0#xSvi%@S8~og zV@?WcgQQ*s&82aH1Wb%|Tx3)ETqP$~7nufbttC_=tA!Ov5sEMq%)2R__G~3rA|T|A z4~R)>JW|YK6Fl*hWP)Nl@t*t`b(u5{we-G z`60QI93;QTkK>d0&*bNG#|0q}00KY&2mk>f00e*l5C8%|00>-s0v|Mcjk>PMN-3&q zd|WC;b!YD7N>tO}SSm$zZH>iBRMX5@C`ENGjQLVj*SMG~MRo0p*(zMk>AhKs>beax LSDL-0x*o&d6(6Qy literal 0 HcmV?d00001 diff --git a/frontend/src/components/base_component.jsx b/frontend/src/components/base_component.jsx index 48c555d2a3..d0853f28a8 100644 --- a/frontend/src/components/base_component.jsx +++ b/frontend/src/components/base_component.jsx @@ -2,6 +2,6 @@ import React from "react"; export default class BaseComponent extends React.Component { static memo = (a, b) => { - return a.value === b.value; + return a.value === b.value && a.interpretation === b.interpretation; }; } diff --git a/frontend/src/components/input/checkbox.jsx b/frontend/src/components/input/checkbox.jsx index 8582dddc0e..85812a8052 100644 --- a/frontend/src/components/input/checkbox.jsx +++ b/frontend/src/components/input/checkbox.jsx @@ -12,7 +12,6 @@ class CheckboxInput extends BaseComponent { this.props.handleChange(this.props.value !== true); }; render() { - console.log("render checkbox"); return (
diff --git a/frontend/src/gradio.jsx b/frontend/src/gradio.jsx index 3f188be13f..2840d2b105 100644 --- a/frontend/src/gradio.jsx +++ b/frontend/src/gradio.jsx @@ -82,6 +82,7 @@ export class GradioInterface extends React.Component { state["has_changed"] = false; state["example_id"] = null; state["flag_index"] = null; + state["queue_index"] = null; return state; }; clear = () => { @@ -102,7 +103,7 @@ export class GradioInterface extends React.Component { flag_index: null }); this.props - .fn(input_state, "predict") + .fn(input_state, "predict", this.queueCallback) .then((output) => { let index_start = this.props.input_components.length; let new_state = {}; @@ -182,6 +183,9 @@ export class GradioInterface extends React.Component { removeInterpret = () => { this.setState({ interpretation: null }); }; + queueCallback = (queue_index) => { + this.setState({"queue_index": queue_index}); + } takeScreenshot = () => { html2canvas(ReactDOM.findDOMNode(this).parentNode).then((canvas) => { saveAs(canvas.toDataURL(), "screenshot.png"); @@ -220,6 +224,7 @@ export class GradioInterface extends React.Component { if (this.state.submitting) { status = (
+ {this.state.queue_index !== null ? "queued @ " + this.state.queue_index : false} loading
); diff --git a/frontend/src/index.jsx b/frontend/src/index.jsx index fd9cc71088..e857d81f79 100644 --- a/frontend/src/index.jsx +++ b/frontend/src/index.jsx @@ -3,18 +3,57 @@ import ReactDOM from "react-dom"; import { GradioPage } from "./gradio"; import Login from "./login"; -let fn = async (data, action) => { +function delay(n) { + return new Promise(function(resolve){ + setTimeout(resolve,n*1000); + }); +} + +let postData = async (url, body) => { const output = await fetch( - process.env.REACT_APP_BACKEND_URL + "api/" + action + "/", + process.env.REACT_APP_BACKEND_URL + url, { method: "POST", - body: JSON.stringify({ data: data }), - headers: { - "Content-Type": "application/json" - } + body: JSON.stringify(body), + headers: {"Content-Type": "application/json"} } ); - return await output.json(); + return output; +} + +let fn = async (queue, data, action, queue_callback) => { + if (queue && action == "predict") { + const output = await postData( + "api/queue/push/", { data: data }, + ); + let hash = await output.text(); + let status = "UNKNOWN"; + while (status != "COMPLETE" && status != "FAILED") { + if (status != "UNKNOWN") { + await delay(1); + } + const status_response = await postData( + "api/queue/status/", { hash: hash }, + ); + var status_obj = await status_response.json(); + status = status_obj["status"]; + if (status === "QUEUED") { + queue_callback(status_obj["data"]); + } else if (status === "PENDING") { + queue_callback(null); + } + } + if (status == "FAILED") { + throw new Error(status); + } else { + return status_obj["data"]; + } + } else { + const output = await postData( + "api/" + action + "/", { data: data }, + ); + return await output.json(); + } }; async function get_config() { @@ -39,7 +78,7 @@ get_config().then((config) => { style.appendChild(document.createTextNode(config.css)); } ReactDOM.render( - , + , document.getElementById("root") ); } diff --git a/frontend/src/themes/defaults.scss b/frontend/src/themes/defaults.scss index 21b6135328..ab91daa9ef 100644 --- a/frontend/src/themes/defaults.scss +++ b/frontend/src/themes/defaults.scss @@ -48,7 +48,7 @@ @apply absolute right-1; } .loading img { - @apply h-5; + @apply h-5 ml-2 inline-block; } .panels { @apply flex flex-wrap justify-center gap-4; @@ -67,6 +67,7 @@ } .component_set { @apply bg-gray-50 p-2 rounded flex flex-col flex-1 gap-2; + min-height: 36px; } .panel_header { @apply mb-1 uppercase text-sm font-semibold; diff --git a/gradio/encryptor.py b/gradio/encryptor.py index 3eaa983d30..739c231bea 100644 --- a/gradio/encryptor.py +++ b/gradio/encryptor.py @@ -1,4 +1,3 @@ -import base64 from Crypto.Cipher import AES from Crypto.Hash import SHA256 from Crypto import Random diff --git a/gradio/frontend/asset-manifest.json b/gradio/frontend/asset-manifest.json index 4f2d8aba95..3a956a2887 100644 --- a/gradio/frontend/asset-manifest.json +++ b/gradio/frontend/asset-manifest.json @@ -1,6 +1,6 @@ { "files": { - "main.css": "/static/css/main.22c25130.css", + "main.css": "/static/css/main.0a917430.css", "main.js": "/static/bundle.js", "index.html": "/index.html", "static/bundle.js.LICENSE.txt": "/static/bundle.js.LICENSE.txt", @@ -9,7 +9,7 @@ }, "entrypoints": [ "static/bundle.css", - "static/css/main.22c25130.css", + "static/css/main.0a917430.css", "static/bundle.js" ] } \ No newline at end of file diff --git a/gradio/frontend/index.html b/gradio/frontend/index.html index 695d6e2f32..406837ea6d 100644 --- a/gradio/frontend/index.html +++ b/gradio/frontend/index.html @@ -8,4 +8,4 @@ window.config = {{ config|tojson }}; } catch (e) { window.config = {}; - }Gradio
\ No newline at end of file + }Gradio
\ No newline at end of file diff --git a/gradio/inputs.py b/gradio/inputs.py index c13ea547aa..988a962b71 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -4,14 +4,9 @@ This module defines various classes that can serve as the `input` to an interfac automatically added to a registry, which allows them to be easily referenced in other parts of the code. """ -import datetime import json -import os -import shutil -import time import warnings from gradio.component import Component -import base64 import numpy as np import PIL import scipy.io.wavfile @@ -20,7 +15,6 @@ import pandas as pd from ffmpy import FFmpeg import math import tempfile -from pandas.api.types import is_bool_dtype, is_numeric_dtype, is_string_dtype from pathlib import Path diff --git a/gradio/interface.py b/gradio/interface.py index 31f8829a23..9aee452cc3 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -4,13 +4,12 @@ interface using the input and output types. """ import gradio -from gradio.inputs import InputComponent, get_input_instance -from gradio.outputs import OutputComponent, get_output_instance +from gradio.inputs import get_input_instance +from gradio.outputs import get_output_instance from gradio import networking, strings, utils from gradio.interpretation import quantify_difference_in_label, get_regression_or_classification_value from gradio.external import load_interface from gradio import encryptor -import pkg_resources import requests import random import time @@ -70,7 +69,7 @@ class Interface: title=None, description=None, article=None, thumbnail=None, css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900, allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False, - show_tips=False, embedding=None, flagging_dir="flagged", analytics_enabled=True): + show_tips=False, embedding=None, flagging_dir="flagged", analytics_enabled=True, enable_queue=False): """ Parameters: @@ -99,6 +98,7 @@ class Interface: encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch flagging_dir (str): what to name the dir where flagged data is stored. show_tips (bool): if True, will occasionally show tips about new Gradio features + enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout. """ if not isinstance(fn, list): fn = [fn] @@ -171,6 +171,7 @@ class Interface: self.embedding = embedding self.show_tips = show_tips self.requires_permissions = any([component.requires_permissions for component in self.input_components]) + self.enable_queue = enable_queue data = {'fn': fn, 'inputs': inputs, @@ -246,6 +247,7 @@ class Interface: "flagging_options": self.flagging_options, "allow_interpretation": self.interpretation is not None, "allow_embedding": self.embedding is not None, + "queue": self.enable_queue } try: param_names = inspect.getfullargspec(self.predict[0])[0] @@ -505,9 +507,8 @@ class Interface: self.encryption_key = encryptor.get_key(getpass("Enter key for encryption: ")) # Launch local flask server - server_port, app, thread = networking.start_server( + server_port, path_to_local_server, app, thread = networking.start_server( self, self.server_name, self.server_port, self.auth) - path_to_local_server = "http://{}:{}/".format(self.server_name, server_port) self.local_url = path_to_local_server self.server_port = server_port self.status = "RUNNING" diff --git a/gradio/networking.py b/gradio/networking.py index 6796fd0fa9..f06ffe6d39 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -25,6 +25,7 @@ import gradio as gr from gradio.embeddings import calculate_similarity, fit_pca_to_embeddings, transform_with_pca from gradio.tunneling import create_tunnel from gradio import encryptor +from gradio import queue from functools import wraps import io @@ -392,12 +393,41 @@ def file(path): return send_file(os.path.join(app.cwd, path)) +@app.route("/api/queue/push/", methods=["POST"]) +@login_check +def queue_push(): + data = request.json + job_hash = queue.push(data) + return job_hash + +@app.route("/api/queue/status/", methods=["POST"]) +@login_check +def queue_status(): + hash = request.json['hash'] + status, data = queue.get_status(hash) + return {"status": status, "data": data} + +def queue_thread(path_to_local_server): + while True: + next_job = queue.pop() + if next_job is not None: + _, hash, input_data = next_job + queue.start_job(hash) + response = requests.post(path_to_local_server + "/api/predict/", json=input_data) + if response.status_code == 200: + queue.pass_job(hash, response.json()) + else: + queue.fail_job(hash, response.text) + else: + time.sleep(1) + def start_server(interface, server_name, server_port=None, auth=None, ssl=None): if server_port is None: server_port = INITIAL_PORT_VALUE port = get_first_available_port( server_port, server_port + TRY_NUM_PORTS ) + path_to_local_server = "http://{}:{}/".format(server_name, port) if auth is not None: if not callable(auth): app.auth = {account[0]: account[1] for account in auth} @@ -409,6 +439,10 @@ def start_server(interface, server_name, server_port=None, auth=None, ssl=None): app.cwd = os.getcwd() log = logging.getLogger('werkzeug') log.setLevel(logging.ERROR) + if app.interface.enable_queue: + queue.init() + app.queue_thread = threading.Thread(target=queue_thread, args=(path_to_local_server,)) + app.queue_thread.start() if interface.save_to is not None: interface.save_to["port"] = port app_kwargs = {"port": port, "host": server_name} @@ -420,7 +454,7 @@ def start_server(interface, server_name, server_port=None, auth=None, ssl=None): daemon=True) thread.start() - return port, app, thread + return port, path_to_local_server, app, thread def close_server(process): diff --git a/gradio/queue.py b/gradio/queue.py new file mode 100644 index 0000000000..c5e389b6be --- /dev/null +++ b/gradio/queue.py @@ -0,0 +1,133 @@ +import sqlite3 +import os +import uuid +import json + +DB_FILE = "gradio_queue.db" + +def init(): + if os.path.exists(DB_FILE): + os.remove(DB_FILE) + conn = sqlite3.connect(DB_FILE) + c = conn.cursor() + c.execute("BEGIN EXCLUSIVE") + c.execute("""CREATE TABLE queue ( + queue_index integer PRIMARY KEY, + hash text, + input_data text, + popped integer DEFAULT 0 + );""") + c.execute(""" + CREATE TABLE jobs ( + hash text PRIMARY KEY, + status text, + output_data text, + error_message text + ); + """) + conn.commit() + +def pop(): + conn = sqlite3.connect(DB_FILE) + c = conn.cursor() + c.execute("BEGIN EXCLUSIVE") + c.execute(""" + SELECT queue_index, hash, input_data FROM queue + WHERE popped = 0 ORDER BY queue_index ASC LIMIT 1; + """) + result = c.fetchone() + if result is None: + conn.commit() + return None + queue_index = result[0] + c.execute(""" + UPDATE queue SET popped = 1, input_data = '' WHERE queue_index = ?; + """, (queue_index,)) + conn.commit() + return result[0], result[1], json.loads(result[2]) + + +def push(input_data): + input_data = json.dumps(input_data) + hash = uuid.uuid4().hex + conn = sqlite3.connect(DB_FILE) + c = conn.cursor() + c.execute("BEGIN EXCLUSIVE") + c.execute(""" + INSERT INTO queue (hash, input_data) + VALUES (?, ?); + """, (hash, input_data)) + conn.commit() + return hash + +def get_status(hash): + conn = sqlite3.connect(DB_FILE) + c = conn.cursor() + c.execute("BEGIN EXCLUSIVE") + c.execute(""" + SELECT queue_index FROM queue WHERE hash = ?; + """, (hash,)) + result = c.fetchone() + if result is None: # not in queue + c.execute(""" + SELECT status, output_data, error_message FROM jobs WHERE hash = ?; + """, (hash,)) + result = c.fetchone() + if result is None: + conn.commit() + return "NOT FOUND", None + else: + status, output_data, error_message = result + if status == "PENDING": + conn.commit() + return "PENDING", None + elif status == "FAILED": + conn.commit() + return "FAILED", error_message + elif status == "COMPLETE": + c.execute(""" + UPDATE jobs SET output_data = '' WHERE hash = ?; + """, (hash,)) + conn.commit() + output_data = json.loads(output_data) + return "COMPLETE", output_data + else: + queue_index = result[0] + c.execute(""" + SELECT COUNT(*) FROM queue WHERE queue_index < ?; + """, (queue_index,)) + result = c.fetchone() + conn.commit() + return "QUEUED", result[0] + +def start_job(hash): + conn = sqlite3.connect(DB_FILE) + c = conn.cursor() + c.execute("BEGIN EXCLUSIVE") + c.execute(""" + DELETE FROM queue WHERE hash = ?; + """, (hash,)) + c.execute(""" + INSERT INTO jobs (hash, status) VALUES (?, 'PENDING'); + """, (hash,)) + conn.commit() + +def fail_job(hash, error_message): + conn = sqlite3.connect(DB_FILE) + c = conn.cursor() + c.execute("BEGIN EXCLUSIVE") + c.execute(""" + UPDATE jobs SET status = 'FAILED', error_message = ? WHERE hash = ?; + """, (error_message, hash,)) + conn.commit() + +def pass_job(hash, output_data): + output_data = json.dumps(output_data) + conn = sqlite3.connect(DB_FILE) + c = conn.cursor() + c.execute("BEGIN EXCLUSIVE") + c.execute(""" + UPDATE jobs SET status = 'COMPLETE', output_data = ? WHERE hash = ?; + """, (output_data, hash,)) + conn.commit() + diff --git a/gradio/version.txt b/gradio/version.txt index e3d11902be..0d3ad67afa 100644 --- a/gradio/version.txt +++ b/gradio/version.txt @@ -1 +1 @@ -2.2.9.a2 +2.2.10 diff --git a/setup.py b/setup.py index 9cb0339d8e..5132d5b1f9 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ except ImportError: setup( name='gradio', - version='2.2.9.a2', + version='2.2.10', include_package_data=True, description='Python library for easily interacting with trained machine learning models', author='Abubakar Abid',