diff --git a/gradio/inputs.py b/gradio/inputs.py index a130642f61..62fac86a19 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -63,6 +63,11 @@ class AbstractInput(ABC): """ return {} + def rebuild_flagged(self, dir, msg): + """ + All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64) + """ + pass class Textbox(AbstractInput): """ @@ -290,6 +295,16 @@ class Image(AbstractInput): else: return example + def rebuild_flagged(self, dir, msg): + """ + Default rebuild method to decode a base64 image + """ + im = preprocessing_utils.decode_base64_to_image(msg) + timestamp = datetime.datetime.now() + filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png' + im.save(f'{dir}/{filename}', 'PNG') + return filename + class Sketchpad(AbstractInput): """ diff --git a/gradio/networking.py b/gradio/networking.py index ecae3abad3..3a36a4ced1 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -36,7 +36,7 @@ CONFIG_FILE = "static/config.json" ASSOCIATION_PATH_IN_STATIC = "static/apple-app-site-association" ASSOCIATION_PATH_IN_ROOT = "apple-app-site-association" -FLAGGING_DIRECTORY = 'static/flagged/' +FLAGGING_DIRECTORY = 'flagged/' FLAGGING_FILENAME = 'data.txt' analytics.write_key = "uxIFddIEuuUcFLf9VgH2teTEtPlWdkNy" analytics_url = 'https://api.gradio.app/' @@ -175,16 +175,6 @@ def serve_files_in_background(interface, port, directory_to_serve=None, server_n if interface.saliency is not None: saliency = interface.saliency(raw_input, prediction) output['saliency'] = saliency.tolist() - # if interface.always_flag: - # msg = json.loads(data_string) - # flag_dir = os.path.join(FLAGGING_DIRECTORY, str(interface.hash)) - # os.makedirs(flag_dir, exist_ok=True) - # output_flag = {'input': interface.input_interface.rebuild_flagged(flag_dir, msg['data']), - # 'output': interface.output_interface.rebuild_flagged(flag_dir, processed_output), - # } - # with open(os.path.join(flag_dir, FLAGGING_FILENAME), 'a+') as f: - # f.write(json.dumps(output_flag)) - # f.write("\n") self.wfile.write(json.dumps(output).encode()) diff --git a/gradio/outputs.py b/gradio/outputs.py index df68483ba2..553b93c7b3 100644 --- a/gradio/outputs.py +++ b/gradio/outputs.py @@ -44,6 +44,12 @@ class AbstractOutput(ABC): """ return {} + def rebuild_flagged(self, dir, msg): + """ + All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64) + """ + pass + class Textbox(AbstractOutput): ''' @@ -130,6 +136,11 @@ class Label(AbstractOutput): "label": {}, } + def rebuild_flagged(self, dir, msg): + """ + Default rebuild method for label + """ + return json.loads(msg) class Image(AbstractOutput): ''' diff --git a/gradio/static/css/gradio.css b/gradio/static/css/gradio.css index be4b7332e9..06293cc3d1 100644 --- a/gradio/static/css/gradio.css +++ b/gradio/static/css/gradio.css @@ -75,6 +75,10 @@ input.submit { input.submit:hover { background-color: #f39c12; } + +.flag.flagged { + background-color: pink; +} /* label:hover { background-color: lightgray; } */ diff --git a/gradio/static/js/all_io.js b/gradio/static/js/all_io.js index 7a7ca9eae7..52ebe0ecc7 100644 --- a/gradio/static/js/all_io.js +++ b/gradio/static/js/all_io.js @@ -56,12 +56,12 @@ var io_master_template = { this.target.find(".output_interfaces").css("opacity", 1); } }, - flag: function(message) { + flag: function() { var post_data = { 'data': { 'input_data' : toStringIfObject(this.last_input) , 'output_data' : toStringIfObject(this.last_output), - 'message' : message + 'message' : "no-message" } } $.ajax({type: "POST", diff --git a/gradio/static/js/gradio.js b/gradio/static/js/gradio.js index e21cfeb93f..eb4061343c 100644 --- a/gradio/static/js/gradio.js +++ b/gradio/static/js/gradio.js @@ -22,7 +22,7 @@ function gradio(config, fn, target) { - + `); let io_master = Object.create(io_master_template); @@ -117,6 +117,7 @@ function gradio(config, fn, target) { output_interface.clear(); } target.find(".flag").removeClass("flagged"); + target.find(".flag").val("FLAG"); target.find(".flag_message").empty(); target.find(".loading").addClass("invisible"); target.find(".loading_time").text(""); @@ -146,11 +147,22 @@ function gradio(config, fn, target) { target.find(".submit").click(function() { io_master.gather(); target.find(".flag").removeClass("flagged"); + target.find(".flag").val("FLAG"); }) } if (!config.show_input) { target.find(".input_panel").hide(); - } + } + + target.find(".flag").click(function() { + if (io_master.last_output) { + target.find(".flag").addClass("flagged"); + target.find(".flag").val("FLAGGED"); + io_master.flag(); + + // io_master.flag($(".flag_message").val()); + } + }) return io_master; }