mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
auto flagging
This commit is contained in:
parent
0adcdfd97d
commit
f4eb532f63
@ -19,7 +19,8 @@ iface = gr.Interface(calculator,
|
|||||||
[4, "divide", 2],
|
[4, "divide", 2],
|
||||||
[-4, "multiply", 2.5],
|
[-4, "multiply", 2.5],
|
||||||
[0, "subtract", 1.2],
|
[0, "subtract", 1.2],
|
||||||
]
|
],
|
||||||
|
allow_flagging="auto"
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -37,4 +37,4 @@ iface = gr.Interface(image_classifier, image, label,
|
|||||||
])
|
])
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
iface.launch(share=True)
|
iface.launch()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
Metadata-Version: 1.0
|
Metadata-Version: 1.0
|
||||||
Name: gradio
|
Name: gradio
|
||||||
Version: 1.4.4
|
Version: 1.5.0
|
||||||
Summary: Python library for easily interacting with trained machine learning models
|
Summary: Python library for easily interacting with trained machine learning models
|
||||||
Home-page: https://github.com/gradio-app/gradio-UI
|
Home-page: https://github.com/gradio-app/gradio-UI
|
||||||
Author: Abubakar Abid
|
Author: Abubakar Abid
|
||||||
|
@ -140,7 +140,7 @@ class Interface:
|
|||||||
self.server_port = server_port
|
self.server_port = server_port
|
||||||
self.simple_server = None
|
self.simple_server = None
|
||||||
self.allow_screenshot = allow_screenshot
|
self.allow_screenshot = allow_screenshot
|
||||||
self.allow_flagging = allow_flagging
|
self.allow_flagging = os.getenv("GRADIO_FLAGGING") or allow_flagging
|
||||||
self.flagging_dir = flagging_dir
|
self.flagging_dir = flagging_dir
|
||||||
Interface.instances.add(self)
|
Interface.instances.add(self)
|
||||||
self.analytics_enabled=analytics_enabled
|
self.analytics_enabled=analytics_enabled
|
||||||
@ -384,8 +384,7 @@ class Interface:
|
|||||||
|
|
||||||
# Set up local flask server
|
# Set up local flask server
|
||||||
config = self.get_config_file()
|
config = self.get_config_file()
|
||||||
networking.set_config(config)
|
self.config = config
|
||||||
networking.set_meta_tags(self.title, self.description, self.thumbnail)
|
|
||||||
self.auth = auth
|
self.auth = auth
|
||||||
|
|
||||||
# Launch local flask server
|
# Launch local flask server
|
||||||
|
@ -5,7 +5,7 @@ Defines helper methods useful for setting up ports, launching servers, and handl
|
|||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import threading
|
import threading
|
||||||
from flask import Flask, request, jsonify, abort, send_file, render_template
|
from flask import Flask, request, jsonify, abort, send_file, render_template, redirect
|
||||||
from flask_cachebuster import CacheBuster
|
from flask_cachebuster import CacheBuster
|
||||||
from flask_basicauth import BasicAuth
|
from flask_basicauth import BasicAuth
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
@ -44,24 +44,11 @@ app = Flask(__name__,
|
|||||||
CORS(app)
|
CORS(app)
|
||||||
cache_buster = CacheBuster(config={'extensions': ['.js', '.css'], 'hash_size': 5})
|
cache_buster = CacheBuster(config={'extensions': ['.js', '.css'], 'hash_size': 5})
|
||||||
cache_buster.init_app(app)
|
cache_buster.init_app(app)
|
||||||
app.app_globals = {}
|
|
||||||
|
|
||||||
# Hide Flask default message
|
# Hide Flask default message
|
||||||
cli = sys.modules['flask.cli']
|
cli = sys.modules['flask.cli']
|
||||||
cli.show_server_banner = lambda *x: None
|
cli.show_server_banner = lambda *x: None
|
||||||
|
|
||||||
def set_meta_tags(title, description, thumbnail):
|
|
||||||
app.app_globals.update({
|
|
||||||
"title": title,
|
|
||||||
"description": description,
|
|
||||||
"thumbnail": thumbnail
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def set_config(config):
|
|
||||||
app.app_globals["config"] = config
|
|
||||||
|
|
||||||
|
|
||||||
def get_local_ip_address():
|
def get_local_ip_address():
|
||||||
try:
|
try:
|
||||||
ip_address = requests.get('https://api.ipify.org').text
|
ip_address = requests.get('https://api.ipify.org').text
|
||||||
@ -96,24 +83,43 @@ def get_first_available_port(initial, final):
|
|||||||
@app.route("/", methods=["GET"])
|
@app.route("/", methods=["GET"])
|
||||||
def main():
|
def main():
|
||||||
return render_template("index.html",
|
return render_template("index.html",
|
||||||
title=app.app_globals["title"],
|
config=app.interface.config,
|
||||||
description=app.app_globals["description"],
|
|
||||||
thumbnail=app.app_globals["thumbnail"],
|
|
||||||
vendor_prefix=(GRADIO_STATIC_ROOT if app.interface.share else ""),
|
vendor_prefix=(GRADIO_STATIC_ROOT if app.interface.share else ""),
|
||||||
css=app.interface.css
|
css=app.interface.css,
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.route("/from_dir", methods=["GET"])
|
||||||
|
def main_from_flagging_dir():
|
||||||
|
return redirect("/from_dir/" + app.interface.flagging_dir)
|
||||||
|
|
||||||
|
@app.route("/from_dir/<path:path>", methods=["GET"])
|
||||||
|
def main_from_dir(path):
|
||||||
|
log_file = os.path.join(path, "log.csv")
|
||||||
|
if not os.path.exists(log_file):
|
||||||
|
abort(404)
|
||||||
|
with open(log_file) as logs:
|
||||||
|
examples = list(csv.reader(logs))
|
||||||
|
examples = examples[1:] #remove header
|
||||||
|
input_examples = [example[:len(app.interface.input_interfaces)] for example in examples]
|
||||||
|
return render_template("index.html",
|
||||||
|
config=app.interface.config,
|
||||||
|
vendor_prefix=(GRADIO_STATIC_ROOT if app.interface.share else ""),
|
||||||
|
css=app.interface.css,
|
||||||
|
path=path,
|
||||||
|
examples=input_examples
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/config/", methods=["GET"])
|
@app.route("/config/", methods=["GET"])
|
||||||
def config():
|
def config():
|
||||||
return jsonify(app.app_globals["config"])
|
return jsonify(app.interface.config)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/enable_sharing/<path:path>", methods=["GET"])
|
@app.route("/enable_sharing/<path:path>", methods=["GET"])
|
||||||
def enable_sharing(path):
|
def enable_sharing(path):
|
||||||
if path == "None":
|
if path == "None":
|
||||||
path = None
|
path = None
|
||||||
app.app_globals["config"]["share_url"] = path
|
app.interface.config["share_url"] = path
|
||||||
return jsonify(success=True)
|
return jsonify(success=True)
|
||||||
|
|
||||||
|
|
||||||
@ -122,6 +128,11 @@ def predict():
|
|||||||
raw_input = request.json["data"]
|
raw_input = request.json["data"]
|
||||||
prediction, durations = app.interface.process(raw_input)
|
prediction, durations = app.interface.process(raw_input)
|
||||||
output = {"data": prediction, "durations": durations}
|
output = {"data": prediction, "durations": durations}
|
||||||
|
if app.interface.allow_flagging == "auto":
|
||||||
|
try:
|
||||||
|
flag_data(raw_input)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
return jsonify(output)
|
return jsonify(output)
|
||||||
|
|
||||||
def log_feature_analytics(feature):
|
def log_feature_analytics(feature):
|
||||||
@ -206,39 +217,29 @@ def predict_examples():
|
|||||||
return jsonify(output)
|
return jsonify(output)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/flag/", methods=["POST"])
|
def flag_data(data):
|
||||||
def flag():
|
|
||||||
log_feature_analytics('flag')
|
|
||||||
flag_path = os.path.join(app.cwd, app.interface.flagging_dir)
|
flag_path = os.path.join(app.cwd, app.interface.flagging_dir)
|
||||||
output = {'inputs': [app.interface.input_interfaces[
|
output = [app.interface.input_interfaces[i].rebuild(
|
||||||
i].rebuild(
|
flag_path, component_data)
|
||||||
flag_path, request.json['data']['input_data'][i]) for i
|
for i, component_data in enumerate(data)]
|
||||||
in range(len(app.interface.input_interfaces))],
|
|
||||||
'outputs': [app.interface.output_interfaces[
|
|
||||||
i].rebuild(
|
|
||||||
flag_path, request.json['data']['output_data'][i])
|
|
||||||
for i
|
|
||||||
in range(len(app.interface.output_interfaces))]}
|
|
||||||
|
|
||||||
log_fp = "{}/log.csv".format(flag_path)
|
log_fp = "{}/log.csv".format(flag_path)
|
||||||
|
|
||||||
is_new = not os.path.exists(log_fp)
|
is_new = not os.path.exists(log_fp)
|
||||||
|
|
||||||
with open(log_fp, "a") as csvfile:
|
with open(log_fp, "a") as csvfile:
|
||||||
headers = ["input_{}".format(i) for i in range(len(
|
writer = csv.writer(csvfile)
|
||||||
output["inputs"]))] + ["output_{}".format(i) for i in
|
|
||||||
range(len(output["outputs"]))]
|
|
||||||
writer = csv.DictWriter(csvfile, delimiter=',',
|
|
||||||
lineterminator='\n',
|
|
||||||
fieldnames=headers)
|
|
||||||
if is_new:
|
if is_new:
|
||||||
writer.writeheader()
|
headers = [interface[1]["label"] for interface in app.interface.config["input_interfaces"]]
|
||||||
|
writer.writerow(headers)
|
||||||
|
|
||||||
writer.writerow(
|
writer.writerow(output)
|
||||||
dict(zip(headers, output["inputs"] +
|
|
||||||
output["outputs"]))
|
@app.route("/api/flag/", methods=["POST"])
|
||||||
)
|
def flag():
|
||||||
return jsonify(success=True)
|
log_feature_analytics('flag')
|
||||||
|
data = request.json['data']['input_data']
|
||||||
|
flag_data(data)
|
||||||
|
return jsonify(success=True)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/api/interpret/", methods=["POST"])
|
@app.route("/api/interpret/", methods=["POST"])
|
||||||
@ -263,9 +264,9 @@ def start_server(interface, server_name, server_port=None, auth=None):
|
|||||||
server_port, server_port + TRY_NUM_PORTS
|
server_port, server_port + TRY_NUM_PORTS
|
||||||
)
|
)
|
||||||
if auth is not None:
|
if auth is not None:
|
||||||
app.config['BASIC_AUTH_USERNAME'] = auth[0]
|
app.interface.config['BASIC_AUTH_USERNAME'] = auth[0]
|
||||||
app.config['BASIC_AUTH_PASSWORD'] = auth[1]
|
app.interface.config['BASIC_AUTH_PASSWORD'] = auth[1]
|
||||||
app.config['BASIC_AUTH_FORCE'] = True
|
app.interface.config['BASIC_AUTH_FORCE'] = True
|
||||||
basic_auth = BasicAuth(app)
|
basic_auth = BasicAuth(app)
|
||||||
app.interface = interface
|
app.interface = interface
|
||||||
app.cwd = os.getcwd()
|
app.cwd = os.getcwd()
|
||||||
|
@ -205,13 +205,13 @@ function gradio(config, fn, target, example_file_path) {
|
|||||||
if (!config["allow_embedding"]) {
|
if (!config["allow_embedding"]) {
|
||||||
target.find(".embedding").css("visibility", "hidden");
|
target.find(".embedding").css("visibility", "hidden");
|
||||||
}
|
}
|
||||||
if (!config["allow_screenshot"] && !config["allow_flagging"] && !config["allow_interpretation"]) {
|
if (!config["allow_screenshot"] && config["allow_flagging"] !== true && !config["allow_interpretation"]) {
|
||||||
target.find(".screenshot, .record, .flag, .interpret").css("visibility", "hidden");
|
target.find(".screenshot, .record, .flag, .interpret").css("visibility", "hidden");
|
||||||
} else {
|
} else {
|
||||||
if (!config["allow_screenshot"]) {
|
if (!config["allow_screenshot"]) {
|
||||||
target.find(".screenshot, .record").hide();
|
target.find(".screenshot, .record").hide();
|
||||||
}
|
}
|
||||||
if (!config["allow_flagging"]) {
|
if (config["allow_flagging"] !== true) {
|
||||||
target.find(".flag").hide();
|
target.find(".flag").hide();
|
||||||
}
|
}
|
||||||
if (!config["allow_interpretation"]) {
|
if (!config["allow_interpretation"]) {
|
||||||
|
@ -19,14 +19,14 @@
|
|||||||
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
|
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
|
||||||
<meta property="og:url" content="https://gradio.app/" />
|
<meta property="og:url" content="https://gradio.app/" />
|
||||||
<meta property="og:type" content="website" />
|
<meta property="og:type" content="website" />
|
||||||
<meta property="og:image" content="{{thumbnail}}" />
|
<meta property="og:image" content="{{ config['thumbnail'] or '' }}" />
|
||||||
<meta property="og:title" content="{{title}}" />
|
<meta property="og:title" content="{{ config['title'] or '' }}" />
|
||||||
<meta property="og:description" content="{{description}}" />
|
<meta property="og:description" content="{{ config['description'] or '' }}" />
|
||||||
<meta name="twitter:card" content="summary_large_image">
|
<meta name="twitter:card" content="summary_large_image">
|
||||||
<meta name="twitter:creator" content="@teamGradio">
|
<meta name="twitter:creator" content="@teamGradio">
|
||||||
<meta name="twitter:title" content="{{title}}">
|
<meta name="twitter:title" content="{{ config['title'] or '' }}">
|
||||||
<meta name="twitter:description" content="{{description}}">
|
<meta name="twitter:description" content="{{ config['description'] or '' }}">
|
||||||
<meta name="twitter:image" content="{{thumbnail}}">
|
<meta name="twitter:image" content="{{ config['thumbnail'] or '' }}">
|
||||||
|
|
||||||
<title>Gradio</title>
|
<title>Gradio</title>
|
||||||
<link href="https://fonts.googleapis.com/css?family=Open+Sans" rel="stylesheet">
|
<link href="https://fonts.googleapis.com/css?family=Open+Sans" rel="stylesheet">
|
||||||
@ -127,7 +127,10 @@
|
|||||||
<script src="{{ url_for('static', filename='js/gradio.js') }}"></script>
|
<script src="{{ url_for('static', filename='js/gradio.js') }}"></script>
|
||||||
<script>
|
<script>
|
||||||
$.getJSON("/config/", function(config) {
|
$.getJSON("/config/", function(config) {
|
||||||
io = gradio_url(config, "/api/", "#interface_target", "/file/");
|
{% if examples %}
|
||||||
|
config["examples"] = {{ examples|tojson }}
|
||||||
|
{% endif %}
|
||||||
|
io = gradio_url(config, "/api/", "#interface_target", "/file/{% if path %}{{ path }}/{% endif %}");
|
||||||
});
|
});
|
||||||
const copyToClipboard = str => {
|
const copyToClipboard = str => {
|
||||||
const el = document.createElement('textarea');
|
const el = document.createElement('textarea');
|
||||||
|
Loading…
Reference in New Issue
Block a user