implementing output deserialization

This commit is contained in:
Abubakar Abid 2021-10-19 01:07:23 -05:00
parent 3e9e861184
commit 47b9cd9e1f
2 changed files with 12 additions and 9 deletions

View File

@ -18,6 +18,7 @@ def get_huggingface_interface(model_name, api_key, alias):
assert response.status_code == 200, "Invalid model name or src"
p = response.json().get('pipeline_tag')
# convert from binary to base64
def post_process_binary_body(r: requests.Response):
with tempfile.NamedTemporaryFile(delete=False) as fp:
fp.write(r.content)
@ -118,14 +119,14 @@ def get_huggingface_interface(model_name, api_key, alias):
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Audio(label="Audio"),
'preprocess': lambda x: {"inputs": x},
'postprocess': post_process_binary_body,
'postprocess': lambda x: base64.b64encode(x),
},
'text-to-image': {
# example model: hf.co/osanseviero/BigGAN-deep-128
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Image(label="Output"),
'preprocess': lambda x: {"inputs": x},
'postprocess': post_process_binary_body,
'postprocess': lambda x: base64.b64encode(x),
},
}
@ -175,13 +176,6 @@ def interface_params_from_config(config_dict):
## instantiate input component and output component
config_dict["inputs"] = [inputs.get_input_instance(component) for component in config_dict["input_components"]]
config_dict["outputs"] = [outputs.get_output_instance(component) for component in config_dict["output_components"]]
print(config_dict["outputs"])
# # remove preprocessing and postprocessing (since they'll be performed remotely)
# for component in config_dict["inputs"]:
# component.preprocess = lambda x:x
# for component in config_dict["outputs"]:
# component.postprocess = lambda x:x
# Remove keys that are not parameters to Interface() class
not_parameters = ("allow_embedding", "allow_interpretation", "avg_durations", "function_count",
"queue", "input_components", "output_components", "examples")
for key in not_parameters:

View File

@ -209,6 +209,9 @@ class Image(OutputComponent):
". Please choose from: 'numpy', 'pil', 'file', 'plot'.")
return out_y, coordinates
def deserialize(self, x):
raise processing_utils.decode_base64_to_file(x).name
def save_flagged(self, dir, label, data, encryption_key):
"""
Returns: (str) path to image file
@ -255,6 +258,9 @@ class Video(OutputComponent):
"data": processing_utils.encode_file_to_base64(y, type="video")
}
def deserialize(self, x):
raise processing_utils.decode_base64_to_file(x).name
def save_flagged(self, dir, label, data, encryption_key):
"""
Returns: (str) path to image file
@ -376,6 +382,9 @@ class Audio(OutputComponent):
raise ValueError("Unknown type: " + self.type +
". Please choose from: 'numpy', 'file'.")
def deserialize(self, x):
raise processing_utils.decode_base64_to_file(x).name
def save_flagged(self, dir, label, data, encryption_key):
"""
Returns: (str) path to audio file