From 47b9cd9e1f88f1911d555346b7e751d69f1ae4a4 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Tue, 19 Oct 2021 01:07:23 -0500 Subject: [PATCH] implementing output deserialization --- gradio/external.py | 12 +++--------- gradio/outputs.py | 9 +++++++++ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/gradio/external.py b/gradio/external.py index 7851bed207..eb5a8633d6 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -18,6 +18,7 @@ def get_huggingface_interface(model_name, api_key, alias): assert response.status_code == 200, "Invalid model name or src" p = response.json().get('pipeline_tag') + # convert from binary to base64 def post_process_binary_body(r: requests.Response): with tempfile.NamedTemporaryFile(delete=False) as fp: fp.write(r.content) @@ -118,14 +119,14 @@ def get_huggingface_interface(model_name, api_key, alias): 'inputs': inputs.Textbox(label="Input"), 'outputs': outputs.Audio(label="Audio"), 'preprocess': lambda x: {"inputs": x}, - 'postprocess': post_process_binary_body, + 'postprocess': lambda x: base64.b64encode(x), }, 'text-to-image': { # example model: hf.co/osanseviero/BigGAN-deep-128 'inputs': inputs.Textbox(label="Input"), 'outputs': outputs.Image(label="Output"), 'preprocess': lambda x: {"inputs": x}, - 'postprocess': post_process_binary_body, + 'postprocess': lambda x: base64.b64encode(x), }, } @@ -175,13 +176,6 @@ def interface_params_from_config(config_dict): ## instantiate input component and output component config_dict["inputs"] = [inputs.get_input_instance(component) for component in config_dict["input_components"]] config_dict["outputs"] = [outputs.get_output_instance(component) for component in config_dict["output_components"]] - print(config_dict["outputs"]) - # # remove preprocessing and postprocessing (since they'll be performed remotely) - # for component in config_dict["inputs"]: - # component.preprocess = lambda x:x - # for component in config_dict["outputs"]: - # component.postprocess = lambda x:x - # Remove keys that are not parameters to Interface() class not_parameters = ("allow_embedding", "allow_interpretation", "avg_durations", "function_count", "queue", "input_components", "output_components", "examples") for key in not_parameters: diff --git a/gradio/outputs.py b/gradio/outputs.py index baccff4969..7b5ab1eda3 100644 --- a/gradio/outputs.py +++ b/gradio/outputs.py @@ -209,6 +209,9 @@ class Image(OutputComponent): ". Please choose from: 'numpy', 'pil', 'file', 'plot'.") return out_y, coordinates + def deserialize(self, x): + raise processing_utils.decode_base64_to_file(x).name + def save_flagged(self, dir, label, data, encryption_key): """ Returns: (str) path to image file @@ -255,6 +258,9 @@ class Video(OutputComponent): "data": processing_utils.encode_file_to_base64(y, type="video") } + def deserialize(self, x): + raise processing_utils.decode_base64_to_file(x).name + def save_flagged(self, dir, label, data, encryption_key): """ Returns: (str) path to image file @@ -376,6 +382,9 @@ class Audio(OutputComponent): raise ValueError("Unknown type: " + self.type + ". Please choose from: 'numpy', 'file'.") + def deserialize(self, x): + raise processing_utils.decode_base64_to_file(x).name + def save_flagged(self, dir, label, data, encryption_key): """ Returns: (str) path to audio file