auto flagging

This commit is contained in:
Ali Abid 2021-01-29 10:23:17 -08:00
parent 0adcdfd97d
commit f4eb532f63
7 changed files with 67 additions and 63 deletions

View File

@ -19,7 +19,8 @@ iface = gr.Interface(calculator,
[4, "divide", 2], [4, "divide", 2],
[-4, "multiply", 2.5], [-4, "multiply", 2.5],
[0, "subtract", 1.2], [0, "subtract", 1.2],
] ],
allow_flagging="auto"
) )
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -37,4 +37,4 @@ iface = gr.Interface(image_classifier, image, label,
]) ])
if __name__ == "__main__": if __name__ == "__main__":
iface.launch(share=True) iface.launch()

View File

@ -1,6 +1,6 @@
Metadata-Version: 1.0 Metadata-Version: 1.0
Name: gradio Name: gradio
Version: 1.4.4 Version: 1.5.0
Summary: Python library for easily interacting with trained machine learning models Summary: Python library for easily interacting with trained machine learning models
Home-page: https://github.com/gradio-app/gradio-UI Home-page: https://github.com/gradio-app/gradio-UI
Author: Abubakar Abid Author: Abubakar Abid

View File

@ -140,7 +140,7 @@ class Interface:
self.server_port = server_port self.server_port = server_port
self.simple_server = None self.simple_server = None
self.allow_screenshot = allow_screenshot self.allow_screenshot = allow_screenshot
self.allow_flagging = allow_flagging self.allow_flagging = os.getenv("GRADIO_FLAGGING") or allow_flagging
self.flagging_dir = flagging_dir self.flagging_dir = flagging_dir
Interface.instances.add(self) Interface.instances.add(self)
self.analytics_enabled=analytics_enabled self.analytics_enabled=analytics_enabled
@ -384,8 +384,7 @@ class Interface:
# Set up local flask server # Set up local flask server
config = self.get_config_file() config = self.get_config_file()
networking.set_config(config) self.config = config
networking.set_meta_tags(self.title, self.description, self.thumbnail)
self.auth = auth self.auth = auth
# Launch local flask server # Launch local flask server

View File

