diff --git a/.gitignore b/.gitignore index 6fa9e787f1..ceb8b79025 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,4 @@ build/ gradio/launches.json gradio/frontend/static workspace.code-workspace +*.db diff --git a/demo/calculator.py b/demo/calculator.py index 1c6591172b..04ed3e511c 100644 --- a/demo/calculator.py +++ b/demo/calculator.py @@ -4,7 +4,6 @@ import time def calculator(num1, operation, num2): print(num1, operation, num2) - time.sleep(10) if operation == "add": return num1 + num2 elif operation == "subtract": @@ -26,8 +25,7 @@ iface = gr.Interface(calculator, title="test calculator", description="heres a sample toy calculator. enjoy!", flagging_options=["this", "or", "that"], - enable_queue=True ) if __name__ == "__main__": - iface.launch(share=True) + iface.launch() diff --git a/demo/gender_sentence_custom_interpretation.py b/demo/gender_sentence_custom_interpretation.py index 3e763d9145..d605c6d48d 100644 --- a/demo/gender_sentence_custom_interpretation.py +++ b/demo/gender_sentence_custom_interpretation.py @@ -24,6 +24,6 @@ def interpret_gender(sentence): iface = gr.Interface( fn=gender_of_sentence, inputs=gr.inputs.Textbox(default="She went to his house to get her keys."), - outputs="label", interpretation=interpret_gender) + outputs="label", interpretation=interpret_gender, enable_queue=True) if __name__ == "__main__": iface.launch() \ No newline at end of file diff --git a/demo/gradio_queue.db b/demo/gradio_queue.db index 29925d36f0..e865c561cc 100644 Binary files a/demo/gradio_queue.db and b/demo/gradio_queue.db differ diff --git a/frontend/src/gradio.jsx b/frontend/src/gradio.jsx index 2840d2b105..76ba7f92e0 100644 --- a/frontend/src/gradio.jsx +++ b/frontend/src/gradio.jsx @@ -165,7 +165,7 @@ export class GradioInterface extends React.Component { } this.setState({ submitting: true, has_changed: false, error: false }); this.props - .fn(input_state, "interpret") + .fn(input_state, "interpret", this.queueCallback) .then((output) => { this.setState({ interpretation: output["interpretation_scores"], diff --git a/frontend/src/index.jsx b/frontend/src/index.jsx index e857d81f79..98b0967156 100644 --- a/frontend/src/index.jsx +++ b/frontend/src/index.jsx @@ -22,9 +22,9 @@ let postData = async (url, body) => { } let fn = async (queue, data, action, queue_callback) => { - if (queue && action == "predict") { + if (queue && ["predict", "interpret"].includes(action)) { const output = await postData( - "api/queue/push/", { data: data }, + "api/queue/push/", { data: data, action:action }, ); let hash = await output.text(); let status = "UNKNOWN"; diff --git a/gradio.egg-info/PKG-INFO b/gradio.egg-info/PKG-INFO index 9e10a78674..9804bee7c0 100644 --- a/gradio.egg-info/PKG-INFO +++ b/gradio.egg-info/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 1.0 Name: gradio -Version: 2.2.11 +Version: 2.2.12 Summary: Python library for easily interacting with trained machine learning models Home-page: https://github.com/gradio-app/gradio-UI Author: Abubakar Abid diff --git a/gradio/interface.py b/gradio/interface.py index 3929c48de6..506ac8e30e 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -10,6 +10,7 @@ from gradio import networking, strings, utils from gradio.interpretation import quantify_difference_in_label, get_regression_or_classification_value from gradio.external import load_interface from gradio import encryptor +from gradio import queue import requests import random import time @@ -462,6 +463,8 @@ class Interface: print("Keyboard interruption in main thread... closing server.") thread.keep_running = False networking.url_ok(path_to_local_server) # Hit the server one more time to close it + if self.enable_queue: + queue.close() def test_launch(self): for predict_fn in self.predict: diff --git a/gradio/networking.py b/gradio/networking.py index f06ffe6d39..1cbdfce062 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -396,8 +396,9 @@ def file(path): @app.route("/api/queue/push/", methods=["POST"]) @login_check def queue_push(): - data = request.json - job_hash = queue.push(data) + data = request.json["data"] + action = request.json["action"] + job_hash = queue.push({"data": data}, action) return job_hash @app.route("/api/queue/status/", methods=["POST"]) @@ -409,17 +410,22 @@ def queue_status(): def queue_thread(path_to_local_server): while True: - next_job = queue.pop() - if next_job is not None: - _, hash, input_data = next_job - queue.start_job(hash) - response = requests.post(path_to_local_server + "/api/predict/", json=input_data) - if response.status_code == 200: - queue.pass_job(hash, response.json()) + try: + next_job = queue.pop() + if next_job is not None: + _, hash, input_data, task_type = next_job + queue.start_job(hash) + response = requests.post( + path_to_local_server + "/api/" + task_type + "/", json=input_data) + if response.status_code == 200: + queue.pass_job(hash, response.json()) + else: + queue.fail_job(hash, response.text) else: - queue.fail_job(hash, response.text) - else: + time.sleep(1) + except Exception as e: time.sleep(1) + pass def start_server(interface, server_name, server_port=None, auth=None, ssl=None): if server_port is None: @@ -440,6 +446,8 @@ def start_server(interface, server_name, server_port=None, auth=None, ssl=None): log = logging.getLogger('werkzeug') log.setLevel(logging.ERROR) if app.interface.enable_queue: + if auth is not None or app.interface.encrypt: + raise ValueError("Cannot queue with encryption or authenitcation enabled.") queue.init() app.queue_thread = threading.Thread(target=queue_thread, args=(path_to_local_server,)) app.queue_thread.start() diff --git a/gradio/queue.py b/gradio/queue.py index c5e389b6be..b642eff913 100644 --- a/gradio/queue.py +++ b/gradio/queue.py @@ -5,6 +5,20 @@ import json DB_FILE = "gradio_queue.db" +def generate_hash(): + generate = True + while generate: + hash = uuid.uuid4().hex + conn = sqlite3.connect(DB_FILE) + c = conn.cursor() + c.execute(""" + SELECT hash FROM queue + WHERE hash = ?; + """, (hash,)) + conn.commit() + generate = c.fetchone() is not None + return hash + def init(): if os.path.exists(DB_FILE): os.remove(DB_FILE) @@ -15,6 +29,7 @@ def init(): queue_index integer PRIMARY KEY, hash text, input_data text, + action text, popped integer DEFAULT 0 );""") c.execute(""" @@ -27,12 +42,16 @@ def init(): """) conn.commit() +def close(): + if os.path.exists(DB_FILE): + os.remove(DB_FILE) + def pop(): conn = sqlite3.connect(DB_FILE) c = conn.cursor() c.execute("BEGIN EXCLUSIVE") c.execute(""" - SELECT queue_index, hash, input_data FROM queue + SELECT queue_index, hash, input_data, action FROM queue WHERE popped = 0 ORDER BY queue_index ASC LIMIT 1; """) result = c.fetchone() @@ -44,19 +63,19 @@ def pop(): UPDATE queue SET popped = 1, input_data = '' WHERE queue_index = ?; """, (queue_index,)) conn.commit() - return result[0], result[1], json.loads(result[2]) + return result[0], result[1], json.loads(result[2]), result[3] -def push(input_data): +def push(input_data, action): input_data = json.dumps(input_data) - hash = uuid.uuid4().hex + hash = generate_hash() conn = sqlite3.connect(DB_FILE) c = conn.cursor() c.execute("BEGIN EXCLUSIVE") c.execute(""" - INSERT INTO queue (hash, input_data) - VALUES (?, ?); - """, (hash, input_data)) + INSERT INTO queue (hash, input_data, action) + VALUES (?, ?, ?); + """, (hash, input_data, action)) conn.commit() return hash @@ -65,10 +84,10 @@ def get_status(hash): c = conn.cursor() c.execute("BEGIN EXCLUSIVE") c.execute(""" - SELECT queue_index FROM queue WHERE hash = ?; + SELECT queue_index, popped FROM queue WHERE hash = ?; """, (hash,)) result = c.fetchone() - if result is None: # not in queue + if result[1] == 1: # in jobs c.execute(""" SELECT status, output_data, error_message FROM jobs WHERE hash = ?; """, (hash,)) @@ -94,7 +113,7 @@ def get_status(hash): else: queue_index = result[0] c.execute(""" - SELECT COUNT(*) FROM queue WHERE queue_index < ?; + SELECT COUNT(*) FROM queue WHERE queue_index < ? and popped = 0; """, (queue_index,)) result = c.fetchone() conn.commit() @@ -105,7 +124,7 @@ def start_job(hash): c = conn.cursor() c.execute("BEGIN EXCLUSIVE") c.execute(""" - DELETE FROM queue WHERE hash = ?; + UPDATE queue SET popped = 1 WHERE hash = ?; """, (hash,)) c.execute(""" INSERT INTO jobs (hash, status) VALUES (?, 'PENDING'); diff --git a/gradio/version.txt b/gradio/version.txt index 0b6e43134b..98c938ec34 100644 --- a/gradio/version.txt +++ b/gradio/version.txt @@ -1 +1 @@ -2.2.11 +2.2.12 diff --git a/setup.py b/setup.py index 1270ea1649..0dbc97060d 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ except ImportError: setup( name='gradio', - version='2.2.11', + version='2.2.12', include_package_data=True, description='Python library for easily interacting with trained machine learning models', author='Abubakar Abid',