making fixes based on Ali's comments

This commit is contained in:
Abubakar Abid 2021-10-22 06:50:26 -05:00
parent 6dc6ecb2cd
commit e82308eee7
5 changed files with 51 additions and 15 deletions

View File

@ -7,7 +7,10 @@ import base64
def get_huggingface_interface(model_name, api_key, alias):
model_url = "https://huggingface.co/{}".format(model_name)
api_url = "https://api-inference.huggingface.co/models/{}".format(model_name)
print("Fetching model from: {}".format(model_url))
if api_key is not None:
headers = {"Authorization": f"Bearer {api_key}"}
else:
@ -59,7 +62,7 @@ def get_huggingface_interface(model_name, api_key, alias):
},
'text-classification': {
'inputs': inputs.Textbox(label="Input"),
'outputs': outputs.Label(label="Classification"),
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: {'Negative': r.json()[0][0]["score"],
'Positive': r.json()[0][1]["score"]}
@ -75,7 +78,7 @@ def get_huggingface_interface(model_name, api_key, alias):
inputs.Textbox(label="Possible class names ("
"comma-separated)"),
inputs.Checkbox(label="Allow multiple true classes")],
'outputs': "label",
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda i, c, m: {"inputs": i, "parameters":
{"candidate_labels": c, "multi_class": m}},
'postprocess': lambda r: {r.json()["labels"][i]: r.json()["scores"][i] for i in
@ -90,7 +93,7 @@ def get_huggingface_interface(model_name, api_key, alias):
},
'image-classification': {
'inputs': inputs.Image(label="Input Image", type="filepath"),
'outputs': outputs.Label(label="Classification"),
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda i: base64.b64decode(i.split(",")[1]), # convert the base64 representation to binary
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r.json()}
},
@ -107,7 +110,7 @@ def get_huggingface_interface(model_name, api_key, alias):
inputs.Textbox(label="Source Sentence", default="That is a happy person"),
inputs.Textbox(lines=7, label="Sentences to compare to", placeholder="Separate each sentence by a newline"),
],
'outputs': outputs.Label(label="Classification"),
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda src, sentences: {"inputs": {
"source_sentence": src,
"sentences": [s for s in sentences.splitlines() if s != ""],
@ -184,6 +187,8 @@ def interface_params_from_config(config_dict):
return config_dict
def get_spaces_interface(model_name, api_key, alias):
space_url = "https://huggingface.co/spaces/{}".format(model_name)
print("Fetching interface from: {}".format(space_url))
iframe_url = "https://huggingface.co/gradioiframe/{}/+".format(model_name)
api_url = "https://huggingface.co/gradioiframe/{}/api/predict/".format(model_name)
headers = {'Content-Type': 'application/json'}
@ -205,7 +210,7 @@ def get_spaces_interface(model_name, api_key, alias):
output = output[0]
return output
fn.__name__ = alias if alias else model_name
fn.__name__ = alias if (alias is not None) else model_name
interface_info["fn"] = fn
interface_info["api_mode"] = True

View File

@ -34,7 +34,10 @@ class InputComponent(Component):
def serialize(self, x, called_directly):
"""
Convert from a human-readable version of the input (path of an image, URL of a video, etc.) used to call() the interface to a serialized version (e.g. base64) to pass into an API. May do different things if the interface is called() vs. used via GUI.
Convert from a human-readable version of the input (path of an image, URL of a video, etc.) into the interface to a serialized version (e.g. base64) to pass into an API. May do different things if the interface is called() vs. used via GUI.
Parameters:
x (Any): Input to interface
called_directly (bool): if true, the interface was called(), otherwise, it is being used via the GUI
"""
return x
@ -664,12 +667,14 @@ class Image(InputComponent):
"."+fmt.lower() if fmt is not None else ".png"))
im.save(file_obj.name)
if self.type == "file":
warnings.warn(
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.", DeprecationWarning)
return file_obj
else:
return file_obj.name
else:
raise ValueError("Unknown type: " + str(self.type) +
". Please choose from: 'numpy', 'pil', 'file', 'filepath'.")
". Please choose from: 'numpy', 'pil', 'filepath'.")
def preprocess_example(self, x):
return processing_utils.encode_file_to_base64(x)
@ -690,7 +695,7 @@ class Image(InputComponent):
return processing_utils.encode_url_or_file_to_base64(file_obj.name)
else:
raise ValueError("Unknown type: " + str(self.type) +
". Please choose from: 'numpy', 'pil', 'file', 'filepath'.")
". Please choose from: 'numpy', 'pil', 'filepath'.")
def set_interpret_parameters(self, segments=16):
"""
@ -902,6 +907,8 @@ class Audio(InputComponent):
file_obj = processing_utils.decode_base64_to_file(
file_data, file_path=file_name)
if self.type == "file":
warnings.warn(
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.", DeprecationWarning)
return file_obj
elif self.type == "filepath":
return file_obj.name
@ -909,7 +916,7 @@ class Audio(InputComponent):
return processing_utils.audio_from_file(file_obj.name)
else:
raise ValueError("Unknown type: " + str(self.type) +
". Please choose from: 'numpy', 'file', 'filepath'.")
". Please choose from: 'numpy', 'filepath'.")
def preprocess_example(self, x):
return processing_utils.encode_file_to_base64(x, type="audio")
@ -918,6 +925,8 @@ class Audio(InputComponent):
if self.type == "filepath" or called_directly:
name = x
elif self.type == "file":
warnings.warn(
"The 'file' type has been deprecated. Set parameter 'type' to 'filepath' instead.", DeprecationWarning)
name = x.name
elif self.type == "numpy":
file = tempfile.NamedTemporaryFile(delete=False)
@ -925,7 +934,7 @@ class Audio(InputComponent):
processing_utils.audio_to_file(x[0], x[1], name)
else:
raise ValueError("Unknown type: " + str(self.type) +
". Please choose from: 'numpy', 'file', 'filepath'.")
". Please choose from: 'numpy', 'filepath'.")
file_data = processing_utils.encode_url_or_file_to_base64(name, type="audio")
return {"name": name, "data": file_data, "is_example": False}

View File

@ -70,8 +70,8 @@ class Interface:
capture_session=False, interpretation=None, num_shap=2.0, theme=None, repeat_outputs_per_model=True,
title=None, description=None, article=None, thumbnail=None,
css=None, server_port=None, server_name=networking.LOCALHOST_NAME, height=500, width=900,
allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False, api_mode=False,
show_tips=False, flagging_dir="flagged", analytics_enabled=True, enable_queue=False):
allow_screenshot=True, allow_flagging=True, flagging_options=None, encrypt=False,
show_tips=False, flagging_dir="flagged", analytics_enabled=True, enable_queue=False, api_mode=False):
"""
Parameters:
fn (Callable): the function to wrap an interface around.
@ -97,10 +97,10 @@ class Interface:
allow_flagging (bool): if False, users will not see a button to flag an input and output.
flagging_options (List[str]): if not None, provides options a user must select when flagging.
encrypt (bool): If True, flagged data will be encrypted by key provided by creator at launch
api_mode (bool): If True, will skip preprocessing steps when the Interface is called() as a function (should remain False unless the Interface is loaded from an external repo)
flagging_dir (str): what to name the dir where flagged data is stored.
show_tips (bool): if True, will occasionally show tips about new Gradio features
enable_queue (bool): if True, inference requests will be served through a queue instead of with parallel threads. Required for longer inference times (> 1min) to prevent timeout.
api_mode (bool): If True, will skip preprocessing steps when the Interface is called() as a function (should remain False unless the Interface is loaded from an external repo)
"""
if not isinstance(fn, list):
fn = [fn]

View File

@ -121,6 +121,21 @@ class Label(OutputComponent):
raise ValueError("The `Label` output interface expects one of: a string label, or an int label, a "
"float label, or a dictionary whose keys are labels and values are confidences.")
def deserialize(self, y):
# 4 cases: (1): {'label': 'lion'}, {'label': 'lion', 'confidences':...}, {'lion': 0.46, ...}, 'lion'
if self.type == "label" or (self.type == "auto" and (isinstance(y, str) or ('label' in y and not('confidences' in y.keys())))):
if isinstance(y, str):
return y
else:
return y['label']
elif self.type == "confidences" or self.type == "auto":
if 'confidences' in y.keys() and isinstance(y['confidences'], list):
return {k['label']:k['confidence'] for k in y['confidences']}
else:
return y
raise ValueError("Unable to deserialize output: {}".format(y))
@classmethod
def get_shortcut_implementations(cls):
return {
@ -203,7 +218,8 @@ class Image(OutputComponent):
return out_y
def deserialize(self, x):
return processing_utils.decode_base64_to_file(x).name
y = processing_utils.decode_base64_to_file(x).name
return y
def save_flagged(self, dir, label, data, encryption_key):
"""

View File

@ -62,12 +62,18 @@ class TestCallingLoadInterface(unittest.TestCase):
output = io("My name is Sarah and I live in London")
self.assertEquals(output, 'Mein Name ist Sarah und ich lebe in London')
def test_numerical_to_label_space(self):
interface_info = gr.external.load_interface("spaces/abidlabs/titanic-survival")
io = gr.Interface(**interface_info)
output = io("male", 77, 10)
self.assertLess(output['Survives'], 0.5)
def test_image_to_image_space(self):
def assertIsFile(path):
if not pathlib.Path(path).resolve().is_file():
raise AssertionError("File does not exist: %s" % str(path))
interface_info = gr.external.load_interface("spaces/akhaliq/Face_Mesh")
interface_info = gr.external.load_interface("spaces/abidlabs/image-identity")
io = gr.Interface(**interface_info)
output = io("images/lion.jpg")
assertIsFile(output)