implemented deserialization

This commit is contained in:
Abubakar Abid 2021-10-19 03:04:17 -05:00
parent 47b9cd9e1f
commit fa3cb474e9
4 changed files with 22 additions and 20 deletions

View File

@ -119,14 +119,14 @@ def get_huggingface_interface(model_name, api_key, alias):
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Audio(label="Audio"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda x: base64.b64encode(x),
'postprocess': lambda x: base64.b64encode(x.content).decode('utf-8'),
},
'text-to-image': {
# example model: hf.co/osanseviero/BigGAN-deep-128
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Image(label="Output"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda x: base64.b64encode(x),
'postprocess': lambda x: base64.b64encode(x.content).decode('utf-8'),
},
}
@ -201,13 +201,14 @@ def get_spaces_interface(model_name, api_key, alias):
output = result["data"]
if len(interface_info["outputs"])==1: # if the fn is supposed to return a single value, pop it
output = output[0]
if len(interface_info["outputs"])==1 and isinstance(output, list): # not sure why this is needed but it fixes the bug
output = output[0]
return output
fn.__name__ = alias if alias else model_name
interface_info["fn"] = fn
interface_info["api_mode"] = True
print(interface_info)
return interface_info
repos = {

View File

@ -32,9 +32,9 @@ class InputComponent(Component):
"""
return x
def serialize(self, x):
def serialize(self, x, called_directly):
"""
Convert from a human-readable version of the input (path of an image, URL of a video, etc.) used to call() the interface to a serialized version (e.g. base64) to pass into an API
Convert from a human-readable version of the input (path of an image, URL of a video, etc.) used to call() the interface to a serialized version (e.g. base64) to pass into an API. May do different things if the interface is called() vs. used via GUI.
"""
return x
@ -674,8 +674,9 @@ class Image(InputComponent):
def preprocess_example(self, x):
return processing_utils.encode_file_to_base64(x)
def serialize(self, x):
if self.type == "filepath":
def serialize(self, x, called_directly=False):
# if called directly, can assume it's a URL or filepath
if self.type == "filepath" or called_directly:
return processing_utils.encode_url_or_file_to_base64(x)
elif self.type == "file":
return processing_utils.encode_url_or_file_to_base64(x.name)
@ -839,7 +840,7 @@ class Video(InputComponent):
else:
return file_name
def serialize(self, x):
def serialize(self, x, called_directly):
raise NotImplementedError()
def preprocess_example(self, x):
@ -913,8 +914,8 @@ class Audio(InputComponent):
def preprocess_example(self, x):
return processing_utils.encode_file_to_base64(x, type="audio")
def serialize(self, x):
if self.type == "filepath":
def serialize(self, x, called_directly):
if self.type == "filepath" or called_directly:
name = x
elif self.type == "file":
name = x.name

View File

@ -219,7 +219,7 @@ class Interface:
def __call__(self, *params):
if self.api_mode: # skip the preprocessing/postprocessing if sending to a remote API
output = self.run_prediction(params)
output = self.run_prediction(params, called_directly=True)
else:
output, _ = self.process(params)
return output[0] if len(output) == 1 else output
@ -311,9 +311,9 @@ class Interface:
config["examples"] = self.examples
return config
def run_prediction(self, processed_input, return_duration=False):
def run_prediction(self, processed_input, return_duration=False, called_directly=False):
if self.api_mode: # Serialize the input
processed_input = [input_component.serialize(processed_input[i])
processed_input = [input_component.serialize(processed_input[i], called_directly)
for i, input_component in enumerate(self.input_components)]
predictions = []
durations = []
@ -336,13 +336,13 @@ class Interface:
if len(self.output_components) == len(self.predict):
prediction = [prediction]
if self.api_mode: # Serialize the input
prediction = [output_component.deserialize(prediction[o])
for o, output_component in enumerate(self.output_components)]
durations.append(duration)
predictions.extend(prediction)
if self.api_mode: # Serialize the input
predictions = [output_component.deserialize(predictions[o])
for o, output_component in enumerate(self.output_components)]
if return_duration:
return predictions, durations
else:

View File

@ -210,7 +210,7 @@ class Image(OutputComponent):
return out_y, coordinates
def deserialize(self, x):
raise processing_utils.decode_base64_to_file(x).name
return processing_utils.decode_base64_to_file(x).name
def save_flagged(self, dir, label, data, encryption_key):
"""
@ -259,7 +259,7 @@ class Video(OutputComponent):
}
def deserialize(self, x):
raise processing_utils.decode_base64_to_file(x).name
return processing_utils.decode_base64_to_file(x).name
def save_flagged(self, dir, label, data, encryption_key):
"""
@ -383,7 +383,7 @@ class Audio(OutputComponent):
". Please choose from: 'numpy', 'file'.")
def deserialize(self, x):
raise processing_utils.decode_base64_to_file(x).name
return processing_utils.decode_base64_to_file(x).name
def save_flagged(self, dir, label, data, encryption_key):
"""