diff --git a/gradio/component.py b/gradio/component.py index 74b674426e..5722664f9d 100644 --- a/gradio/component.py +++ b/gradio/component.py @@ -1,3 +1,7 @@ +import os +import shutil +from gradio import processing_utils + class Component(): """ A class for defining the methods that all gradio input and output components should have. @@ -19,12 +23,34 @@ class Component(): """ return {} - def rebuild(self, dir, data): + def save_flagged(self, dir, label, data): """ - All interfaces should define a method that rebuilds the flagged input when it's passed back (i.e. rebuilds image from base64) + Saves flagged data from component """ return data + def restore_flagged(self, data): + """ + Restores flagged data from logs + """ + return data + + def save_flagged_file(self, dir, label, data): + file = processing_utils.decode_base64_to_file(data) + old_file_name = file.name + output_dir = os.path.join(dir, label) + if os.path.exists(output_dir): + file_index = len(os.listdir(output_dir)) + else: + os.mkdir(output_dir) + file_index = 0 + new_file_name = str(file_index) + if "." in old_file_name: + uploaded_format = old_file_name.split(".")[-1].lower() + new_file_name += "." + uploaded_format + shutil.move(old_file_name, os.path.join(dir, label, new_file_name)) + return label + "/" + new_file_name + @classmethod def get_all_shortcut_implementations(cls): shortcuts = {} diff --git a/gradio/inputs.py b/gradio/inputs.py index daa269d59b..77283519a6 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -7,6 +7,7 @@ automatically added to a registry, which allows them to be easily referenced in import datetime import json import os +import shutil import time import warnings from gradio.component import Component @@ -466,6 +467,14 @@ class CheckboxGroup(InputComponent): else: raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'value', 'index'.") + def save_flagged(self, dir, label, data): + """ + Returns: (List[str]]) + """ + return json.dumps(data) + + def restore_flagged(self, data): + return json.loads(data) class Radio(InputComponent): @@ -713,15 +722,11 @@ class Image(InputComponent): im = processing_utils.resize_and_crop(im, (shape[0], shape[1])) return np.asarray(im).flatten() - def rebuild(self, dir, data): + def save_flagged(self, dir, label, data): """ - Default rebuild method to decode a base64 image + Returns: (str) path to image file """ - im = processing_utils.decode_base64_to_image(data) - timestamp = datetime.datetime.now() - filename = f'input_{timestamp.strftime("%Y-%m-%d-%H-%M-%S")}.png' - im.save(f'{dir}/{filename}', 'PNG') - return filename + return self.save_flagged_file(dir, label, data) class Video(InputComponent): @@ -766,6 +771,12 @@ class Video(InputComponent): def preprocess_example(self, x): return processing_utils.encode_file_to_base64(x) + def save_flagged(self, dir, label, data): + """ + Returns: (str) path to video file + """ + return self.save_flagged_file(dir, label, data) + class Audio(InputComponent): """ Component accepts audio input files. @@ -865,6 +876,12 @@ class Audio(InputComponent): else: raise ValueError("Unknown type: " + str(self.type) + ". Please choose from: 'numpy', 'mfcc', 'file'.") + def save_flagged(self, dir, label, data): + """ + Returns: (str) path to audio file + """ + return self.save_flagged_file(dir, label, data) + class File(InputComponent): """ @@ -906,6 +923,12 @@ class File(InputComponent): def embed(self, x): raise NotImplementedError("File doesn't currently support embeddings") + def save_flagged(self, dir, label, data): + """ + Returns: (str) path to file + """ + return self.save_flagged_file(dir, label, data["data"]) + class Dataframe(InputComponent): """ @@ -1000,6 +1023,15 @@ class Dataframe(InputComponent): def embed(self, x): raise NotImplementedError("DataFrame doesn't currently support embeddings") + def save_flagged(self, dir, label, data): + """ + Returns: (List[List[Union[str, float]]]) 2D array + """ + return json.dumps(data) + + def restore_flagged(self, data): + return json.loads(data) + ####################### # DEPRECATED COMPONENTS @@ -1050,7 +1082,7 @@ class Sketchpad(InputComponent): def process_example(self, example): return processing_utils.encode_file_to_base64(example) - def rebuild(self, dir, data): + def save_flagged(self, dir, label, data): """ Default rebuild method to decode a base64 image """ @@ -1089,7 +1121,7 @@ class Webcam(InputComponent): im, (self.image_width, self.image_height)) return np.array(im) - def rebuild(self, dir, data): + def save_flagged(self, dir, label, data): """ Default rebuild method to decode a base64 image """ @@ -1131,7 +1163,7 @@ class Microphone(InputComponent): return signal - def rebuild(self, dir, data): + def save_flagged(self, dir, label, data): inp = data.split(';')[1].split(',')[1] wav_obj = base64.b64decode(inp) timestamp = datetime.datetime.now() diff --git a/gradio/interface.py b/gradio/interface.py index bf69086ed0..1f35b249e5 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -222,14 +222,8 @@ class Interface: except ValueError: pass - if self.examples is not None: - processed_examples = [] - for example_set in self.examples: - processed_set = [] - for iface, example in zip(self.input_interfaces, example_set): - processed_set.append(example) - processed_examples.append(processed_set) - config["examples"] = processed_examples + if self.examples is not None and not isinstance(self.examples, str): + config["examples"] = self.examples return config def run_prediction(self, processed_input, return_duration=False): diff --git a/gradio/networking.py b/gradio/networking.py index 55918b59ba..e8895b455f 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -82,6 +82,8 @@ def get_first_available_port(initial, final): @app.route("/", methods=["GET"]) def main(): + if isinstance(app.interface.examples, str): + return redirect("/from_dir/" + app.interface.examples) return render_template("index.html", config=app.interface.config, vendor_prefix=(GRADIO_STATIC_ROOT if app.interface.share else ""), @@ -100,13 +102,16 @@ def main_from_dir(path): with open(log_file) as logs: examples = list(csv.reader(logs)) examples = examples[1:] #remove header - input_examples = [example[:len(app.interface.input_interfaces)] for example in examples] + for i, example in enumerate(examples): + for j, (interface, cell) in enumerate(zip(app.interface.input_interfaces + app.interface.output_interfaces, example)): + examples[i][j] = interface.restore_flagged(cell) + examples = [example[:len(app.interface.input_interfaces) + len(app.interface.output_interfaces)] for example in examples] return render_template("index.html", config=app.interface.config, vendor_prefix=(GRADIO_STATIC_ROOT if app.interface.share else ""), css=app.interface.css, path=path, - examples=input_examples + examples=examples ) @@ -217,11 +222,13 @@ def predict_examples(): return jsonify(output) -def flag_data(data): +def flag_data(input_data, output_data): flag_path = os.path.join(app.cwd, app.interface.flagging_dir) - output = [app.interface.input_interfaces[i].rebuild( - flag_path, component_data) - for i, component_data in enumerate(data)] + csv_data = [] + for i, interface in enumerate(app.interface.input_interfaces): + csv_data.append(interface.save_flagged(flag_path, app.interface.config["input_interfaces"][i][1]["label"], input_data[i])) + for i, interface in enumerate(app.interface.output_interfaces): + csv_data.append(interface.save_flagged(flag_path, app.interface.config["output_interfaces"][i][1]["label"], output_data[i])) log_fp = "{}/log.csv".format(flag_path) is_new = not os.path.exists(log_fp) @@ -230,15 +237,16 @@ def flag_data(data): writer = csv.writer(csvfile) if is_new: headers = [interface[1]["label"] for interface in app.interface.config["input_interfaces"]] + headers += [interface[1]["label"] for interface in app.interface.config["output_interfaces"]] writer.writerow(headers) - writer.writerow(output) + writer.writerow(csv_data) @app.route("/api/flag/", methods=["POST"]) def flag(): log_feature_analytics('flag') - data = request.json['data']['input_data'] - flag_data(data) + input_data, output_data = request.json['data']['input_data'], request.json['data']['output_data'] + flag_data(input_data, output_data) return jsonify(success=True) diff --git a/gradio/outputs.py b/gradio/outputs.py index ef9367df26..16e876998a 100644 --- a/gradio/outputs.py +++ b/gradio/outputs.py @@ -118,12 +118,21 @@ class Label(OutputComponent): "label": {}, } - def rebuild(self, dir, data): + def save_flagged(self, dir, label, data): """ - Default rebuild method for label + Returns: (Union[str, Dict[str, number]]): Either a string representing the main category label, or a dictionary with category keys mapping to confidence levels. """ - # return json.loads(data) - return data + if "confidences" in data: + return json.dumps({example["label"]: example["confidence"] for example in data["confidences"]}) + else: + return data["label"] + + def restore_flagged(self, data): + try: + data = json.loads(data) + return data + except: + return data class Image(OutputComponent): ''' @@ -186,15 +195,12 @@ class Image(OutputComponent): raise ValueError("Unknown type: " + dtype + ". Please choose from: 'numpy', 'pil', 'file', 'plot'.") return out_y, coordinates - def rebuild(self, dir, data): + def save_flagged(self, dir, label, data): """ - Default rebuild method to decode a base64 image + Returns: (str) path to image file """ - im = processing_utils.decode_base64_to_image(data) - timestamp = datetime.datetime.now() - filename = 'output_{}_{}.png'.format(self.label, timestamp.strftime("%Y-%m-%d-%H-%M-%S")) - im.save('{}/{}'.format(dir, filename), 'PNG') - return filename + return self.save_flagged_file(dir, label, data[0]) + class Video(OutputComponent): ''' @@ -218,6 +224,12 @@ class Video(OutputComponent): def postprocess(self, y): return processing_utils.encode_file_to_base64(y, type="video") + def save_flagged(self, dir, label, data): + """ + Returns: (str) path to image file + """ + return self.save_flagged_file(dir, label, data) + class KeyValues(OutputComponent): ''' @@ -246,6 +258,12 @@ class KeyValues(OutputComponent): return { "key_values": {}, } + + def save_flagged(self, dir, label, data): + return json.dumps(data) + + def restore_flagged(self, data): + return json.loads(data) class HighlightedText(OutputComponent): @@ -279,6 +297,12 @@ class HighlightedText(OutputComponent): def postprocess(self, y): return y + def save_flagged(self, dir, label, data): + return json.dumps(data) + + def restore_flagged(self, data): + return json.loads(data) + class Audio(OutputComponent): ''' @@ -316,6 +340,12 @@ class Audio(OutputComponent): else: raise ValueError("Unknown type: " + self.type + ". Please choose from: 'numpy', 'file'.") + def save_flagged(self, dir, label, data): + """ + Returns: (str) path to audio file + """ + return self.save_flagged_file(dir, label, data) + class JSON(OutputComponent): ''' @@ -343,6 +373,12 @@ class JSON(OutputComponent): "json": {}, } + def save_flagged(self, dir, label, data): + return json.dumps(data) + + def restore_flagged(self, data): + return json.loads(data) + class HTML(OutputComponent): ''' @@ -392,6 +428,12 @@ class File(OutputComponent): "data": processing_utils.encode_file_to_base64(y, header=False) } + def save_flagged(self, dir, label, data): + """ + Returns: (str) path to image file + """ + return self.save_flagged_file(dir, label, data["data"]) + class Dataframe(OutputComponent): """ @@ -446,3 +488,12 @@ class Dataframe(OutputComponent): return {"data": y} else: raise ValueError("Unknown type: " + self.type + ". Please choose from: 'pandas', 'numpy', 'array'.") + + def save_flagged(self, dir, label, data): + """ + Returns: (List[List[Union[str, float]]]) 2D array + """ + return json.dumps(data["data"]) + + def restore_flagged(self, data): + return json.loads(data) \ No newline at end of file diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index 1f053abcc3..08a9d7441f 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -68,13 +68,16 @@ def resize_and_crop(img, size, crop_type='center'): ################## def decode_base64_to_binary(encoding): - header, data = encoding.split(",") - header = header[5:] - if ";base64" in header: - header = header[0:header.index(";base64")] extension = None - if "/" in header: - extension = header[header.index("/") + 1:] + if "," in encoding: + header, data = encoding.split(",") + header = header[5:] + if ";base64" in header: + header = header[0:header.index(";base64")] + if "/" in header: + extension = header[header.index("/") + 1:] + else: + data = encoding return base64.b64decode(data), extension def decode_base64_to_file(encoding): diff --git a/gradio/static/css/style.css b/gradio/static/css/style.css index de7de86652..82a5709346 100644 --- a/gradio/static/css/style.css +++ b/gradio/static/css/style.css @@ -74,15 +74,40 @@ h4 { .close_explain { cursor: pointer; } -.examples > button { +.backward { + display: inline-block; + -moz-transform: scale(-1, 1); + -webkit-transform: scale(-1, 1); + transform: scale(-1, 1); +} +.examples_control button { padding: 8px 16px; border-radius: 2px; - margin-right: 4px; + margin-right: 8px; background-color: whitesmoke; } +.examples_control { + display: flex; +} +.examples_control > div { + display: flex; + align-items: stretch; +} + +.examples_control button small { + display: block; + font-weight: bold; +} +.examples_control_right { + padding-left: 8px; + border-left: solid 2px whitesmoke; +} +.examples_control_right .current { + background-color: #e67e22; + color: white; +} .examples > table { border-collapse: collapse; - font-family: monospace; padding: 8px; background-color: whitesmoke; border-right: solid 4px whitesmoke; @@ -90,6 +115,20 @@ h4 { border-bottom: solid 4px whitesmoke; margin-top: 8px; } +.examples > table.gallery { + background-color: white; + border: none; +} +.examples > table.gallery > thead { + display: none; +} +.examples > table.gallery > tbody > tr { + padding: 4px; + border-radius: 4px; + margin: 0 8px 8px 0; + background-color: whitesmoke; + display: inline-block; +} .examples > table th { padding: 8px 16px; text-align: left; @@ -118,7 +157,7 @@ h4 { background-color: lightgray; } .examples_body > tr.current_example { - background-color: #ffb573; + background-color: lightgray !important; } #credit { text-align: center; diff --git a/gradio/static/js/gradio.js b/gradio/static/js/gradio.js index 4f179cb5e5..e35c8668b7 100644 --- a/gradio/static/js/gradio.js +++ b/gradio/static/js/gradio.js @@ -48,13 +48,25 @@ function gradio(config, fn, target, example_file_path) {