mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
implementing output deserialization
This commit is contained in:
parent
3e9e861184
commit
47b9cd9e1f
@ -18,6 +18,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
assert response.status_code == 200, "Invalid model name or src"
|
||||
p = response.json().get('pipeline_tag')
|
||||
|
||||
# convert from binary to base64
|
||||
def post_process_binary_body(r: requests.Response):
|
||||
with tempfile.NamedTemporaryFile(delete=False) as fp:
|
||||
fp.write(r.content)
|
||||
@ -118,14 +119,14 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Audio(label="Audio"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': post_process_binary_body,
|
||||
'postprocess': lambda x: base64.b64encode(x),
|
||||
},
|
||||
'text-to-image': {
|
||||
# example model: hf.co/osanseviero/BigGAN-deep-128
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Image(label="Output"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': post_process_binary_body,
|
||||
'postprocess': lambda x: base64.b64encode(x),
|
||||
},
|
||||
}
|
||||
|
||||
@ -175,13 +176,6 @@ def interface_params_from_config(config_dict):
|
||||
## instantiate input component and output component
|
||||
config_dict["inputs"] = [inputs.get_input_instance(component) for component in config_dict["input_components"]]
|
||||
config_dict["outputs"] = [outputs.get_output_instance(component) for component in config_dict["output_components"]]
|
||||
print(config_dict["outputs"])
|
||||
# # remove preprocessing and postprocessing (since they'll be performed remotely)
|
||||
# for component in config_dict["inputs"]:
|
||||
# component.preprocess = lambda x:x
|
||||
# for component in config_dict["outputs"]:
|
||||
# component.postprocess = lambda x:x
|
||||
# Remove keys that are not parameters to Interface() class
|
||||
not_parameters = ("allow_embedding", "allow_interpretation", "avg_durations", "function_count",
|
||||
"queue", "input_components", "output_components", "examples")
|
||||
for key in not_parameters:
|
||||
|
@ -209,6 +209,9 @@ class Image(OutputComponent):
|
||||
". Please choose from: 'numpy', 'pil', 'file', 'plot'.")
|
||||
return out_y, coordinates
|
||||
|
||||
def deserialize(self, x):
|
||||
raise processing_utils.decode_base64_to_file(x).name
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
"""
|
||||
Returns: (str) path to image file
|
||||
@ -255,6 +258,9 @@ class Video(OutputComponent):
|
||||
"data": processing_utils.encode_file_to_base64(y, type="video")
|
||||
}
|
||||
|
||||
def deserialize(self, x):
|
||||
raise processing_utils.decode_base64_to_file(x).name
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
"""
|
||||
Returns: (str) path to image file
|
||||
@ -376,6 +382,9 @@ class Audio(OutputComponent):
|
||||
raise ValueError("Unknown type: " + self.type +
|
||||
". Please choose from: 'numpy', 'file'.")
|
||||
|
||||
def deserialize(self, x):
|
||||
raise processing_utils.decode_base64_to_file(x).name
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
"""
|
||||
Returns: (str) path to audio file
|
||||
|
Loading…
Reference in New Issue
Block a user