@ -5,7 +5,7 @@ Defines helper methods useful for setting up ports, launching servers, and handl
import os import os
import socket import socket
import threading import threading
from flask import Flask, request, jsonify, abort, send_file, render_template from flask import Flask, request, jsonify, abort, send_file, render_template, redirect
from flask_cachebuster import CacheBuster from flask_cachebuster import CacheBuster
from flask_basicauth import BasicAuth from flask_basicauth import BasicAuth
from flask_cors import CORS from flask_cors import CORS
@ -44,24 +44,11 @@ app = Flask(__name__,
CORS(app) CORS(app)
cache_buster = CacheBuster(config={'extensions': ['.js', '.css'], 'hash_size': 5}) cache_buster = CacheBuster(config={'extensions': ['.js', '.css'], 'hash_size': 5})
cache_buster.init_app(app) cache_buster.init_app(app)
app.app_globals = {}
# Hide Flask default message # Hide Flask default message
cli = sys.modules['flask.cli'] cli = sys.modules['flask.cli']
cli.show_server_banner = lambda *x: None cli.show_server_banner = lambda *x: None
def set_meta_tags(title, description, thumbnail):
app.app_globals.update({
"title": title,
"description": description,
"thumbnail": thumbnail
})
def set_config(config):
app.app_globals["config"] = config
def get_local_ip_address(): def get_local_ip_address():
try: try:
ip_address = requests.get('https://api.ipify.org').text ip_address = requests.get('https://api.ipify.org').text
@ -96,24 +83,43 @@ def get_first_available_port(initial, final):
@app.route("/", methods=["GET"]) @app.route("/", methods=["GET"])
def main(): def main():
return render_template("index.html", return render_template("index.html",
title=app.app_globals["title"], config=app.interface.config,
description=app.app_globals["description"],
thumbnail=app.app_globals["thumbnail"],
vendor_prefix=(GRADIO_STATIC_ROOT if app.interface.share else ""), vendor_prefix=(GRADIO_STATIC_ROOT if app.interface.share else ""),
css=app.interface.css css=app.interface.css,
)
@app.route("/from_dir", methods=["GET"])
def main_from_flagging_dir():
return redirect("/from_dir/" + app.interface.flagging_dir)
@app.route("/from_dir/<path:path>", methods=["GET"])
def main_from_dir(path):
log_file = os.path.join(path, "log.csv")
if not os.path.exists(log_file):
abort(404)
with open(log_file) as logs:
examples = list(csv.reader(logs))
examples = examples[1:] #remove header
input_examples = [example[:len(app.interface.input_interfaces)] for example in examples]
return render_template("index.html",
config=app.interface.config,
vendor_prefix=(GRADIO_STATIC_ROOT if app.interface.share else ""),
css=app.interface.css,
path=path,
examples=input_examples
) )
@app.route("/config/", methods=["GET"]) @app.route("/config/", methods=["GET"])
def config(): def config():
return jsonify(app.app_globals["config"]) return jsonify(app.interface.config)
@app.route("/enable_sharing/<path:path>", methods=["GET"]) @app.route("/enable_sharing/<path:path>", methods=["GET"])
def enable_sharing(path): def enable_sharing(path):
if path == "None": if path == "None":
path = None path = None
app.app_globals["config"]["share_url"] = path app.interface.config["share_url"] = path
return jsonify(success=True) return jsonify(success=True)
@ -122,6 +128,11 @@ def predict():
raw_input = request.json["data"] raw_input = request.json["data"]
prediction, durations = app.interface.process(raw_input) prediction, durations = app.interface.process(raw_input)
output = {"data": prediction, "durations": durations} output = {"data": prediction, "durations": durations}
if app.interface.allow_flagging == "auto":
try:
flag_data(raw_input)
except:
pass
return jsonify(output) return jsonify(output)
def log_feature_analytics(feature): def log_feature_analytics(feature):
@ -206,39 +217,29 @@ def predict_examples():
return jsonify(output) return jsonify(output)
@app.route("/api/flag/", methods=["POST"]) def flag_data(data):
def flag():
log_feature_analytics('flag')
flag_path = os.path.join(app.cwd, app.interface.flagging_dir) flag_path = os.path.join(app.cwd, app.interface.flagging_dir)
output = {'inputs': [app.interface.input_interfaces[ output = [app.interface.input_interfaces[i].rebuild(
i].rebuild( flag_path, component_data)
flag_path, request.json['data']['input_data'][i]) for i for i, component_data in enumerate(data)]
in range(len(app.interface.input_interfaces))],
'outputs': [app.interface.output_interfaces[
i].rebuild(
flag_path, request.json['data']['output_data'][i])
for i
in range(len(app.interface.output_interfaces))]}
log_fp = "{}/log.csv".format(flag_path) log_fp = "{}/log.csv".format(flag_path)
is_new = not os.path.exists(log_fp) is_new = not os.path.exists(log_fp)
with open(log_fp, "a") as csvfile: with open(log_fp, "a") as csvfile:
headers = ["input_{}".format(i) for i in range(len( writer = csv.writer(csvfile)
output["inputs"]))] + ["output_{}".format(i) for i in
range(len(output["outputs"]))]
writer = csv.DictWriter(csvfile, delimiter=',',
lineterminator='\n',
fieldnames=headers)
if is_new: if is_new:
writer.writeheader() headers = [interface[1]["label"] for interface in app.interface.config["input_interfaces"]]
writer.writerow(headers)
writer.writerow( writer.writerow(output)
dict(zip(headers, output["inputs"] +
output["outputs"])) @app.route("/api/flag/", methods=["POST"])
) def flag():
return jsonify(success=True) log_feature_analytics('flag')
data = request.json['data']['input_data']
flag_data(data)
return jsonify(success=True)
@app.route("/api/interpret/", methods=["POST"]) @app.route("/api/interpret/", methods=["POST"])
@ -263,9 +264,9 @@ def start_server(interface, server_name, server_port=None, auth=None):
server_port, server_port + TRY_NUM_PORTS server_port, server_port + TRY_NUM_PORTS
) )
if auth is not None: if auth is not None:
app.config['BASIC_AUTH_USERNAME'] = auth[0] app.interface.config['BASIC_AUTH_USERNAME'] = auth[0]
app.config['BASIC_AUTH_PASSWORD'] = auth[1] app.interface.config['BASIC_AUTH_PASSWORD'] = auth[1]
app.config['BASIC_AUTH_FORCE'] = True app.interface.config['BASIC_AUTH_FORCE'] = True
basic_auth = BasicAuth(app) basic_auth = BasicAuth(app)
app.interface = interface app.interface = interface
app.cwd = os.getcwd() app.cwd = os.getcwd()

