mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
added remaining pipelines
This commit is contained in:
parent
95147d859e
commit
77998232b0
@ -18,7 +18,7 @@ jobs:
|
||||
. venv/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install -r gradio.egg-info/requires.txt
|
||||
pip install shap IPython comet_ml wandb mlflow transformers
|
||||
pip install shap IPython comet_ml wandb mlflow tensorflow transformers
|
||||
pip install selenium==4.0.0a6.post2 coverage scikit-image
|
||||
- run:
|
||||
command: |
|
||||
|
@ -90,21 +90,19 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: {'Negative': r.json()[0][0]["score"],
|
||||
'Positive': r.json()[0][1]["score"]}
|
||||
},
|
||||
'text2text-generation': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Generated Text"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r.json()[0]["generated_text"]
|
||||
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r.json()[0]}
|
||||
},
|
||||
'text-generation': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Output"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r.json()[0]["generated_text"],
|
||||
# 'examples': [['My name is Clara and I am']]
|
||||
},
|
||||
'text2text-generation': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Generated Text"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r.json()[0]["generated_text"]
|
||||
},
|
||||
'translation': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
@ -123,6 +121,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
'postprocess': lambda r: {r.json()["labels"][i]: r.json()["scores"][i] for i in
|
||||
range(len(r.json()["labels"]))}
|
||||
},
|
||||
# Non-HF pipelines
|
||||
'sentence-similarity': {
|
||||
# example model: hf.co/sentence-transformers/distilbert-base-nli-stsb-mean-tokens
|
||||
'inputs': [
|
||||
@ -166,9 +165,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
response = requests.request("POST", api_url, headers=headers, data=data)
|
||||
if not(response.status_code == 200):
|
||||
raise ValueError("Could not complete request to HuggingFace API, Error {}".format(response.status_code))
|
||||
print("response>>>>>>", response.json())
|
||||
output = pipeline['postprocess'](response)
|
||||
print("output>>>>>>", output)
|
||||
return output
|
||||
|
||||
if alias is None:
|
||||
@ -295,9 +292,6 @@ def load_from_pipeline(pipeline):
|
||||
'preprocess': lambda i: {"images": i},
|
||||
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}
|
||||
}
|
||||
elif hasattr(transformers, 'AutomaticSpeechRecognitionPipeline') and isinstance(pipeline, transformers.AutomaticSpeechRecognitionPipeline):
|
||||
pipeline_info = {
|
||||
}
|
||||
elif hasattr(transformers, 'QuestionAnsweringPipeline') and isinstance(pipeline, transformers.QuestionAnsweringPipeline):
|
||||
pipeline_info = {
|
||||
'inputs': [inputs.Textbox(label="Context", lines=7), inputs.Textbox(label="Question")],
|
||||
@ -305,13 +299,66 @@ def load_from_pipeline(pipeline):
|
||||
'preprocess': lambda c, q: {"context": c, "question": q},
|
||||
'postprocess': lambda r: (r["answer"], r["score"]),
|
||||
}
|
||||
elif hasattr(transformers, 'SummarizationPipeline') and isinstance(pipeline, transformers.SummarizationPipeline):
|
||||
pipeline_info = {
|
||||
'inputs': inputs.Textbox(label="Input", lines=7),
|
||||
'outputs': outputs.Textbox(label="Summary"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: r[0]["summary_text"]
|
||||
}
|
||||
elif hasattr(transformers, 'TextClassificationPipeline') and isinstance(pipeline, transformers.TextClassificationPipeline):
|
||||
pipeline_info = {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda x: [x],
|
||||
'postprocess': lambda r: {i["label"].split(", ")[0]: i["score"] for i in r}
|
||||
}
|
||||
elif hasattr(transformers, 'TextGenerationPipeline') and isinstance(pipeline, transformers.TextGenerationPipeline):
|
||||
pipeline_info = {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Output"),
|
||||
'preprocess': lambda x: {"text_inputs": x},
|
||||
'postprocess': lambda r: r[0]["generated_text"],
|
||||
}
|
||||
elif hasattr(transformers, 'TranslationPipeline') and isinstance(pipeline, transformers.TranslationPipeline):
|
||||
pipeline_info = {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Translation"),
|
||||
'preprocess': lambda x: [x],
|
||||
'postprocess': lambda r: r[0]["translation_text"]
|
||||
}
|
||||
elif hasattr(transformers, 'Text2TextGenerationPipeline') and isinstance(pipeline, transformers.Text2TextGenerationPipeline):
|
||||
pipeline_info = {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': outputs.Textbox(label="Generated Text"),
|
||||
'preprocess': lambda x: [x],
|
||||
'postprocess': lambda r: r[0]["generated_text"]
|
||||
}
|
||||
elif hasattr(transformers, 'ZeroShotClassificationPipeline') and isinstance(pipeline, transformers.ZeroShotClassificationPipeline):
|
||||
pipeline_info = {
|
||||
'inputs': [inputs.Textbox(label="Input"),
|
||||
inputs.Textbox(label="Possible class names ("
|
||||
"comma-separated)"),
|
||||
inputs.Checkbox(label="Allow multiple true classes")],
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda i, c, m: {"sequences": i,
|
||||
"candidate_labels": c, "multi_label": m},
|
||||
'postprocess': lambda r: {r["labels"][i]: r["scores"][i] for i in
|
||||
range(len(r["labels"]))}
|
||||
}
|
||||
else:
|
||||
raise ValueError("Unsupported pipeline type: {}".format(type(pipeline)))
|
||||
|
||||
# define the function that will be called by the Interface
|
||||
def fn(*params):
|
||||
data = pipeline_info["preprocess"](*params)
|
||||
data = pipeline(**data)
|
||||
# special cases that needs to be handled differently
|
||||
if isinstance(pipeline, (transformers.TextClassificationPipeline,
|
||||
transformers.Text2TextGenerationPipeline,
|
||||
transformers.TranslationPipeline)):
|
||||
data = pipeline(*data)
|
||||
else:
|
||||
data = pipeline(**data)
|
||||
# print("Before postprocessing", data)
|
||||
output = pipeline_info["postprocess"](data)
|
||||
return output
|
||||
|
@ -155,7 +155,7 @@ class TestLoadInterface(unittest.TestCase):
|
||||
io = gr.Interface(**interface_info)
|
||||
io.api_mode = True
|
||||
output = io("I am happy, I love you.")
|
||||
self.assertGreater(output['Positive'], 0.5)
|
||||
self.assertGreater(output['POSITIVE'], 0.5)
|
||||
|
||||
def test_image_classification_model(self):
|
||||
interface_info = gr.external.load_interface("models/google/vit-base-patch16-224")
|
||||
|
@ -101,7 +101,7 @@ class TestInterface(unittest.TestCase):
|
||||
def test_interface_load(self):
|
||||
io = Interface.load("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
|
||||
output = io("I am happy, I love you.")
|
||||
self.assertGreater(output['Positive'], 0.5)
|
||||
self.assertGreater(output['POSITIVE'], 0.5)
|
||||
|
||||
def test_interface_none_interp(self):
|
||||
interface = Interface(lambda x: x, "textbox", "label", interpretation=[None])
|
||||
|
Loading…
x
Reference in New Issue
Block a user