updated PyPi version

This commit is contained in:
root 2021-08-16 21:58:30 +00:00
parent 076db7eaea
commit 998a5fb156
17 changed files with 241 additions and 34 deletions

View File

@ -1,8 +1,10 @@
import gradio as gr
import random
import time
def calculator(num1, operation, num2):
print(num1, operation, num2)
time.sleep(10)
if operation == "add":
return num1 + num2
elif operation == "subtract":
@ -23,8 +25,9 @@ iface = gr.Interface(calculator,
],
title="test calculator",
description="heres a sample toy calculator. enjoy!",
flagging_options=["this", "or", "that"]
flagging_options=["this", "or", "that"],
enable_queue=True
)
if __name__ == "__main__":
iface.launch()
iface.launch(share=True)

BIN
demo/gradio_queue.db Normal file

Binary file not shown.

View File

@ -2,6 +2,6 @@ import React from "react";
export default class BaseComponent extends React.Component {
static memo = (a, b) => {
return a.value === b.value;
return a.value === b.value && a.interpretation === b.interpretation;
};
}

View File

@ -12,7 +12,6 @@ class CheckboxInput extends BaseComponent {
this.props.handleChange(this.props.value !== true);
};
render() {
console.log("render checkbox");
return (
<div className="input_checkbox">
<div

View File

@ -12,7 +12,6 @@ class TextboxInput extends BaseComponent {
this.props.handleChange(evt.target.value);
}
render() {
console.log("render text");
if (this.props.interpretation !== null) {
return (
<div className="input_text">

View File

@ -82,6 +82,7 @@ export class GradioInterface extends React.Component {
state["has_changed"] = false;
state["example_id"] = null;
state["flag_index"] = null;
state["queue_index"] = null;
return state;
};
clear = () => {
@ -102,7 +103,7 @@ export class GradioInterface extends React.Component {
flag_index: null
});
this.props
.fn(input_state, "predict")
.fn(input_state, "predict", this.queueCallback)
.then((output) => {
let index_start = this.props.input_components.length;
let new_state = {};
@ -182,6 +183,9 @@ export class GradioInterface extends React.Component {
removeInterpret = () => {
this.setState({ interpretation: null });
};
queueCallback = (queue_index) => {
this.setState({"queue_index": queue_index});
}
takeScreenshot = () => {
html2canvas(ReactDOM.findDOMNode(this).parentNode).then((canvas) => {
saveAs(canvas.toDataURL(), "screenshot.png");
@ -220,6 +224,7 @@ export class GradioInterface extends React.Component {
if (this.state.submitting) {
status = (
<div className="loading">
{this.state.queue_index !== null ? "queued @ " + this.state.queue_index : false}
<img alt="loading" src={logo_loading} />
</div>
);

View File

@ -3,18 +3,57 @@ import ReactDOM from "react-dom";
import { GradioPage } from "./gradio";
import Login from "./login";
let fn = async (data, action) => {
function delay(n) {
return new Promise(function(resolve){
setTimeout(resolve,n*1000);
});
}
let postData = async (url, body) => {
const output = await fetch(
process.env.REACT_APP_BACKEND_URL + "api/" + action + "/",
process.env.REACT_APP_BACKEND_URL + url,
{
method: "POST",
body: JSON.stringify({ data: data }),
headers: {
"Content-Type": "application/json"
}
body: JSON.stringify(body),
headers: {"Content-Type": "application/json"}
}
);
return await output.json();
return output;
}
let fn = async (queue, data, action, queue_callback) => {
if (queue && action == "predict") {
const output = await postData(
"api/queue/push/", { data: data },
);
let hash = await output.text();
let status = "UNKNOWN";
while (status != "COMPLETE" && status != "FAILED") {
if (status != "UNKNOWN") {
await delay(1);
}
const status_response = await postData(
"api/queue/status/", { hash: hash },
);
var status_obj = await status_response.json();
status = status_obj["status"];
if (status === "QUEUED") {
queue_callback(status_obj["data"]);
} else if (status === "PENDING") {
queue_callback(null);
}
}
if (status == "FAILED") {
throw new Error(status);
} else {
return status_obj["data"];
}
} else {
const output = await postData(
"api/" + action + "/", { data: data },
);
return await output.json();
}
};
async function get_config() {
@ -39,7 +78,7 @@ get_config().then((config) => {
style.appendChild(document.createTextNode(config.css));
}
ReactDOM.render(
<GradioPage {...config} fn={fn} />,
<GradioPage {...config} fn={fn.bind(null, config.queue)} />,
document.getElementById("root")
);
}

View File

@ -48,7 +48,7 @@
@apply absolute right-1;
}
.loading img {
@apply h-5;
@apply h-5 ml-2 inline-block;
}
.panels {
@apply flex flex-wrap justify-center gap-4;
@ -67,6 +67,7 @@
}
.component_set {
@apply bg-gray-50 p-2 rounded flex flex-col flex-1 gap-2;
min-height: 36px;
}
.panel_header {
@apply mb-1 uppercase text-sm font-semibold;

View File

@ -1,4 +1,3 @@
import base64
from Crypto.Cipher import AES
from Crypto.Hash import SHA256
from Crypto import Random

View File

@ -1,6 +1,6 @@
{
"files": {
"main.css": "/static/css/main.22c25130.css",
"main.css": "/static/css/main.0a917430.css",
"main.js": "/static/bundle.js",
"index.html": "/index.html",
"static/bundle.js.LICENSE.txt": "/static/bundle.js.LICENSE.txt",
@ -9,7 +9,7 @@
},
"entrypoints": [
"static/bundle.css",
"static/css/main.22c25130.css",
"static/css/main.0a917430.css",
"static/bundle.js"
]
}

View File

@ -8,4 +8,4 @@
window.config = {{ config|tojson }};
} catch (e) {
window.config = {};
}</script><script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script><title>Gradio</title><link href="static/bundle.css" rel="stylesheet"><link href="static/css/main.22c25130.css" rel="stylesheet"></head><body style="height:100%"><div id="root" style="height:100%"></div><script src="static/bundle.js"></script></body></html>
}</script><script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script><title>Gradio</title><link href="static/bundle.css" rel="stylesheet"><link href="static/css/main.0a917430.css" rel="stylesheet"></head><body style="height:100%"><div id="root" style="height:100%"></div><script src="static/bundle.js"></script></body></html>

View File

@ -4,14 +4,9 @@ This module defines various classes that can serve as the `input` to an interfac
automatically added to a registry, which allows them to be easily referenced in other parts of the code.
"""
import datetime
import json
import os
import shutil
import time
import warnings
from gradio.component import Component
import base64
import numpy as np
import PIL
import scipy.io.wavfile
@ -20,7 +15,6 @@ import pandas as pd
from ffmpy import FFmpeg
import math
import tempfile
from pandas.api.types import is_bool_dtype, is_numeric_dtype, is_string_dtype
from pathlib import Path

View File

@ -4,13 +4,12 @@ interface using the input and output types.
"""
import gradio
from gradio.inputs import InputComponent, get_input_instance
from gradio.outputs import OutputComponent, get_output_instance
from gradio.inputs import get_input_instance
from gradio.outputs import get_output_instance
from gradio import networking, strings, utils
from gradio.interpretation import quantify_difference_in_label, get_regression_or_classification_value
from gradio.external import load_interface
from gradio import encryptor
import pkg_resources
import requests
import random
import time
@ -70,7 +69,7 @@ class Interface:
title=None, description=None, article=None, thumbnail=None,
css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900,
allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False,
show_tips=False, embedding=None, flagging_dir="flagged", analytics_enabled=True):
show_tips=False, embedding=None, flagging_dir="flagged", analytics_enabled=True, enable_queue=False):
"""
Parameters:
@ -99,6 +98,7 @@ class Interface:
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
flagging_dir (str): what to name the dir where flagged data is stored.
show_tips (bool): if True, will occasionally show tips about new Gradio features
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
"""
if not isinstance(fn, list):
fn = [fn]
@ -171,6 +171,7 @@ class Interface:
self.embedding = embedding
self.show_tips = show_tips
self.requires_permissions = any([component.requires_permissions for component in self.input_components])
self.enable_queue = enable_queue
data = {'fn': fn,
'inputs': inputs,
@ -246,6 +247,7 @@ class Interface:
"flagging_options": self.flagging_options,
"allow_interpretation": self.interpretation is not None,
"allow_embedding": self.embedding is not None,
"queue": self.enable_queue
}
try:
param_names = inspect.getfullargspec(self.predict[0])[0]
@ -505,9 +507,8 @@ class Interface:
self.encryption_key = encryptor.get_key(getpass("Enter key for encryption: "))
# Launch local flask server
server_port, app, thread = networking.start_server(
server_port, path_to_local_server, app, thread = networking.start_server(
self, self.server_name, self.server_port, self.auth)
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
self.local_url = path_to_local_server
self.server_port = server_port
self.status = "RUNNING"

View File

@ -25,6 +25,7 @@ import gradio as gr
from gradio.embeddings import calculate_similarity, fit_pca_to_embeddings, transform_with_pca
from gradio.tunneling import create_tunnel
from gradio import encryptor
from gradio import queue
from functools import wraps
import io
@ -392,12 +393,41 @@ def file(path):
return send_file(os.path.join(app.cwd, path))
@app.route("/api/queue/push/", methods=["POST"])
@login_check
def queue_push():
data = request.json
job_hash = queue.push(data)
return job_hash
@app.route("/api/queue/status/", methods=["POST"])
@login_check
def queue_status():
hash = request.json['hash']
status, data = queue.get_status(hash)
return {"status": status, "data": data}
def queue_thread(path_to_local_server):
while True:
next_job = queue.pop()
if next_job is not None:
_, hash, input_data = next_job
queue.start_job(hash)
response = requests.post(path_to_local_server + "/api/predict/", json=input_data)
if response.status_code == 200:
queue.pass_job(hash, response.json())
else:
queue.fail_job(hash, response.text)
else:
time.sleep(1)
def start_server(interface, server_name, server_port=None, auth=None, ssl=None):
if server_port is None:
server_port = INITIAL_PORT_VALUE
port = get_first_available_port(
server_port, server_port + TRY_NUM_PORTS
)
path_to_local_server = "http://{}:{}/".format(server_name, port)
if auth is not None:
if not callable(auth):
app.auth = {account[0]: account[1] for account in auth}
@ -409,6 +439,10 @@ def start_server(interface, server_name, server_port=None, auth=None, ssl=None):
app.cwd = os.getcwd()
log = logging.getLogger('werkzeug')
log.setLevel(logging.ERROR)
if app.interface.enable_queue:
queue.init()
app.queue_thread = threading.Thread(target=queue_thread, args=(path_to_local_server,))
app.queue_thread.start()
if interface.save_to is not None:
interface.save_to["port"] = port
app_kwargs = {"port": port, "host": server_name}
@ -420,7 +454,7 @@ def start_server(interface, server_name, server_port=None, auth=None, ssl=None):
daemon=True)
thread.start()
return port, app, thread
return port, path_to_local_server, app, thread
def close_server(process):

133
gradio/queue.py Normal file
View File

@ -0,0 +1,133 @@
import sqlite3
import os
import uuid
import json
DB_FILE = "gradio_queue.db"
def init():
if os.path.exists(DB_FILE):
os.remove(DB_FILE)
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
c.execute("BEGIN EXCLUSIVE")
c.execute("""CREATE TABLE queue (
queue_index integer PRIMARY KEY,
hash text,
input_data text,
popped integer DEFAULT 0
);""")
c.execute("""
CREATE TABLE jobs (
hash text PRIMARY KEY,
status text,
output_data text,
error_message text
);
""")
conn.commit()
def pop():
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
c.execute("BEGIN EXCLUSIVE")
c.execute("""
SELECT queue_index, hash, input_data FROM queue
WHERE popped = 0 ORDER BY queue_index ASC LIMIT 1;
""")
result = c.fetchone()
if result is None:
conn.commit()
return None
queue_index = result[0]
c.execute("""
UPDATE queue SET popped = 1, input_data = '' WHERE queue_index = ?;
""", (queue_index,))
conn.commit()
return result[0], result[1], json.loads(result[2])
def push(input_data):
input_data = json.dumps(input_data)
hash = uuid.uuid4().hex
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
c.execute("BEGIN EXCLUSIVE")
c.execute("""
INSERT INTO queue (hash, input_data)
VALUES (?, ?);
""", (hash, input_data))
conn.commit()
return hash
def get_status(hash):
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
c.execute("BEGIN EXCLUSIVE")
c.execute("""
SELECT queue_index FROM queue WHERE hash = ?;
""", (hash,))
result = c.fetchone()
if result is None: # not in queue
c.execute("""
SELECT status, output_data, error_message FROM jobs WHERE hash = ?;
""", (hash,))
result = c.fetchone()
if result is None:
conn.commit()
return "NOT FOUND", None
else:
status, output_data, error_message = result
if status == "PENDING":
conn.commit()
return "PENDING", None
elif status == "FAILED":
conn.commit()
return "FAILED", error_message
elif status == "COMPLETE":
c.execute("""
UPDATE jobs SET output_data = '' WHERE hash = ?;
""", (hash,))
conn.commit()
output_data = json.loads(output_data)
return "COMPLETE", output_data
else:
queue_index = result[0]
c.execute("""
SELECT COUNT(*) FROM queue WHERE queue_index < ?;
""", (queue_index,))
result = c.fetchone()
conn.commit()
return "QUEUED", result[0]
def start_job(hash):
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
c.execute("BEGIN EXCLUSIVE")
c.execute("""
DELETE FROM queue WHERE hash = ?;
""", (hash,))
c.execute("""
INSERT INTO jobs (hash, status) VALUES (?, 'PENDING');
""", (hash,))
conn.commit()
def fail_job(hash, error_message):
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
c.execute("BEGIN EXCLUSIVE")
c.execute("""
UPDATE jobs SET status = 'FAILED', error_message = ? WHERE hash = ?;
""", (error_message, hash,))
conn.commit()
def pass_job(hash, output_data):
output_data = json.dumps(output_data)
conn = sqlite3.connect(DB_FILE)
c = conn.cursor()
c.execute("BEGIN EXCLUSIVE")
c.execute("""
UPDATE jobs SET status = 'COMPLETE', output_data = ? WHERE hash = ?;
""", (output_data, hash,))
conn.commit()

View File

@ -1 +1 @@
2.2.9.a2
2.2.10

View File

@ -5,7 +5,7 @@ except ImportError:
setup(
name='gradio',
version='2.2.9.a2',
version='2.2.10',
include_package_data=True,
description='Python library for easily interacting with trained machine learning models',
author='Abubakar Abid',