From f4eb532f63aa380d7abf8587ff1409a034c8b986 Mon Sep 17 00:00:00 2001 From: Ali Abid Date: Fri, 29 Jan 2021 10:23:17 -0800 Subject: [PATCH] auto flagging --- demo/calculator.py | 3 +- demo/image_classifier.py | 2 +- gradio.egg-info/PKG-INFO | 2 +- gradio/interface.py | 5 +- gradio/networking.py | 97 +++++++++++++++++++------------------ gradio/static/js/gradio.js | 4 +- gradio/templates/index.html | 17 ++++--- 7 files changed, 67 insertions(+), 63 deletions(-) diff --git a/demo/calculator.py b/demo/calculator.py index b34bb8fb80..a03fc02f11 100644 --- a/demo/calculator.py +++ b/demo/calculator.py @@ -19,7 +19,8 @@ iface = gr.Interface(calculator, [4, "divide", 2], [-4, "multiply", 2.5], [0, "subtract", 1.2], - ] + ], + allow_flagging="auto" ) if __name__ == "__main__": diff --git a/demo/image_classifier.py b/demo/image_classifier.py index 3eb1b549ec..a8453cb141 100644 --- a/demo/image_classifier.py +++ b/demo/image_classifier.py @@ -37,4 +37,4 @@ iface = gr.Interface(image_classifier, image, label, ]) if __name__ == "__main__": - iface.launch(share=True) + iface.launch() diff --git a/gradio.egg-info/PKG-INFO b/gradio.egg-info/PKG-INFO index b9f1670c19..a15b7c8689 100644 --- a/gradio.egg-info/PKG-INFO +++ b/gradio.egg-info/PKG-INFO @@ -1,6 +1,6 @@ Metadata-Version: 1.0 Name: gradio -Version: 1.4.4 +Version: 1.5.0 Summary: Python library for easily interacting with trained machine learning models Home-page: https://github.com/gradio-app/gradio-UI Author: Abubakar Abid diff --git a/gradio/interface.py b/gradio/interface.py index 5996d16b1c..ebdb75818c 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -140,7 +140,7 @@ class Interface: self.server_port = server_port self.simple_server = None self.allow_screenshot = allow_screenshot - self.allow_flagging = allow_flagging + self.allow_flagging = os.getenv("GRADIO_FLAGGING") or allow_flagging self.flagging_dir = flagging_dir Interface.instances.add(self) self.analytics_enabled=analytics_enabled @@ -384,8 +384,7 @@ class Interface: # Set up local flask server config = self.get_config_file() - networking.set_config(config) - networking.set_meta_tags(self.title, self.description, self.thumbnail) + self.config = config self.auth = auth # Launch local flask server diff --git a/gradio/networking.py b/gradio/networking.py index d73ca8bc3c..baa0c1e728 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -5,7 +5,7 @@ Defines helper methods useful for setting up ports, launching servers, and handl import os import socket import threading -from flask import Flask, request, jsonify, abort, send_file, render_template +from flask import Flask, request, jsonify, abort, send_file, render_template, redirect from flask_cachebuster import CacheBuster from flask_basicauth import BasicAuth from flask_cors import CORS @@ -44,24 +44,11 @@ app = Flask(__name__, CORS(app) cache_buster = CacheBuster(config={'extensions': ['.js', '.css'], 'hash_size': 5}) cache_buster.init_app(app) -app.app_globals = {} # Hide Flask default message cli = sys.modules['flask.cli'] cli.show_server_banner = lambda *x: None -def set_meta_tags(title, description, thumbnail): - app.app_globals.update({ - "title": title, - "description": description, - "thumbnail": thumbnail - }) - - -def set_config(config): - app.app_globals["config"] = config - - def get_local_ip_address(): try: ip_address = requests.get('https://api.ipify.org').text @@ -96,24 +83,43 @@ def get_first_available_port(initial, final): @app.route("/", methods=["GET"]) def main(): return render_template("index.html", - title=app.app_globals["title"], - description=app.app_globals["description"], - thumbnail=app.app_globals["thumbnail"], + config=app.interface.config, vendor_prefix=(GRADIO_STATIC_ROOT if app.interface.share else ""), - css=app.interface.css + css=app.interface.css, + ) + +@app.route("/from_dir", methods=["GET"]) +def main_from_flagging_dir(): + return redirect("/from_dir/" + app.interface.flagging_dir) + +@app.route("/from_dir/", methods=["GET"]) +def main_from_dir(path): + log_file = os.path.join(path, "log.csv") + if not os.path.exists(log_file): + abort(404) + with open(log_file) as logs: + examples = list(csv.reader(logs)) + examples = examples[1:] #remove header + input_examples = [example[:len(app.interface.input_interfaces)] for example in examples] + return render_template("index.html", + config=app.interface.config, + vendor_prefix=(GRADIO_STATIC_ROOT if app.interface.share else ""), + css=app.interface.css, + path=path, + examples=input_examples ) @app.route("/config/", methods=["GET"]) def config(): - return jsonify(app.app_globals["config"]) + return jsonify(app.interface.config) @app.route("/enable_sharing/", methods=["GET"]) def enable_sharing(path): if path == "None": path = None - app.app_globals["config"]["share_url"] = path + app.interface.config["share_url"] = path return jsonify(success=True) @@ -122,6 +128,11 @@ def predict(): raw_input = request.json["data"] prediction, durations = app.interface.process(raw_input) output = {"data": prediction, "durations": durations} + if app.interface.allow_flagging == "auto": + try: + flag_data(raw_input) + except: + pass return jsonify(output) def log_feature_analytics(feature): @@ -206,39 +217,29 @@ def predict_examples(): return jsonify(output) -@app.route("/api/flag/", methods=["POST"]) -def flag(): - log_feature_analytics('flag') +def flag_data(data): flag_path = os.path.join(app.cwd, app.interface.flagging_dir) - output = {'inputs': [app.interface.input_interfaces[ - i].rebuild( - flag_path, request.json['data']['input_data'][i]) for i - in range(len(app.interface.input_interfaces))], - 'outputs': [app.interface.output_interfaces[ - i].rebuild( - flag_path, request.json['data']['output_data'][i]) - for i - in range(len(app.interface.output_interfaces))]} + output = [app.interface.input_interfaces[i].rebuild( + flag_path, component_data) + for i, component_data in enumerate(data)] log_fp = "{}/log.csv".format(flag_path) - is_new = not os.path.exists(log_fp) with open(log_fp, "a") as csvfile: - headers = ["input_{}".format(i) for i in range(len( - output["inputs"]))] + ["output_{}".format(i) for i in - range(len(output["outputs"]))] - writer = csv.DictWriter(csvfile, delimiter=',', - lineterminator='\n', - fieldnames=headers) + writer = csv.writer(csvfile) if is_new: - writer.writeheader() + headers = [interface[1]["label"] for interface in app.interface.config["input_interfaces"]] + writer.writerow(headers) - writer.writerow( - dict(zip(headers, output["inputs"] + - output["outputs"])) - ) - return jsonify(success=True) + writer.writerow(output) + +@app.route("/api/flag/", methods=["POST"]) +def flag(): + log_feature_analytics('flag') + data = request.json['data']['input_data'] + flag_data(data) + return jsonify(success=True) @app.route("/api/interpret/", methods=["POST"]) @@ -263,9 +264,9 @@ def start_server(interface, server_name, server_port=None, auth=None): server_port, server_port + TRY_NUM_PORTS ) if auth is not None: - app.config['BASIC_AUTH_USERNAME'] = auth[0] - app.config['BASIC_AUTH_PASSWORD'] = auth[1] - app.config['BASIC_AUTH_FORCE'] = True + app.interface.config['BASIC_AUTH_USERNAME'] = auth[0] + app.interface.config['BASIC_AUTH_PASSWORD'] = auth[1] + app.interface.config['BASIC_AUTH_FORCE'] = True basic_auth = BasicAuth(app) app.interface = interface app.cwd = os.getcwd() diff --git a/gradio/static/js/gradio.js b/gradio/static/js/gradio.js index cb9a54f5f9..1a63283aef 100644 --- a/gradio/static/js/gradio.js +++ b/gradio/static/js/gradio.js @@ -205,13 +205,13 @@ function gradio(config, fn, target, example_file_path) { if (!config["allow_embedding"]) { target.find(".embedding").css("visibility", "hidden"); } - if (!config["allow_screenshot"] && !config["allow_flagging"] && !config["allow_interpretation"]) { + if (!config["allow_screenshot"] && config["allow_flagging"] !== true && !config["allow_interpretation"]) { target.find(".screenshot, .record, .flag, .interpret").css("visibility", "hidden"); } else { if (!config["allow_screenshot"]) { target.find(".screenshot, .record").hide(); } - if (!config["allow_flagging"]) { + if (config["allow_flagging"] !== true) { target.find(".flag").hide(); } if (!config["allow_interpretation"]) { diff --git a/gradio/templates/index.html b/gradio/templates/index.html index f5c6f8d07e..caecf48745 100644 --- a/gradio/templates/index.html +++ b/gradio/templates/index.html @@ -19,14 +19,14 @@ - - - + + + - - - + + + Gradio @@ -127,7 +127,10 @@