From fa3cb474e9feb5f7be2a51ae56325c5c2c663394 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Tue, 19 Oct 2021 03:04:17 -0500 Subject: [PATCH] implemented deserialization --- gradio/external.py | 7 ++++--- gradio/inputs.py | 15 ++++++++------- gradio/interface.py | 14 +++++++------- gradio/outputs.py | 6 +++--- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/gradio/external.py b/gradio/external.py index eb5a8633d6..067fe2e695 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -119,14 +119,14 @@ def get_huggingface_interface(model_name, api_key, alias): 'inputs': inputs.Textbox(label="Input"), 'outputs': outputs.Audio(label="Audio"), 'preprocess': lambda x: {"inputs": x}, - 'postprocess': lambda x: base64.b64encode(x), + 'postprocess': lambda x: base64.b64encode(x.content).decode('utf-8'), }, 'text-to-image': { # example model: hf.co/osanseviero/BigGAN-deep-128 'inputs': inputs.Textbox(label="Input"), 'outputs': outputs.Image(label="Output"), 'preprocess': lambda x: {"inputs": x}, - 'postprocess': lambda x: base64.b64encode(x), + 'postprocess': lambda x: base64.b64encode(x.content).decode('utf-8'), }, } @@ -201,13 +201,14 @@ def get_spaces_interface(model_name, api_key, alias): output = result["data"] if len(interface_info["outputs"])==1: # if the fn is supposed to return a single value, pop it output = output[0] + if len(interface_info["outputs"])==1 and isinstance(output, list): # not sure why this is needed but it fixes the bug + output = output[0] return output fn.__name__ = alias if alias else model_name interface_info["fn"] = fn interface_info["api_mode"] = True - print(interface_info) return interface_info repos = { diff --git a/gradio/inputs.py b/gradio/inputs.py index a871cdc819..3e050e30fe 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -32,9 +32,9 @@ class InputComponent(Component): """ return x - def serialize(self, x): + def serialize(self, x, called_directly): """ - Convert from a human-readable version of the input (path of an image, URL of a video, etc.) used to call() the interface to a serialized version (e.g. base64) to pass into an API + Convert from a human-readable version of the input (path of an image, URL of a video, etc.) used to call() the interface to a serialized version (e.g. base64) to pass into an API. May do different things if the interface is called() vs. used via GUI. """ return x @@ -674,8 +674,9 @@ class Image(InputComponent): def preprocess_example(self, x): return processing_utils.encode_file_to_base64(x) - def serialize(self, x): - if self.type == "filepath": + def serialize(self, x, called_directly=False): + # if called directly, can assume it's a URL or filepath + if self.type == "filepath" or called_directly: return processing_utils.encode_url_or_file_to_base64(x) elif self.type == "file": return processing_utils.encode_url_or_file_to_base64(x.name) @@ -839,7 +840,7 @@ class Video(InputComponent): else: return file_name - def serialize(self, x): + def serialize(self, x, called_directly): raise NotImplementedError() def preprocess_example(self, x): @@ -913,8 +914,8 @@ class Audio(InputComponent): def preprocess_example(self, x): return processing_utils.encode_file_to_base64(x, type="audio") - def serialize(self, x): - if self.type == "filepath": + def serialize(self, x, called_directly): + if self.type == "filepath" or called_directly: name = x elif self.type == "file": name = x.name diff --git a/gradio/interface.py b/gradio/interface.py index eb8b1eda92..6fd85d6df5 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -219,7 +219,7 @@ class Interface: def __call__(self, *params): if self.api_mode: # skip the preprocessing/postprocessing if sending to a remote API - output = self.run_prediction(params) + output = self.run_prediction(params, called_directly=True) else: output, _ = self.process(params) return output[0] if len(output) == 1 else output @@ -311,9 +311,9 @@ class Interface: config["examples"] = self.examples return config - def run_prediction(self, processed_input, return_duration=False): + def run_prediction(self, processed_input, return_duration=False, called_directly=False): if self.api_mode: # Serialize the input - processed_input = [input_component.serialize(processed_input[i]) + processed_input = [input_component.serialize(processed_input[i], called_directly) for i, input_component in enumerate(self.input_components)] predictions = [] durations = [] @@ -336,13 +336,13 @@ class Interface: if len(self.output_components) == len(self.predict): prediction = [prediction] + if self.api_mode: # Serialize the input + prediction = [output_component.deserialize(prediction[o]) + for o, output_component in enumerate(self.output_components)] + durations.append(duration) predictions.extend(prediction) - if self.api_mode: # Serialize the input - predictions = [output_component.deserialize(predictions[o]) - for o, output_component in enumerate(self.output_components)] - if return_duration: return predictions, durations else: diff --git a/gradio/outputs.py b/gradio/outputs.py index 7b5ab1eda3..089b786e2b 100644 --- a/gradio/outputs.py +++ b/gradio/outputs.py @@ -210,7 +210,7 @@ class Image(OutputComponent): return out_y, coordinates def deserialize(self, x): - raise processing_utils.decode_base64_to_file(x).name + return processing_utils.decode_base64_to_file(x).name def save_flagged(self, dir, label, data, encryption_key): """ @@ -259,7 +259,7 @@ class Video(OutputComponent): } def deserialize(self, x): - raise processing_utils.decode_base64_to_file(x).name + return processing_utils.decode_base64_to_file(x).name def save_flagged(self, dir, label, data, encryption_key): """ @@ -383,7 +383,7 @@ class Audio(OutputComponent): ". Please choose from: 'numpy', 'file'.") def deserialize(self, x): - raise processing_utils.decode_base64_to_file(x).name + return processing_utils.decode_base64_to_file(x).name def save_flagged(self, dir, label, data, encryption_key): """