mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
auto flagging
This commit is contained in:
parent
0adcdfd97d
commit
f4eb532f63
@ -19,7 +19,8 @@ iface = gr.Interface(calculator,
|
||||
[4, "divide", 2],
|
||||
[-4, "multiply", 2.5],
|
||||
[0, "subtract", 1.2],
|
||||
]
|
||||
],
|
||||
allow_flagging="auto"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -37,4 +37,4 @@ iface = gr.Interface(image_classifier, image, label,
|
||||
])
|
||||
|
||||
if __name__ == "__main__":
|
||||
iface.launch(share=True)
|
||||
iface.launch()
|
||||
|
@ -1,6 +1,6 @@
|
||||
Metadata-Version: 1.0
|
||||
Name: gradio
|
||||
Version: 1.4.4
|
||||
Version: 1.5.0
|
||||
Summary: Python library for easily interacting with trained machine learning models
|
||||
Home-page: https://github.com/gradio-app/gradio-UI
|
||||
Author: Abubakar Abid
|
||||
|
@ -140,7 +140,7 @@ class Interface:
|
||||
self.server_port = server_port
|
||||
self.simple_server = None
|
||||
self.allow_screenshot = allow_screenshot
|
||||
self.allow_flagging = allow_flagging
|
||||
self.allow_flagging = os.getenv("GRADIO_FLAGGING") or allow_flagging
|
||||
self.flagging_dir = flagging_dir
|
||||
Interface.instances.add(self)
|
||||
self.analytics_enabled=analytics_enabled
|
||||
@ -384,8 +384,7 @@ class Interface:
|
||||
|
||||
# Set up local flask server
|
||||
config = self.get_config_file()
|
||||
networking.set_config(config)
|
||||
networking.set_meta_tags(self.title, self.description, self.thumbnail)
|
||||
self.config = config
|
||||
self.auth = auth
|
||||
|
||||
# Launch local flask server
|
||||
|
@ -5,7 +5,7 @@ Defines helper methods useful for setting up ports, launching servers, and handl
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
from flask import Flask, request, jsonify, abort, send_file, render_template
|
||||
from flask import Flask, request, jsonify, abort, send_file, render_template, redirect
|
||||
from flask_cachebuster import CacheBuster
|
||||
from flask_basicauth import BasicAuth
|
||||
from flask_cors import CORS
|
||||
@ -44,24 +44,11 @@ app = Flask(__name__,
|
||||
CORS(app)
|
||||
cache_buster = CacheBuster(config={'extensions': ['.js', '.css'], 'hash_size': 5})
|
||||
cache_buster.init_app(app)
|
||||
app.app_globals = {}
|
||||
|
||||
# Hide Flask default message
|
||||
cli = sys.modules['flask.cli']
|
||||
cli.show_server_banner = lambda *x: None
|
||||
|
||||
def set_meta_tags(title, description, thumbnail):
|
||||
app.app_globals.update({
|
||||
"title": title,
|
||||
"description": description,
|
||||
"thumbnail": thumbnail
|
||||
})
|
||||
|
||||
|
||||
def set_config(config):
|
||||
app.app_globals["config"] = config
|
||||
|
||||
|
||||
def get_local_ip_address():
|
||||
try:
|
||||
ip_address = requests.get('https://api.ipify.org').text
|
||||
@ -96,24 +83,43 @@ def get_first_available_port(initial, final):
|
||||
@app.route("/", methods=["GET"])
|
||||
def main():
|
||||
return render_template("index.html",
|
||||
title=app.app_globals["title"],
|
||||
description=app.app_globals["description"],
|
||||
thumbnail=app.app_globals["thumbnail"],
|
||||
config=app.interface.config,
|
||||
vendor_prefix=(GRADIO_STATIC_ROOT if app.interface.share else ""),
|
||||
css=app.interface.css
|
||||
css=app.interface.css,
|
||||
)
|
||||
|
||||
@app.route("/from_dir", methods=["GET"])
|
||||
def main_from_flagging_dir():
|
||||
return redirect("/from_dir/" + app.interface.flagging_dir)
|
||||
|
||||
@app.route("/from_dir/<path:path>", methods=["GET"])
|
||||
def main_from_dir(path):
|
||||
log_file = os.path.join(path, "log.csv")
|
||||
if not os.path.exists(log_file):
|
||||
abort(404)
|
||||
with open(log_file) as logs:
|
||||
examples = list(csv.reader(logs))
|
||||
examples = examples[1:] #remove header
|
||||
input_examples = [example[:len(app.interface.input_interfaces)] for example in examples]
|
||||
return render_template("index.html",
|
||||
config=app.interface.config,
|
||||
vendor_prefix=(GRADIO_STATIC_ROOT if app.interface.share else ""),
|
||||
css=app.interface.css,
|
||||
path=path,
|
||||
examples=input_examples
|
||||
)
|
||||
|
||||
|
||||
@app.route("/config/", methods=["GET"])
|
||||
def config():
|
||||
return jsonify(app.app_globals["config"])
|
||||
return jsonify(app.interface.config)
|
||||
|
||||
|
||||
@app.route("/enable_sharing/<path:path>", methods=["GET"])
|
||||
def enable_sharing(path):
|
||||
if path == "None":
|
||||
path = None
|
||||
app.app_globals["config"]["share_url"] = path
|
||||
app.interface.config["share_url"] = path
|
||||
return jsonify(success=True)
|
||||
|
||||
|
||||
@ -122,6 +128,11 @@ def predict():
|
||||
raw_input = request.json["data"]
|
||||
prediction, durations = app.interface.process(raw_input)
|
||||
output = {"data": prediction, "durations": durations}
|
||||
if app.interface.allow_flagging == "auto":
|
||||
try:
|
||||
flag_data(raw_input)
|
||||
except:
|
||||
pass
|
||||
return jsonify(output)
|
||||
|
||||
def log_feature_analytics(feature):
|
||||
@ -206,39 +217,29 @@ def predict_examples():
|
||||
return jsonify(output)
|
||||
|
||||
|
||||
@app.route("/api/flag/", methods=["POST"])
|
||||
def flag():
|
||||
log_feature_analytics('flag')
|
||||
def flag_data(data):
|
||||
flag_path = os.path.join(app.cwd, app.interface.flagging_dir)
|
||||
output = {'inputs': [app.interface.input_interfaces[
|
||||
i].rebuild(
|
||||
flag_path, request.json['data']['input_data'][i]) for i
|
||||
in range(len(app.interface.input_interfaces))],
|
||||
'outputs': [app.interface.output_interfaces[
|
||||
i].rebuild(
|
||||
flag_path, request.json['data']['output_data'][i])
|
||||
for i
|
||||
in range(len(app.interface.output_interfaces))]}
|
||||
output = [app.interface.input_interfaces[i].rebuild(
|
||||
flag_path, component_data)
|
||||
for i, component_data in enumerate(data)]
|
||||
|
||||
log_fp = "{}/log.csv".format(flag_path)
|
||||
|
||||
is_new = not os.path.exists(log_fp)
|
||||
|
||||
with open(log_fp, "a") as csvfile:
|
||||
headers = ["input_{}".format(i) for i in range(len(
|
||||
output["inputs"]))] + ["output_{}".format(i) for i in
|
||||
range(len(output["outputs"]))]
|
||||
writer = csv.DictWriter(csvfile, delimiter=',',
|
||||
lineterminator='\n',
|
||||
fieldnames=headers)
|
||||
writer = csv.writer(csvfile)
|
||||
if is_new:
|
||||
writer.writeheader()
|
||||
headers = [interface[1]["label"] for interface in app.interface.config["input_interfaces"]]
|
||||
writer.writerow(headers)
|
||||
|
||||
writer.writerow(
|
||||
dict(zip(headers, output["inputs"] +
|
||||
output["outputs"]))
|
||||
)
|
||||
return jsonify(success=True)
|
||||
writer.writerow(output)
|
||||
|
||||
@app.route("/api/flag/", methods=["POST"])
|
||||
def flag():
|
||||
log_feature_analytics('flag')
|
||||
data = request.json['data']['input_data']
|
||||
flag_data(data)
|
||||
return jsonify(success=True)
|
||||
|
||||
|
||||
@app.route("/api/interpret/", methods=["POST"])
|
||||
@ -263,9 +264,9 @@ def start_server(interface, server_name, server_port=None, auth=None):
|
||||
server_port, server_port + TRY_NUM_PORTS
|
||||
)
|
||||
if auth is not None:
|
||||
app.config['BASIC_AUTH_USERNAME'] = auth[0]
|
||||
app.config['BASIC_AUTH_PASSWORD'] = auth[1]
|
||||
app.config['BASIC_AUTH_FORCE'] = True
|
||||
app.interface.config['BASIC_AUTH_USERNAME'] = auth[0]
|
||||
app.interface.config['BASIC_AUTH_PASSWORD'] = auth[1]
|
||||
app.interface.config['BASIC_AUTH_FORCE'] = True
|
||||
basic_auth = BasicAuth(app)
|
||||
app.interface = interface
|
||||
app.cwd = os.getcwd()
|
||||
|
@ -205,13 +205,13 @@ function gradio(config, fn, target, example_file_path) {
|
||||
if (!config["allow_embedding"]) {
|
||||
target.find(".embedding").css("visibility", "hidden");
|
||||
}
|
||||
if (!config["allow_screenshot"] && !config["allow_flagging"] && !config["allow_interpretation"]) {
|
||||
if (!config["allow_screenshot"] && config["allow_flagging"] !== true && !config["allow_interpretation"]) {
|
||||
target.find(".screenshot, .record, .flag, .interpret").css("visibility", "hidden");
|
||||
} else {
|
||||
if (!config["allow_screenshot"]) {
|
||||
target.find(".screenshot, .record").hide();
|
||||
}
|
||||
if (!config["allow_flagging"]) {
|
||||
if (config["allow_flagging"] !== true) {
|
||||
target.find(".flag").hide();
|
||||
}
|
||||
if (!config["allow_interpretation"]) {
|
||||
|
@ -19,14 +19,14 @@
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
|
||||
<meta property="og:url" content="https://gradio.app/" />
|
||||
<meta property="og:type" content="website" />
|
||||
<meta property="og:image" content="{{thumbnail}}" />
|
||||
<meta property="og:title" content="{{title}}" />
|
||||
<meta property="og:description" content="{{description}}" />
|
||||
<meta property="og:image" content="{{ config['thumbnail'] or '' }}" />
|
||||
<meta property="og:title" content="{{ config['title'] or '' }}" />
|
||||
<meta property="og:description" content="{{ config['description'] or '' }}" />
|
||||
<meta name="twitter:card" content="summary_large_image">
|
||||
<meta name="twitter:creator" content="@teamGradio">
|
||||
<meta name="twitter:title" content="{{title}}">
|
||||
<meta name="twitter:description" content="{{description}}">
|
||||
<meta name="twitter:image" content="{{thumbnail}}">
|
||||
<meta name="twitter:title" content="{{ config['title'] or '' }}">
|
||||
<meta name="twitter:description" content="{{ config['description'] or '' }}">
|
||||
<meta name="twitter:image" content="{{ config['thumbnail'] or '' }}">
|
||||
|
||||
<title>Gradio</title>
|
||||
<link href="https://fonts.googleapis.com/css?family=Open+Sans" rel="stylesheet">
|
||||
@ -127,7 +127,10 @@
|
||||
<script src="{{ url_for('static', filename='js/gradio.js') }}"></script>
|
||||
<script>
|
||||
$.getJSON("/config/", function(config) {
|
||||
io = gradio_url(config, "/api/", "#interface_target", "/file/");
|
||||
{% if examples %}
|
||||
config["examples"] = {{ examples|tojson }}
|
||||
{% endif %}
|
||||
io = gradio_url(config, "/api/", "#interface_target", "/file/{% if path %}{{ path }}/{% endif %}");
|
||||
});
|
||||
const copyToClipboard = str => {
|
||||
const el = document.createElement('textarea');
|
||||
|
Loading…
Reference in New Issue
Block a user