mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
updated PyPi version
This commit is contained in:
parent
076db7eaea
commit
998a5fb156
@ -1,8 +1,10 @@
|
||||
import gradio as gr
|
||||
import random
|
||||
import time
|
||||
|
||||
def calculator(num1, operation, num2):
|
||||
print(num1, operation, num2)
|
||||
time.sleep(10)
|
||||
if operation == "add":
|
||||
return num1 + num2
|
||||
elif operation == "subtract":
|
||||
@ -23,8 +25,9 @@ iface = gr.Interface(calculator,
|
||||
],
|
||||
title="test calculator",
|
||||
description="heres a sample toy calculator. enjoy!",
|
||||
flagging_options=["this", "or", "that"]
|
||||
flagging_options=["this", "or", "that"],
|
||||
enable_queue=True
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
iface.launch()
|
||||
iface.launch(share=True)
|
||||
|
BIN
demo/gradio_queue.db
Normal file
BIN
demo/gradio_queue.db
Normal file
Binary file not shown.
@ -2,6 +2,6 @@ import React from "react";
|
||||
|
||||
export default class BaseComponent extends React.Component {
|
||||
static memo = (a, b) => {
|
||||
return a.value === b.value;
|
||||
return a.value === b.value && a.interpretation === b.interpretation;
|
||||
};
|
||||
}
|
||||
|
@ -12,7 +12,6 @@ class CheckboxInput extends BaseComponent {
|
||||
this.props.handleChange(this.props.value !== true);
|
||||
};
|
||||
render() {
|
||||
console.log("render checkbox");
|
||||
return (
|
||||
<div className="input_checkbox">
|
||||
<div
|
||||
|
@ -12,7 +12,6 @@ class TextboxInput extends BaseComponent {
|
||||
this.props.handleChange(evt.target.value);
|
||||
}
|
||||
render() {
|
||||
console.log("render text");
|
||||
if (this.props.interpretation !== null) {
|
||||
return (
|
||||
<div className="input_text">
|
||||
|
@ -82,6 +82,7 @@ export class GradioInterface extends React.Component {
|
||||
state["has_changed"] = false;
|
||||
state["example_id"] = null;
|
||||
state["flag_index"] = null;
|
||||
state["queue_index"] = null;
|
||||
return state;
|
||||
};
|
||||
clear = () => {
|
||||
@ -102,7 +103,7 @@ export class GradioInterface extends React.Component {
|
||||
flag_index: null
|
||||
});
|
||||
this.props
|
||||
.fn(input_state, "predict")
|
||||
.fn(input_state, "predict", this.queueCallback)
|
||||
.then((output) => {
|
||||
let index_start = this.props.input_components.length;
|
||||
let new_state = {};
|
||||
@ -182,6 +183,9 @@ export class GradioInterface extends React.Component {
|
||||
removeInterpret = () => {
|
||||
this.setState({ interpretation: null });
|
||||
};
|
||||
queueCallback = (queue_index) => {
|
||||
this.setState({"queue_index": queue_index});
|
||||
}
|
||||
takeScreenshot = () => {
|
||||
html2canvas(ReactDOM.findDOMNode(this).parentNode).then((canvas) => {
|
||||
saveAs(canvas.toDataURL(), "screenshot.png");
|
||||
@ -220,6 +224,7 @@ export class GradioInterface extends React.Component {
|
||||
if (this.state.submitting) {
|
||||
status = (
|
||||
<div className="loading">
|
||||
{this.state.queue_index !== null ? "queued @ " + this.state.queue_index : false}
|
||||
<img alt="loading" src={logo_loading} />
|
||||
</div>
|
||||
);
|
||||
|
@ -3,18 +3,57 @@ import ReactDOM from "react-dom";
|
||||
import { GradioPage } from "./gradio";
|
||||
import Login from "./login";
|
||||
|
||||
let fn = async (data, action) => {
|
||||
function delay(n) {
|
||||
return new Promise(function(resolve){
|
||||
setTimeout(resolve,n*1000);
|
||||
});
|
||||
}
|
||||
|
||||
let postData = async (url, body) => {
|
||||
const output = await fetch(
|
||||
process.env.REACT_APP_BACKEND_URL + "api/" + action + "/",
|
||||
process.env.REACT_APP_BACKEND_URL + url,
|
||||
{
|
||||
method: "POST",
|
||||
body: JSON.stringify({ data: data }),
|
||||
headers: {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
body: JSON.stringify(body),
|
||||
headers: {"Content-Type": "application/json"}
|
||||
}
|
||||
);
|
||||
return await output.json();
|
||||
return output;
|
||||
}
|
||||
|
||||
let fn = async (queue, data, action, queue_callback) => {
|
||||
if (queue && action == "predict") {
|
||||
const output = await postData(
|
||||
"api/queue/push/", { data: data },
|
||||
);
|
||||
let hash = await output.text();
|
||||
let status = "UNKNOWN";
|
||||
while (status != "COMPLETE" && status != "FAILED") {
|
||||
if (status != "UNKNOWN") {
|
||||
await delay(1);
|
||||
}
|
||||
const status_response = await postData(
|
||||
"api/queue/status/", { hash: hash },
|
||||
);
|
||||
var status_obj = await status_response.json();
|
||||
status = status_obj["status"];
|
||||
if (status === "QUEUED") {
|
||||
queue_callback(status_obj["data"]);
|
||||
} else if (status === "PENDING") {
|
||||
queue_callback(null);
|
||||
}
|
||||
}
|
||||
if (status == "FAILED") {
|
||||
throw new Error(status);
|
||||
} else {
|
||||
return status_obj["data"];
|
||||
}
|
||||
} else {
|
||||
const output = await postData(
|
||||
"api/" + action + "/", { data: data },
|
||||
);
|
||||
return await output.json();
|
||||
}
|
||||
};
|
||||
|
||||
async function get_config() {
|
||||
@ -39,7 +78,7 @@ get_config().then((config) => {
|
||||
style.appendChild(document.createTextNode(config.css));
|
||||
}
|
||||
ReactDOM.render(
|
||||
<GradioPage {...config} fn={fn} />,
|
||||
<GradioPage {...config} fn={fn.bind(null, config.queue)} />,
|
||||
document.getElementById("root")
|
||||
);
|
||||
}
|
||||
|
@ -48,7 +48,7 @@
|
||||
@apply absolute right-1;
|
||||
}
|
||||
.loading img {
|
||||
@apply h-5;
|
||||
@apply h-5 ml-2 inline-block;
|
||||
}
|
||||
.panels {
|
||||
@apply flex flex-wrap justify-center gap-4;
|
||||
@ -67,6 +67,7 @@
|
||||
}
|
||||
.component_set {
|
||||
@apply bg-gray-50 p-2 rounded flex flex-col flex-1 gap-2;
|
||||
min-height: 36px;
|
||||
}
|
||||
.panel_header {
|
||||
@apply mb-1 uppercase text-sm font-semibold;
|
||||
|
@ -1,4 +1,3 @@
|
||||
import base64
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Hash import SHA256
|
||||
from Crypto import Random
|
||||
|
@ -1,6 +1,6 @@
|
||||
{
|
||||
"files": {
|
||||
"main.css": "/static/css/main.22c25130.css",
|
||||
"main.css": "/static/css/main.0a917430.css",
|
||||
"main.js": "/static/bundle.js",
|
||||
"index.html": "/index.html",
|
||||
"static/bundle.js.LICENSE.txt": "/static/bundle.js.LICENSE.txt",
|
||||
@ -9,7 +9,7 @@
|
||||
},
|
||||
"entrypoints": [
|
||||
"static/bundle.css",
|
||||
"static/css/main.22c25130.css",
|
||||
"static/css/main.0a917430.css",
|
||||
"static/bundle.js"
|
||||
]
|
||||
}
|
@ -8,4 +8,4 @@
|
||||
window.config = {{ config|tojson }};
|
||||
} catch (e) {
|
||||
window.config = {};
|
||||
}</script><script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script><title>Gradio</title><link href="static/bundle.css" rel="stylesheet"><link href="static/css/main.22c25130.css" rel="stylesheet"></head><body style="height:100%"><div id="root" style="height:100%"></div><script src="static/bundle.js"></script></body></html>
|
||||
}</script><script src="https://cdnjs.cloudflare.com/ajax/libs/iframe-resizer/4.3.1/iframeResizer.contentWindow.min.js"></script><title>Gradio</title><link href="static/bundle.css" rel="stylesheet"><link href="static/css/main.0a917430.css" rel="stylesheet"></head><body style="height:100%"><div id="root" style="height:100%"></div><script src="static/bundle.js"></script></body></html>
|
@ -4,14 +4,9 @@ This module defines various classes that can serve as the `input` to an interfac
|
||||
automatically added to a registry, which allows them to be easily referenced in other parts of the code.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
import warnings
|
||||
from gradio.component import Component
|
||||
import base64
|
||||
import numpy as np
|
||||
import PIL
|
||||
import scipy.io.wavfile
|
||||
@ -20,7 +15,6 @@ import pandas as pd
|
||||
from ffmpy import FFmpeg
|
||||
import math
|
||||
import tempfile
|
||||
from pandas.api.types import is_bool_dtype, is_numeric_dtype, is_string_dtype
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
|
@ -4,13 +4,12 @@ interface using the input and output types.
|
||||
"""
|
||||
|
||||
import gradio
|
||||
from gradio.inputs import InputComponent, get_input_instance
|
||||
from gradio.outputs import OutputComponent, get_output_instance
|
||||
from gradio.inputs import get_input_instance
|
||||
from gradio.outputs import get_output_instance
|
||||
from gradio import networking, strings, utils
|
||||
from gradio.interpretation import quantify_difference_in_label, get_regression_or_classification_value
|
||||
from gradio.external import load_interface
|
||||
from gradio import encryptor
|
||||
import pkg_resources
|
||||
import requests
|
||||
import random
|
||||
import time
|
||||
@ -70,7 +69,7 @@ class Interface:
|
||||
title=None, description=None, article=None, thumbnail=None,
|
||||
css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900,
|
||||
allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False,
|
||||
show_tips=False, embedding=None, flagging_dir="flagged", analytics_enabled=True):
|
||||
show_tips=False, embedding=None, flagging_dir="flagged", analytics_enabled=True, enable_queue=False):
|
||||
|
||||
"""
|
||||
Parameters:
|
||||
@ -99,6 +98,7 @@ class Interface:
|
||||
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
|
||||
flagging_dir (str): what to name the dir where flagged data is stored.
|
||||
show_tips (bool): if True, will occasionally show tips about new Gradio features
|
||||
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
|
||||
"""
|
||||
if not isinstance(fn, list):
|
||||
fn = [fn]
|
||||
@ -171,6 +171,7 @@ class Interface:
|
||||
self.embedding = embedding
|
||||
self.show_tips = show_tips
|
||||
self.requires_permissions = any([component.requires_permissions for component in self.input_components])
|
||||
self.enable_queue = enable_queue
|
||||
|
||||
data = {'fn': fn,
|
||||
'inputs': inputs,
|
||||
@ -246,6 +247,7 @@ class Interface:
|
||||
"flagging_options": self.flagging_options,
|
||||
"allow_interpretation": self.interpretation is not None,
|
||||
"allow_embedding": self.embedding is not None,
|
||||
"queue": self.enable_queue
|
||||
}
|
||||
try:
|
||||
param_names = inspect.getfullargspec(self.predict[0])[0]
|
||||
@ -505,9 +507,8 @@ class Interface:
|
||||
self.encryption_key = encryptor.get_key(getpass("Enter key for encryption: "))
|
||||
|
||||
# Launch local flask server
|
||||
server_port, app, thread = networking.start_server(
|
||||
server_port, path_to_local_server, app, thread = networking.start_server(
|
||||
self, self.server_name, self.server_port, self.auth)
|
||||
path_to_local_server = "http://{}:{}/".format(self.server_name, server_port)
|
||||
self.local_url = path_to_local_server
|
||||
self.server_port = server_port
|
||||
self.status = "RUNNING"
|
||||
|
@ -25,6 +25,7 @@ import gradio as gr
|
||||
from gradio.embeddings import calculate_similarity, fit_pca_to_embeddings, transform_with_pca
|
||||
from gradio.tunneling import create_tunnel
|
||||
from gradio import encryptor
|
||||
from gradio import queue
|
||||
from functools import wraps
|
||||
import io
|
||||
|
||||
@ -392,12 +393,41 @@ def file(path):
|
||||
return send_file(os.path.join(app.cwd, path))
|
||||
|
||||
|
||||
@app.route("/api/queue/push/", methods=["POST"])
|
||||
@login_check
|
||||
def queue_push():
|
||||
data = request.json
|
||||
job_hash = queue.push(data)
|
||||
return job_hash
|
||||
|
||||
@app.route("/api/queue/status/", methods=["POST"])
|
||||
@login_check
|
||||
def queue_status():
|
||||
hash = request.json['hash']
|
||||
status, data = queue.get_status(hash)
|
||||
return {"status": status, "data": data}
|
||||
|
||||
def queue_thread(path_to_local_server):
|
||||
while True:
|
||||
next_job = queue.pop()
|
||||
if next_job is not None:
|
||||
_, hash, input_data = next_job
|
||||
queue.start_job(hash)
|
||||
response = requests.post(path_to_local_server + "/api/predict/", json=input_data)
|
||||
if response.status_code == 200:
|
||||
queue.pass_job(hash, response.json())
|
||||
else:
|
||||
queue.fail_job(hash, response.text)
|
||||
else:
|
||||
time.sleep(1)
|
||||
|
||||
def start_server(interface, server_name, server_port=None, auth=None, ssl=None):
|
||||
if server_port is None:
|
||||
server_port = INITIAL_PORT_VALUE
|
||||
port = get_first_available_port(
|
||||
server_port, server_port + TRY_NUM_PORTS
|
||||
)
|
||||
path_to_local_server = "http://{}:{}/".format(server_name, port)
|
||||
if auth is not None:
|
||||
if not callable(auth):
|
||||
app.auth = {account[0]: account[1] for account in auth}
|
||||
@ -409,6 +439,10 @@ def start_server(interface, server_name, server_port=None, auth=None, ssl=None):
|
||||
app.cwd = os.getcwd()
|
||||
log = logging.getLogger('werkzeug')
|
||||
log.setLevel(logging.ERROR)
|
||||
if app.interface.enable_queue:
|
||||
queue.init()
|
||||
app.queue_thread = threading.Thread(target=queue_thread, args=(path_to_local_server,))
|
||||
app.queue_thread.start()
|
||||
if interface.save_to is not None:
|
||||
interface.save_to["port"] = port
|
||||
app_kwargs = {"port": port, "host": server_name}
|
||||
@ -420,7 +454,7 @@ def start_server(interface, server_name, server_port=None, auth=None, ssl=None):
|
||||
daemon=True)
|
||||
thread.start()
|
||||
|
||||
return port, app, thread
|
||||
return port, path_to_local_server, app, thread
|
||||
|
||||
|
||||
def close_server(process):
|
||||
|
133
gradio/queue.py
Normal file
133
gradio/queue.py
Normal file
@ -0,0 +1,133 @@
|
||||
import sqlite3
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
|
||||
DB_FILE = "gradio_queue.db"
|
||||
|
||||
def init():
|
||||
if os.path.exists(DB_FILE):
|
||||
os.remove(DB_FILE)
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute("BEGIN EXCLUSIVE")
|
||||
c.execute("""CREATE TABLE queue (
|
||||
queue_index integer PRIMARY KEY,
|
||||
hash text,
|
||||
input_data text,
|
||||
popped integer DEFAULT 0
|
||||
);""")
|
||||
c.execute("""
|
||||
CREATE TABLE jobs (
|
||||
hash text PRIMARY KEY,
|
||||
status text,
|
||||
output_data text,
|
||||
error_message text
|
||||
);
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
def pop():
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute("BEGIN EXCLUSIVE")
|
||||
c.execute("""
|
||||
SELECT queue_index, hash, input_data FROM queue
|
||||
WHERE popped = 0 ORDER BY queue_index ASC LIMIT 1;
|
||||
""")
|
||||
result = c.fetchone()
|
||||
if result is None:
|
||||
conn.commit()
|
||||
return None
|
||||
queue_index = result[0]
|
||||
c.execute("""
|
||||
UPDATE queue SET popped = 1, input_data = '' WHERE queue_index = ?;
|
||||
""", (queue_index,))
|
||||
conn.commit()
|
||||
return result[0], result[1], json.loads(result[2])
|
||||
|
||||
|
||||
def push(input_data):
|
||||
input_data = json.dumps(input_data)
|
||||
hash = uuid.uuid4().hex
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute("BEGIN EXCLUSIVE")
|
||||
c.execute("""
|
||||
INSERT INTO queue (hash, input_data)
|
||||
VALUES (?, ?);
|
||||
""", (hash, input_data))
|
||||
conn.commit()
|
||||
return hash
|
||||
|
||||
def get_status(hash):
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute("BEGIN EXCLUSIVE")
|
||||
c.execute("""
|
||||
SELECT queue_index FROM queue WHERE hash = ?;
|
||||
""", (hash,))
|
||||
result = c.fetchone()
|
||||
if result is None: # not in queue
|
||||
c.execute("""
|
||||
SELECT status, output_data, error_message FROM jobs WHERE hash = ?;
|
||||
""", (hash,))
|
||||
result = c.fetchone()
|
||||
if result is None:
|
||||
conn.commit()
|
||||
return "NOT FOUND", None
|
||||
else:
|
||||
status, output_data, error_message = result
|
||||
if status == "PENDING":
|
||||
conn.commit()
|
||||
return "PENDING", None
|
||||
elif status == "FAILED":
|
||||
conn.commit()
|
||||
return "FAILED", error_message
|
||||
elif status == "COMPLETE":
|
||||
c.execute("""
|
||||
UPDATE jobs SET output_data = '' WHERE hash = ?;
|
||||
""", (hash,))
|
||||
conn.commit()
|
||||
output_data = json.loads(output_data)
|
||||
return "COMPLETE", output_data
|
||||
else:
|
||||
queue_index = result[0]
|
||||
c.execute("""
|
||||
SELECT COUNT(*) FROM queue WHERE queue_index < ?;
|
||||
""", (queue_index,))
|
||||
result = c.fetchone()
|
||||
conn.commit()
|
||||
return "QUEUED", result[0]
|
||||
|
||||
def start_job(hash):
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute("BEGIN EXCLUSIVE")
|
||||
c.execute("""
|
||||
DELETE FROM queue WHERE hash = ?;
|
||||
""", (hash,))
|
||||
c.execute("""
|
||||
INSERT INTO jobs (hash, status) VALUES (?, 'PENDING');
|
||||
""", (hash,))
|
||||
conn.commit()
|
||||
|
||||
def fail_job(hash, error_message):
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute("BEGIN EXCLUSIVE")
|
||||
c.execute("""
|
||||
UPDATE jobs SET status = 'FAILED', error_message = ? WHERE hash = ?;
|
||||
""", (error_message, hash,))
|
||||
conn.commit()
|
||||
|
||||
def pass_job(hash, output_data):
|
||||
output_data = json.dumps(output_data)
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute("BEGIN EXCLUSIVE")
|
||||
c.execute("""
|
||||
UPDATE jobs SET status = 'COMPLETE', output_data = ? WHERE hash = ?;
|
||||
""", (output_data, hash,))
|
||||
conn.commit()
|
||||
|
@ -1 +1 @@
|
||||
2.2.9.a2
|
||||
2.2.10
|
||||
|
Loading…
Reference in New Issue
Block a user