mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
implemented deserialization
This commit is contained in:
parent
47b9cd9e1f
commit
fa3cb474e9
@ -119,14 +119,14 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Audio(label="Audio"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda x: base64.b64encode(x),
|
||||
'postprocess': lambda x: base64.b64encode(x.content).decode('utf-8'),
|
||||
},
|
||||
'text-to-image': {
|
||||
# example model: hf.co/osanseviero/BigGAN-deep-128
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Image(label="Output"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda x: base64.b64encode(x),
|
||||
'postprocess': lambda x: base64.b64encode(x.content).decode('utf-8'),
|
||||
},
|
||||
}
|
||||
|
||||
@ -201,13 +201,14 @@ def get_spaces_interface(model_name, api_key, alias):
|
||||
output = result["data"]
|
||||
if len(interface_info["outputs"])==1: # if the fn is supposed to return a single value, pop it
|
||||
output = output[0]
|
||||
if len(interface_info["outputs"])==1 and isinstance(output, list): # not sure why this is needed but it fixes the bug
|
||||
output = output[0]
|
||||
return output
|
||||
|
||||
fn.__name__ = alias if alias else model_name
|
||||
interface_info["fn"] = fn
|
||||
interface_info["api_mode"] = True
|
||||
|
||||
print(interface_info)
|
||||
return interface_info
|
||||
|
||||
repos = {
|
||||
|
@ -32,9 +32,9 @@ class InputComponent(Component):
|
||||
"""
|
||||
return x
|
||||
|
||||
def serialize(self, x):
|
||||
def serialize(self, x, called_directly):
|
||||
"""
|
||||
Convert from a human-readable version of the input (path of an image, URL of a video, etc.) used to call() the interface to a serialized version (e.g. base64) to pass into an API
|
||||
Convert from a human-readable version of the input (path of an image, URL of a video, etc.) used to call() the interface to a serialized version (e.g. base64) to pass into an API. May do different things if the interface is called() vs. used via GUI.
|
||||
"""
|
||||
return x
|
||||
|
||||
@ -674,8 +674,9 @@ class Image(InputComponent):
|
||||
def preprocess_example(self, x):
|
||||
return processing_utils.encode_file_to_base64(x)
|
||||
|
||||
def serialize(self, x):
|
||||
if self.type == "filepath":
|
||||
def serialize(self, x, called_directly=False):
|
||||
# if called directly, can assume it's a URL or filepath
|
||||
if self.type == "filepath" or called_directly:
|
||||
return processing_utils.encode_url_or_file_to_base64(x)
|
||||
elif self.type == "file":
|
||||
return processing_utils.encode_url_or_file_to_base64(x.name)
|
||||
@ -839,7 +840,7 @@ class Video(InputComponent):
|
||||
else:
|
||||
return file_name
|
||||
|
||||
def serialize(self, x):
|
||||
def serialize(self, x, called_directly):
|
||||
raise NotImplementedError()
|
||||
|
||||
def preprocess_example(self, x):
|
||||
@ -913,8 +914,8 @@ class Audio(InputComponent):
|
||||
def preprocess_example(self, x):
|
||||
return processing_utils.encode_file_to_base64(x, type="audio")
|
||||
|
||||
def serialize(self, x):
|
||||
if self.type == "filepath":
|
||||
def serialize(self, x, called_directly):
|
||||
if self.type == "filepath" or called_directly:
|
||||
name = x
|
||||
elif self.type == "file":
|
||||
name = x.name
|
||||
|
@ -219,7 +219,7 @@ class Interface:
|
||||
|
||||
def __call__(self, *params):
|
||||
if self.api_mode: # skip the preprocessing/postprocessing if sending to a remote API
|
||||
output = self.run_prediction(params)
|
||||
output = self.run_prediction(params, called_directly=True)
|
||||
else:
|
||||
output, _ = self.process(params)
|
||||
return output[0] if len(output) == 1 else output
|
||||
@ -311,9 +311,9 @@ class Interface:
|
||||
config["examples"] = self.examples
|
||||
return config
|
||||
|
||||
def run_prediction(self, processed_input, return_duration=False):
|
||||
def run_prediction(self, processed_input, return_duration=False, called_directly=False):
|
||||
if self.api_mode: # Serialize the input
|
||||
processed_input = [input_component.serialize(processed_input[i])
|
||||
processed_input = [input_component.serialize(processed_input[i], called_directly)
|
||||
for i, input_component in enumerate(self.input_components)]
|
||||
predictions = []
|
||||
durations = []
|
||||
@ -336,13 +336,13 @@ class Interface:
|
||||
if len(self.output_components) == len(self.predict):
|
||||
prediction = [prediction]
|
||||
|
||||
if self.api_mode: # Serialize the input
|
||||
prediction = [output_component.deserialize(prediction[o])
|
||||
for o, output_component in enumerate(self.output_components)]
|
||||
|
||||
durations.append(duration)
|
||||
predictions.extend(prediction)
|
||||
|
||||
if self.api_mode: # Serialize the input
|
||||
predictions = [output_component.deserialize(predictions[o])
|
||||
for o, output_component in enumerate(self.output_components)]
|
||||
|
||||
if return_duration:
|
||||
return predictions, durations
|
||||
else:
|
||||
|
@ -210,7 +210,7 @@ class Image(OutputComponent):
|
||||
return out_y, coordinates
|
||||
|
||||
def deserialize(self, x):
|
||||
raise processing_utils.decode_base64_to_file(x).name
|
||||
return processing_utils.decode_base64_to_file(x).name
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
"""
|
||||
@ -259,7 +259,7 @@ class Video(OutputComponent):
|
||||
}
|
||||
|
||||
def deserialize(self, x):
|
||||
raise processing_utils.decode_base64_to_file(x).name
|
||||
return processing_utils.decode_base64_to_file(x).name
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
"""
|
||||
@ -383,7 +383,7 @@ class Audio(OutputComponent):
|
||||
". Please choose from: 'numpy', 'file'.")
|
||||
|
||||
def deserialize(self, x):
|
||||
raise processing_utils.decode_base64_to_file(x).name
|
||||
return processing_utils.decode_base64_to_file(x).name
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user