auto flagging

This commit is contained in:
Ali Abid 2021-01-29 10:23:17 -08:00
parent 0adcdfd97d
commit f4eb532f63
7 changed files with 67 additions and 63 deletions

View File

@ -19,7 +19,8 @@ iface = gr.Interface(calculator,
[4, "divide", 2],
[-4, "multiply", 2.5],
[0, "subtract", 1.2],
]
],
allow_flagging="auto"
)
if __name__ == "__main__":

View File

@ -37,4 +37,4 @@ iface = gr.Interface(image_classifier, image, label,
])
if __name__ == "__main__":
iface.launch(share=True)
iface.launch()

View File

@ -1,6 +1,6 @@
Metadata-Version: 1.0
Name: gradio
Version: 1.4.4
Version: 1.5.0
Summary: Python library for easily interacting with trained machine learning models
Home-page: https://github.com/gradio-app/gradio-UI
Author: Abubakar Abid

View File

@ -140,7 +140,7 @@ class Interface:
self.server_port = server_port
self.simple_server = None
self.allow_screenshot = allow_screenshot
self.allow_flagging = allow_flagging
self.allow_flagging = os.getenv("GRADIO_FLAGGING") or allow_flagging
self.flagging_dir = flagging_dir
Interface.instances.add(self)
self.analytics_enabled=analytics_enabled
@ -384,8 +384,7 @@ class Interface:
# Set up local flask server
config = self.get_config_file()
networking.set_config(config)
networking.set_meta_tags(self.title, self.description, self.thumbnail)
self.config = config
self.auth = auth
# Launch local flask server

View File

@ -5,7 +5,7 @@ Defines helper methods useful for setting up ports, launching servers, and handl
import os
import socket
import threading
from flask import Flask, request, jsonify, abort, send_file, render_template
from flask import Flask, request, jsonify, abort, send_file, render_template, redirect
from flask_cachebuster import CacheBuster
from flask_basicauth import BasicAuth
from flask_cors import CORS
@ -44,24 +44,11 @@ app = Flask(__name__,
CORS(app)
cache_buster = CacheBuster(config={'extensions': ['.js', '.css'], 'hash_size': 5})
cache_buster.init_app(app)
app.app_globals = {}
# Hide Flask default message
cli = sys.modules['flask.cli']
cli.show_server_banner = lambda *x: None
def set_meta_tags(title, description, thumbnail):
app.app_globals.update({
"title": title,
"description": description,
"thumbnail": thumbnail
})
def set_config(config):
app.app_globals["config"] = config
def get_local_ip_address():
try:
ip_address = requests.get('https://api.ipify.org').text
@ -96,24 +83,43 @@ def get_first_available_port(initial, final):
@app.route("/", methods=["GET"])
def main():
return render_template("index.html",
title=app.app_globals["title"],
description=app.app_globals["description"],
thumbnail=app.app_globals["thumbnail"],
config=app.interface.config,
vendor_prefix=(GRADIO_STATIC_ROOT if app.interface.share else ""),
css=app.interface.css
css=app.interface.css,
)
@app.route("/from_dir", methods=["GET"])
def main_from_flagging_dir():
return redirect("/from_dir/" + app.interface.flagging_dir)
@app.route("/from_dir/<path:path>", methods=["GET"])
def main_from_dir(path):
log_file = os.path.join(path, "log.csv")
if not os.path.exists(log_file):
abort(404)
with open(log_file) as logs:
examples = list(csv.reader(logs))
examples = examples[1:] #remove header
input_examples = [example[:len(app.interface.input_interfaces)] for example in examples]
return render_template("index.html",
config=app.interface.config,
vendor_prefix=(GRADIO_STATIC_ROOT if app.interface.share else ""),
css=app.interface.css,
path=path,
examples=input_examples
)
@app.route("/config/", methods=["GET"])
def config():
return jsonify(app.app_globals["config"])
return jsonify(app.interface.config)
@app.route("/enable_sharing/<path:path>", methods=["GET"])
def enable_sharing(path):
if path == "None":
path = None
app.app_globals["config"]["share_url"] = path
app.interface.config["share_url"] = path
return jsonify(success=True)
@ -122,6 +128,11 @@ def predict():
raw_input = request.json["data"]
prediction, durations = app.interface.process(raw_input)
output = {"data": prediction, "durations": durations}
if app.interface.allow_flagging == "auto":
try:
flag_data(raw_input)
except:
pass
return jsonify(output)
def log_feature_analytics(feature):
@ -206,39 +217,29 @@ def predict_examples():
return jsonify(output)
@app.route("/api/flag/", methods=["POST"])
def flag():
log_feature_analytics('flag')
def flag_data(data):
flag_path = os.path.join(app.cwd, app.interface.flagging_dir)
output = {'inputs': [app.interface.input_interfaces[
i].rebuild(
flag_path, request.json['data']['input_data'][i]) for i
in range(len(app.interface.input_interfaces))],
'outputs': [app.interface.output_interfaces[
i].rebuild(
flag_path, request.json['data']['output_data'][i])
for i
in range(len(app.interface.output_interfaces))]}
output = [app.interface.input_interfaces[i].rebuild(
flag_path, component_data)
for i, component_data in enumerate(data)]
log_fp = "{}/log.csv".format(flag_path)
is_new = not os.path.exists(log_fp)
with open(log_fp, "a") as csvfile:
headers = ["input_{}".format(i) for i in range(len(
output["inputs"]))] + ["output_{}".format(i) for i in
range(len(output["outputs"]))]
writer = csv.DictWriter(csvfile, delimiter=',',
lineterminator='\n',
fieldnames=headers)
writer = csv.writer(csvfile)
if is_new:
writer.writeheader()
headers = [interface[1]["label"] for interface in app.interface.config["input_interfaces"]]
writer.writerow(headers)
writer.writerow(
dict(zip(headers, output["inputs"] +
output["outputs"]))
)
return jsonify(success=True)
writer.writerow(output)
@app.route("/api/flag/", methods=["POST"])
def flag():
log_feature_analytics('flag')
data = request.json['data']['input_data']
flag_data(data)
return jsonify(success=True)
@app.route("/api/interpret/", methods=["POST"])
@ -263,9 +264,9 @@ def start_server(interface, server_name, server_port=None, auth=None):
server_port, server_port + TRY_NUM_PORTS
)
if auth is not None:
app.config['BASIC_AUTH_USERNAME'] = auth[0]
app.config['BASIC_AUTH_PASSWORD'] = auth[1]
app.config['BASIC_AUTH_FORCE'] = True
app.interface.config['BASIC_AUTH_USERNAME'] = auth[0]
app.interface.config['BASIC_AUTH_PASSWORD'] = auth[1]
app.interface.config['BASIC_AUTH_FORCE'] = True
basic_auth = BasicAuth(app)
app.interface = interface
app.cwd = os.getcwd()

