mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-30 11:00:11 +08:00
utils and external unit tests
This commit is contained in:
parent
a3922a169e
commit
65a5ed5147
@ -69,7 +69,7 @@ def get_huggingface_interface(model_name, api_key, alias):
|
||||
},
|
||||
'fill-mask': {
|
||||
'inputs': inputs.Textbox(label="Input"),
|
||||
'outputs': "label",
|
||||
'outputs': outputs.Label(label="Classification", type="confidences"),
|
||||
'preprocess': lambda x: {"inputs": x},
|
||||
'postprocess': lambda r: {i["token_str"]: i["score"] for i in r.json()}
|
||||
},
|
||||
|
@ -7,22 +7,105 @@ WARNING: These tests have an external dependency: namely that Hugging Face's Hub
|
||||
"""
|
||||
|
||||
class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
def test_question_answering(self):
|
||||
model_type = "question-answering"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"deepset/roberta-base-squad2", api_key=None, alias=model_type)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["inputs"][1], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"][1], gr.outputs.Label)
|
||||
|
||||
def test_text_generation(self):
|
||||
model_type = "text_generation"
|
||||
interface_info = gr.external.get_huggingface_interface("gpt2", api_key=None, alias=None)
|
||||
self.assertEqual(interface_info["fn"].__name__, "gpt2")
|
||||
interface_info = gr.external.get_huggingface_interface("gpt2",
|
||||
api_key=None,
|
||||
alias=model_type)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
|
||||
|
||||
def test_sentiment_classifier(self):
|
||||
model_type = "sentiment_classifier"
|
||||
def test_summarization(self):
|
||||
model_type = "summarization"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"distilbert-base-uncased-finetuned-sst-2-english", api_key=None,
|
||||
alias=model_type)
|
||||
"facebook/bart-large-cnn", api_key=None, alias=model_type)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
|
||||
|
||||
def test_translation(self):
|
||||
model_type = "translation"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"facebook/bart-large-cnn", api_key=None, alias=model_type)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
|
||||
|
||||
def test_text2text_generation(self):
|
||||
model_type = "text2text-generation"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"sshleifer/tiny-mbart", api_key=None, alias=model_type)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
|
||||
|
||||
def test_text_classification(self):
|
||||
model_type = "text-classification"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"distilbert-base-uncased-finetuned-sst-2-english",
|
||||
api_key=None, alias=model_type)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
||||
|
||||
def test_fill_mask(self):
|
||||
model_type = "fill-mask"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"bert-base-uncased",
|
||||
api_key=None, alias=model_type)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
||||
|
||||
def test_zero_shot_classification(self):
|
||||
model_type = "zero-shot-classification"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"facebook/bart-large-mnli",
|
||||
api_key=None, alias=model_type)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["inputs"][1], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["inputs"][2], gr.inputs.Checkbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
||||
|
||||
def test_automatic_speech_recognition(self):
|
||||
model_type = "automatic-speech-recognition"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"facebook/wav2vec2-base-960h",
|
||||
api_key=None, alias=model_type)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Audio)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
|
||||
|
||||
def test_image_classification(self):
|
||||
model_type = "image-classification"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"google/vit-base-patch16-224",
|
||||
api_key=None, alias=model_type)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Image)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
||||
|
||||
def test_feature_extraction(self):
|
||||
model_type = "feature-extraction"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"sentence-transformers/distilbert-base-nli-mean-tokens",
|
||||
api_key=None, alias=model_type)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Dataframe)
|
||||
|
||||
def test_sentence_similarity(self):
|
||||
model_type = "text-to-speech"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
@ -50,8 +133,6 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Image)
|
||||
|
||||
|
||||
class TestHuggingFaceSpaceAPI(unittest.TestCase):
|
||||
def test_english_to_spanish(self):
|
||||
interface_info = gr.external.get_spaces_interface("abidlabs/english_to_spanish", api_key=None, alias=None)
|
||||
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
|
||||
@ -61,21 +142,8 @@ class TestLoadInterface(unittest.TestCase):
|
||||
def test_english_to_spanish(self):
|
||||
interface_info = gr.external.load_interface("spaces/abidlabs/english_to_spanish")
|
||||
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
|
||||
|
||||
def test_distilbert_classification(self):
|
||||
interface_info = gr.external.load_interface("distilbert-base-uncased-finetuned-sst-2-english", src="huggingface", alias="sentiment_classifier")
|
||||
self.assertEqual(interface_info["fn"].__name__, "sentiment_classifier")
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
||||
|
||||
def test_models_src(self):
|
||||
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
|
||||
self.assertEqual(interface_info["fn"].__name__, "sentiment_classifier")
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
||||
|
||||
class TestCallingLoadInterface(unittest.TestCase):
|
||||
def test_sentiment_model(self):
|
||||
interface_info = gr.external.load_interface("models/distilbert-base-uncased-finetuned-sst-2-english", alias="sentiment_classifier")
|
||||
io = gr.Interface(**interface_info)
|
||||
@ -85,7 +153,7 @@ class TestCallingLoadInterface(unittest.TestCase):
|
||||
def test_image_classification_model(self):
|
||||
interface_info = gr.external.load_interface("models/google/vit-base-patch16-224")
|
||||
io = gr.Interface(**interface_info)
|
||||
output = io("test/images/lion.jpg")
|
||||
output = io("images/lion.jpg")
|
||||
self.assertGreater(output['lion'], 0.5)
|
||||
|
||||
def test_translation_model(self):
|
||||
@ -107,7 +175,7 @@ class TestCallingLoadInterface(unittest.TestCase):
|
||||
|
||||
interface_info = gr.external.load_interface("spaces/abidlabs/image-identity")
|
||||
io = gr.Interface(**interface_info)
|
||||
output = io("test/images/lion.jpg")
|
||||
output = io("images/lion.jpg")
|
||||
assertIsFile(output)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
29
test/test_processing_utils.py
Normal file
29
test/test_processing_utils.py
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user