mirror of
https://github.com/gradio-app/gradio.git
synced 2025-02-05 11:10:03 +08:00
making fixes based on Ali's comments
This commit is contained in:
parent
6dc6ecb2cd
commit
e82308eee7
@ -7,7 +7,10 @@ import base64
|
||||
|
||||
|
||||
def get_huggingface_interface(model_name, api_key, alias):
|
||||
model_url = "https://huggingface.co/{}".format(model_name)
|
||||
api_url = "https://api-inference.huggingface.co/models/{}".format(model_name)
|
||||
print("Fetching model from: {}".format(model_url))
|
||||
|
||||
if api_key is not None:
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
else:
|
||||
@ -59,7 +62,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
},
|
||||
'text-classification': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Label(label="Classification"),
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: {'Negative': r.json()[0][0]["score"],
|
||||
'Positive': r.json()[0][1]["score"]}
|
||||
@ -75,7 +78,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
inputs.Textbox(label="Possible class names ("
|
||||
"comma-separated)"),
|
||||
inputs.Checkbox(label="Allow multiple true classes")],
|
||||
'outputs': "label",
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda i, c, m: {"inputs": i, "parameters":
|
||||
{"candidate_labels": c, "multi_class": m}},
|
||||
'postprocess': lambda r: {r.json()["labels"][i]: r.json()["scores"][i] for i in
|
||||
@ -90,7 +93,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
},
|
||||
'image-classification': {
|
||||
'inputs': inputs.Image(label="Input Image", type="filepath"),
|
||||
'outputs': outputs.Label(label="Classification"),
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda i: base64.b64decode(i.split(",")[1]), # convert the base64 representation to binary
|
||||
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r.json()}
|
||||
},
|
||||
@ -107,7 +110,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
inputs.Textbox(label="Source Sentence", default="That is a happy person"),
|
||||
inputs.Textbox(lines=7, label="Sentences to compare to", placeholder="Separate each sentence by a newline"),
|
||||
],
|
||||
'outputs': outputs.Label(label="Classification"),
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda src, sentences: {"inputs": {
|
||||
"source_sentence": src,
|
||||
"sentences": [s for s in sentences.splitlines() if s != ""],
|
||||
@ -184,6 +187,8 @@ def interface_params_from_config(config_dict):
|
||||
return config_dict
|
||||
|
||||
def get_spaces_interface(model_name, api_key, alias):
|
||||
space_url = "https://huggingface.co/spaces/{}".format(model_name)
|
||||
print("Fetching interface from: {}".format(space_url))
|
||||
iframe_url = "https://huggingface.co/gradioiframe/{}/+".format(model_name)
|
||||
api_url = "https://huggingface.co/gradioiframe/{}/api/predict/".format(model_name)
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
@ -205,7 +210,7 @@ def get_spaces_interface(model_name, api_key, alias):
|
||||
output = output[0]
|
||||
return output
|
||||
|
||||
fn.__name__ = alias if alias else model_name
|
||||
fn.__name__ = alias if (alias is not None) else model_name
|
||||
interface_info["fn"] = fn
|
||||
interface_info["api_mode"] = True
|
||||
|
||||
|
@ -34,7 +34,10 @@ class InputComponent(Component):
|
||||
|
||||
def serialize(self, x, called_directly):
|
||||
"""
|
||||
Convert from a human-readable version of the input (path of an image, URL of a video, etc.) used to call() the interface to a serialized version (e.g. base64) to pass into an API. May do different things if the interface is called() vs. used via GUI.
|
||||
Convert from a human-readable version of the input (path of an image, URL of a video, etc.) into the interface to a serialized version (e.g. base64) to pass into an API. May do different things if the interface is called() vs. used via GUI.
|
||||
Parameters:
|
||||
x (Any): Input to interface
|
||||
called_directly (bool): if true, the interface was called(), otherwise, it is being used via the GUI
|
||||
"""
|
||||
return x
|
||||
|
||||
@ -664,12 +667,14 @@ class Image(InputComponent):
|
||||
"."+fmt.lower() if fmt is not None else ".png"))
|
||||
im.save(file_obj.name)
|
||||
if self.type == "file":
|
||||
warnings.warn(
|
||||
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.", DeprecationWarning)
|
||||
return file_obj
|
||||
else:
|
||||
return file_obj.name
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) +
|
||||
". Please choose from: 'numpy', 'pil', 'file', 'filepath'.")
|
||||
". Please choose from: 'numpy', 'pil', 'filepath'.")
|
||||
|
||||
def preprocess_example(self, x):
|
||||
return processing_utils.encode_file_to_base64(x)
|
||||
@ -690,7 +695,7 @@ class Image(InputComponent):
|
||||
return processing_utils.encode_url_or_file_to_base64(file_obj.name)
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) +
|
||||
". Please choose from: 'numpy', 'pil', 'file', 'filepath'.")
|
||||
". Please choose from: 'numpy', 'pil', 'filepath'.")
|
||||
|
||||
def set_interpret_parameters(self, segments=16):
|
||||
"""
|
||||
@ -902,6 +907,8 @@ class Audio(InputComponent):
|
||||
file_obj = processing_utils.decode_base64_to_file(
|
||||
file_data, file_path=file_name)
|
||||
if self.type == "file":
|
||||
warnings.warn(
|
||||
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.", DeprecationWarning)
|
||||
return file_obj
|
||||
elif self.type == "filepath":
|
||||
return file_obj.name
|
||||
@ -909,7 +916,7 @@ class Audio(InputComponent):
|
||||
return processing_utils.audio_from_file(file_obj.name)
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) +
|
||||
". Please choose from: 'numpy', 'file', 'filepath'.")
|
||||
". Please choose from: 'numpy', 'filepath'.")
|
||||
|
||||
def preprocess_example(self, x):
|
||||
return processing_utils.encode_file_to_base64(x, type="audio")
|
||||
@ -918,6 +925,8 @@ class Audio(InputComponent):
|
||||
if self.type == "filepath" or called_directly:
|
||||
name = x
|
||||
elif self.type == "file":
|
||||
warnings.warn(
|
||||
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.", DeprecationWarning)
|
||||
name = x.name
|
||||
elif self.type == "numpy":
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
@ -925,7 +934,7 @@ class Audio(InputComponent):
|
||||
processing_utils.audio_to_file(x[0], x[1], name)
|
||||
else:
|
||||
raise ValueError("Unknown type: " + str(self.type) +
|
||||
". Please choose from: 'numpy', 'file', 'filepath'.")
|
||||
". Please choose from: 'numpy', 'filepath'.")
|
||||
|
||||
file_data = processing_utils.encode_url_or_file_to_base64(name, type="audio")
|
||||
return {"name": name, "data": file_data, "is_example": False}
|
||||
|
@ -70,8 +70,8 @@ class Interface:
|
||||
capture_session=False, interpretation=None, num_shap=2.0, theme=None, repeat_outputs_per_model=True,
|
||||
title=None, description=None, article=None, thumbnail=None,
|
||||
css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900,
|
||||
allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False, api_mode=False,
|
||||
show_tips=False, flagging_dir="flagged", analytics_enabled=True, enable_queue=False):
|
||||
allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False,
|
||||
show_tips=False, flagging_dir="flagged", analytics_enabled=True, enable_queue=False, api_mode=False):
|
||||
"""
|
||||
Parameters:
|
||||
fn (Callable): the function to wrap an interface around.
|
||||
@ -97,10 +97,10 @@ class Interface:
|
||||
allow_flagging (bool): if False, users will not see a button to flag an input and output.
|
||||
flagging_options (List[str]): if not None, provides options a user must select when flagging.
|
||||
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
|
||||
api_mode (bool): If True, will skip preprocessing steps when the Interface is called() as a function (should remain False unless the Interface is loaded from an external repo)
|
||||
flagging_dir (str): what to name the dir where flagged data is stored.
|
||||
show_tips (bool): if True, will occasionally show tips about new Gradio features
|
||||
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
|
||||
api_mode (bool): If True, will skip preprocessing steps when the Interface is called() as a function (should remain False unless the Interface is loaded from an external repo)
|
||||
"""
|
||||
if not isinstance(fn, list):
|
||||
fn = [fn]
|
||||
|
@ -121,6 +121,21 @@ class Label(OutputComponent):
|
||||
raise ValueError("The `Label` output interface expects one of: a string label, or an int label, a "
|
||||
"float label, or a dictionary whose keys are labels and values are confidences.")
|
||||
|
||||
def deserialize(self, y):
|
||||
# 4 cases: (1): {'label': 'lion'}, {'label': 'lion', 'confidences':...}, {'lion': 0.46, ...}, 'lion'
|
||||
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or ('label' in y and not('confidences' in y.keys())))):
|
||||
if isinstance(y, str):
|
||||
return y
|
||||
else:
|
||||
return y['label']
|
||||
elif self.type == "confidences" or self.type == "auto":
|
||||
if 'confidences' in y.keys() and isinstance(y['confidences'], list):
|
||||
return {k['label']:k['confidence'] for k in y['confidences']}
|
||||
else:
|
||||
return y
|
||||
raise ValueError("Unable to deserialize output: {}".format(y))
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_shortcut_implementations(cls):
|
||||
return {
|
||||
@ -203,7 +218,8 @@ class Image(OutputComponent):
|
||||
return out_y
|
||||
|
||||
def deserialize(self, x):
|
||||
return processing_utils.decode_base64_to_file(x).name
|
||||
y = processing_utils.decode_base64_to_file(x).name
|
||||
return y
|
||||
|
||||
def save_flagged(self, dir, label, data, encryption_key):
|
||||
"""
|
||||
|
@ -62,12 +62,18 @@ class TestCallingLoadInterface(unittest.TestCase):
|
||||
output = io("My name is Sarah and I live in London")
|
||||
self.assertEquals(output, 'Mein Name ist Sarah und ich lebe in London')
|
||||
|
||||
def test_numerical_to_label_space(self):
|
||||
interface_info = gr.external.load_interface("spaces/abidlabs/titanic-survival")
|
||||
io = gr.Interface(**interface_info)
|
||||
output = io("male", 77, 10)
|
||||
self.assertLess(output['Survives'], 0.5)
|
||||
|
||||
def test_image_to_image_space(self):
|
||||
def assertIsFile(path):
|
||||
if not pathlib.Path(path).resolve().is_file():
|
||||
raise AssertionError("File does not exist: %s" % str(path))
|
||||
|
||||
interface_info = gr.external.load_interface("spaces/akhaliq/Face_Mesh")
|
||||
interface_info = gr.external.load_interface("spaces/abidlabs/image-identity")
|
||||
io = gr.Interface(**interface_info)
|
||||
output = io("images/lion.jpg")
|
||||
assertIsFile(output)
|
||||
|
Loading…
Reference in New Issue
Block a user