View File

@ -205,13 +205,13 @@ function gradio(config, fn, target, example_file_path) {
if (!config["allow_embedding"]) {
target.find(".embedding").css("visibility", "hidden");
}
if (!config["allow_screenshot"] && !config["allow_flagging"] && !config["allow_interpretation"]) {
if (!config["allow_screenshot"] && config["allow_flagging"] !== true && !config["allow_interpretation"]) {
target.find(".screenshot, .record, .flag, .interpret").css("visibility", "hidden");
} else {
if (!config["allow_screenshot"]) {
target.find(".screenshot, .record").hide();
}
if (!config["allow_flagging"]) {
if (config["allow_flagging"] !== true) {
target.find(".flag").hide();
}
if (!config["allow_interpretation"]) {

View File

@ -19,14 +19,14 @@
<meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
<meta property="og:url" content="https://gradio.app/" />
<meta property="og:type" content="website" />
<meta property="og:image" content="{{thumbnail}}" />
<meta property="og:title" content="{{title}}" />
<meta property="og:description" content="{{description}}" />
<meta property="og:image" content="{{ config['thumbnail'] or '' }}" />
<meta property="og:title" content="{{ config['title'] or '' }}" />
<meta property="og:description" content="{{ config['description'] or '' }}" />
<meta name="twitter:card" content="summary_large_image">
<meta name="twitter:creator" content="@teamGradio">
<meta name="twitter:title" content="{{title}}">
<meta name="twitter:description" content="{{description}}">
<meta name="twitter:image" content="{{thumbnail}}">
<meta name="twitter:title" content="{{ config['title'] or '' }}">
<meta name="twitter:description" content="{{ config['description'] or '' }}">
<meta name="twitter:image" content="{{ config['thumbnail'] or '' }}">
<title>Gradio</title>
<link href="https://fonts.googleapis.com/css?family=Open+Sans" rel="stylesheet">
@ -127,7 +127,10 @@
<script src="{{ url_for('static', filename='js/gradio.js') }}"></script>
<script>
$.getJSON("/config/", function(config) {
io = gradio_url(config, "/api/", "#interface_target", "/file/");
{% if examples %}
config["examples"] = {{ examples|tojson }}
{% endif %}
io = gradio_url(config, "/api/", "#interface_target", "/file/{% if path %}{{ path }}/{% endif %}");
});
const copyToClipboard = str => {
const el = document.createElement('textarea');