View File

@ -205,13 +205,13 @@ function gradio(config, fn, target, example_file_path) {
if (!config["allow_embedding"]) { if (!config["allow_embedding"]) {
target.find(".embedding").css("visibility", "hidden"); target.find(".embedding").css("visibility", "hidden");
} }
if (!config["allow_screenshot"] && !config["allow_flagging"] && !config["allow_interpretation"]) { if (!config["allow_screenshot"] && config["allow_flagging"] !== true && !config["allow_interpretation"]) {
target.find(".screenshot, .record, .flag, .interpret").css("visibility", "hidden"); target.find(".screenshot, .record, .flag, .interpret").css("visibility", "hidden");
} else { } else {
if (!config["allow_screenshot"]) { if (!config["allow_screenshot"]) {
target.find(".screenshot, .record").hide(); target.find(".screenshot, .record").hide();
} }
if (!config["allow_flagging"]) { if (config["allow_flagging"] !== true) {
target.find(".flag").hide(); target.find(".flag").hide();
} }
if (!config["allow_interpretation"]) { if (!config["allow_interpretation"]) {

View File

@ -19,14 +19,14 @@
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no"> <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<meta property="og:url" content="https://gradio.app/" /> <meta property="og:url" content="https://gradio.app/" />
<meta property="og:type" content="website" /> <meta property="og:type" content="website" />
<meta property="og:image" content="{{thumbnail}}" /> <meta property="og:image" content="{{ config['thumbnail'] or '' }}" />
<meta property="og:title" content="{{title}}" /> <meta property="og:title" content="{{ config['title'] or '' }}" />
<meta property="og:description" content="{{description}}" /> <meta property="og:description" content="{{ config['description'] or '' }}" />
<meta name="twitter:card" content="summary_large_image"> <meta name="twitter:card" content="summary_large_image">
<meta name="twitter:creator" content="@teamGradio"> <meta name="twitter:creator" content="@teamGradio">
<meta name="twitter:title" content="{{title}}"> <meta name="twitter:title" content="{{ config['title'] or '' }}">
<meta name="twitter:description" content="{{description}}"> <meta name="twitter:description" content="{{ config['description'] or '' }}">
<meta name="twitter:image" content="{{thumbnail}}"> <meta name="twitter:image" content="{{ config['thumbnail'] or '' }}">
<title>Gradio</title> <title>Gradio</title>
<link href="https://fonts.googleapis.com/css?family=Open+Sans" rel="stylesheet"> <link href="https://fonts.googleapis.com/css?family=Open+Sans" rel="stylesheet">
@ -127,7 +127,10 @@
<script src="{{ url_for('static', filename='js/gradio.js') }}"></script> <script src="{{ url_for('static', filename='js/gradio.js') }}"></script>
<script> <script>
$.getJSON("/config/", function(config) { $.getJSON("/config/", function(config) {
io = gradio_url(config, "/api/", "#interface_target", "/file/"); {% if examples %}
config["examples"] = {{ examples|tojson }}
{% endif %}
io = gradio_url(config, "/api/", "#interface_target", "/file/{% if path %}{{ path }}/{% endif %}");
}); });
const copyToClipboard = str => { const copyToClipboard = str => {
const el = document.createElement('textarea'); const el = document.createElement('textarea');