utils and external unit tests

This commit is contained in:
dawoodkhan82 2021-11-02 17:22:52 -04:00
parent a3922a169e
commit 65a5ed5147
3 changed files with 122 additions and 25 deletions

View File

@ -69,7 +69,7 @@ def get_huggingface_interface(model_name, api_key, alias):
},
'fill-mask': {
'inputs': inputs.Textbox(label="Input"),
'outputs': "label",
'outputs': outputs.Label(label="Classification", type="confidences"),
'preprocess': lambda x: {"inputs": x},
'postprocess': lambda r: {i["token_str"]: i["score"] for i in r.json()}
},

View File

@ -7,22 +7,105 @@ WARNING: These tests have an external dependency: namely that Hugging Face's Hub
"""
class TestHuggingFaceModelAPI(unittest.TestCase):
def test_question_answering(self):
model_type = "question-answering"
interface_info = gr.external.get_huggingface_interface(
"deepset/roberta-base-squad2", api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
self.assertIsInstance(interface_info["inputs"][1], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
self.assertIsInstance(interface_info["outputs"][1], gr.outputs.Label)
def test_text_generation(self):
model_type = "text_generation"
interface_info = gr.external.get_huggingface_interface("gpt2", api_key=None, alias=None)
self.assertEqual(interface_info["fn"].__name__, "gpt2")
interface_info = gr.external.get_huggingface_interface("gpt2",
api_key=None,
alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
def test_sentiment_classifier(self):
model_type = "sentiment_classifier"
def test_summarization(self):
model_type = "summarization"
interface_info = gr.external.get_huggingface_interface(
"distilbert-base-uncased-finetuned-sst-2-english", api_key=None,
alias=model_type)
"facebook/bart-large-cnn", api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
def test_translation(self):
model_type = "translation"
interface_info = gr.external.get_huggingface_interface(
"facebook/bart-large-cnn", api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
def test_text2text_generation(self):
model_type = "text2text-generation"
interface_info = gr.external.get_huggingface_interface(
"sshleifer/tiny-mbart", api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
def test_text_classification(self):
model_type = "text-classification"
interface_info = gr.external.get_huggingface_interface(
"distilbert-base-uncased-finetuned-sst-2-english",
api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
def test_fill_mask(self):
model_type = "fill-mask"
interface_info = gr.external.get_huggingface_interface(
"bert-base-uncased",
api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
def test_zero_shot_classification(self):
model_type = "zero-shot-classification"
interface_info = gr.external.get_huggingface_interface(
"facebook/bart-large-mnli",
api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
self.assertIsInstance(interface_info["inputs"][1], gr.inputs.Textbox)
self.assertIsInstance(interface_info["inputs"][2], gr.inputs.Checkbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
def test_automatic_speech_recognition(self):
model_type = "automatic-speech-recognition"
interface_info = gr.external.get_huggingface_interface(
"facebook/wav2vec2-base-960h",
api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Audio)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
def test_image_classification(self):
model_type = "image-classification"
interface_info = gr.external.get_huggingface_interface(
"google/vit-base-patch16-224",
api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Image)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
def test_feature_extraction(self):
model_type = "feature-extraction"
interface_info = gr.external.get_huggingface_interface(
"sentence-transformers/distilbert-base-nli-mean-tokens",
api_key=None, alias=model_type)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Dataframe)
def test_sentence_similarity(self):
model_type = "text-to-speech"
interface_info = gr.external.get_huggingface_interface(
@ -50,8 +133,6 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Image)
class TestHuggingFaceSpaceAPI(unittest.TestCase):
def test_english_to_spanish(self):
interface_info = gr.external.get_spaces_interface("abidlabs/english_to_spanish", api_key=None, alias=None)
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
@ -61,21 +142,8 @@ class TestLoadInterface(unittest.TestCase):
def test_english_to_spanish(self):
interface_info = gr.external.load_interface("spaces/abidlabs/english_to_spanish")
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
def test_distilbert_classification(self):
interface_info = gr.external.load_interface("distilbert-base-uncased-finetuned-sst-2-english", src="huggingface", alias="sentiment_classifier")
self.assertEqual(interface_info["fn"].__name__, "sentiment_classifier")
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
def test_models_src(self):
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
self.assertEqual(interface_info["fn"].__name__, "sentiment_classifier")
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
class TestCallingLoadInterface(unittest.TestCase):
def test_sentiment_model(self):
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
io = gr.Interface(**interface_info)
@ -85,7 +153,7 @@ class TestCallingLoadInterface(unittest.TestCase):
def test_image_classification_model(self):
interface_info = gr.external.load_interface("models/google/vit-base-patch16-224")
io = gr.Interface(**interface_info)
output = io("test/images/lion.jpg")
output = io("images/lion.jpg")
self.assertGreater(output['lion'], 0.5)
def test_translation_model(self):
@ -107,7 +175,7 @@ class TestCallingLoadInterface(unittest.TestCase):
interface_info = gr.external.load_interface("spaces/abidlabs/image-identity")
io = gr.Interface(**interface_info)
output = io("test/images/lion.jpg")
output = io("images/lion.jpg")
assertIsFile(output)
if __name__ == '__main__':

File diff suppressed because one or more lines are too long