gradio/test/test_external.py

254 lines
11 KiB
Python
Raw Normal View History

import os
2022-01-14 22:28:11 +08:00
import pathlib
import unittest
2021-12-21 06:04:37 +08:00
import transformers
2022-01-14 22:28:11 +08:00
import gradio as gr
2021-10-17 15:25:04 +08:00
2021-10-19 13:59:28 +08:00
"""
WARNING: These tests have an external dependency: namely that Hugging Face's
Hub and Space APIs do not change, and they keep their most famous models up.
2022-01-14 22:29:08 +08:00
So if, e.g. Spaces is down, then these test will not pass.
2021-10-19 13:59:28 +08:00
"""
2021-11-13 14:33:59 +08:00
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestHuggingFaceModelAPI(unittest.TestCase):
2022-02-03 07:17:50 +08:00
def test_audio_to_audio(self):
model_type = "audio-to-audio"
interface_info = gr.external.get_huggingface_interface(
"speechbrain/mtl-mimic-voicebank",
2022-02-03 07:17:50 +08:00
api_key=None,
alias=model_type,
)
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Audio)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Audio)
2021-11-03 05:22:52 +08:00
def test_question_answering(self):
model_type = "question-answering"
interface_info = gr.external.get_huggingface_interface(
"lysandre/tiny-vit-random", api_key=None, alias=model_type
)
2021-11-03 05:22:52 +08:00
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Image)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
2021-11-03 05:22:52 +08:00
2021-10-26 06:27:08 +08:00
def test_text_generation(self):
model_type = "text_generation"
2022-01-14 22:29:08 +08:00
interface_info = gr.external.get_huggingface_interface(
"gpt2", api_key=None, alias=model_type
)
2021-11-03 05:22:52 +08:00
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
def test_summarization(self):
model_type = "summarization"
interface_info = gr.external.get_huggingface_interface(
"facebook/bart-large-cnn", api_key=None, alias=model_type
)
2021-11-03 05:22:52 +08:00
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
def test_translation(self):
model_type = "translation"
interface_info = gr.external.get_huggingface_interface(
"facebook/bart-large-cnn", api_key=None, alias=model_type
)
2021-11-03 05:22:52 +08:00
self.assertEqual(interface_info["fn"].__name__, model_type)
2021-10-17 15:25:04 +08:00
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
2021-10-26 06:27:08 +08:00
2021-11-03 05:22:52 +08:00
def test_text2text_generation(self):
model_type = "text2text-generation"
2021-10-26 06:27:08 +08:00
interface_info = gr.external.get_huggingface_interface(
"sshleifer/tiny-mbart", api_key=None, alias=model_type
)
2021-11-03 05:22:52 +08:00
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
def test_text_classification(self):
model_type = "text-classification"
interface_info = gr.external.get_huggingface_interface(
"distilbert-base-uncased-finetuned-sst-2-english",
api_key=None,
alias=model_type,
)
2021-11-03 05:22:52 +08:00
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
def test_fill_mask(self):
model_type = "fill-mask"
interface_info = gr.external.get_huggingface_interface(
"bert-base-uncased", api_key=None, alias=model_type
)
2021-10-26 06:27:08 +08:00
self.assertEqual(interface_info["fn"].__name__, model_type)
2021-10-17 15:25:04 +08:00
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
2021-11-03 05:22:52 +08:00
def test_zero_shot_classification(self):
model_type = "zero-shot-classification"
interface_info = gr.external.get_huggingface_interface(
"facebook/bart-large-mnli", api_key=None, alias=model_type
)
2021-11-03 05:22:52 +08:00
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
self.assertIsInstance(interface_info["inputs"][1], gr.inputs.Textbox)
self.assertIsInstance(interface_info["inputs"][2], gr.inputs.Checkbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
def test_automatic_speech_recognition(self):
model_type = "automatic-speech-recognition"
interface_info = gr.external.get_huggingface_interface(
"facebook/wav2vec2-base-960h", api_key=None, alias=model_type
)
2021-11-03 05:22:52 +08:00
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Audio)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Textbox)
def test_image_classification(self):
model_type = "image-classification"
interface_info = gr.external.get_huggingface_interface(
"google/vit-base-patch16-224", api_key=None, alias=model_type
)
2021-11-03 05:22:52 +08:00
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Image)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
def test_feature_extraction(self):
model_type = "feature-extraction"
interface_info = gr.external.get_huggingface_interface(
"sentence-transformers/distilbert-base-nli-mean-tokens",
api_key=None,
alias=model_type,
)
2021-11-03 05:22:52 +08:00
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Dataframe)
2021-10-26 06:27:08 +08:00
def test_sentence_similarity(self):
model_type = "text-to-speech"
interface_info = gr.external.get_huggingface_interface(
"julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train",
api_key=None,
alias=model_type,
)
2021-10-26 06:27:08 +08:00
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Audio)
def test_text_to_speech(self):
model_type = "text-to-speech"
interface_info = gr.external.get_huggingface_interface(
"julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train",
api_key=None,
alias=model_type,
)
2021-10-26 06:27:08 +08:00
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Audio)
def test_text_to_image(self):
model_type = "text-to-image"
interface_info = gr.external.get_huggingface_interface(
"osanseviero/BigGAN-deep-128", api_key=None, alias=model_type
)
2021-10-26 06:27:08 +08:00
self.assertEqual(interface_info["fn"].__name__, model_type)
self.assertIsInstance(interface_info["inputs"], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"], gr.outputs.Image)
2021-10-17 15:25:04 +08:00
def test_english_to_spanish(self):
interface_info = gr.external.get_spaces_interface(
"abidlabs/english_to_spanish", api_key=None, alias=None
)
2021-10-17 15:25:04 +08:00
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
2021-10-17 15:25:04 +08:00
class TestLoadInterface(unittest.TestCase):
def test_english_to_spanish(self):
interface_info = gr.external.load_interface(
"spaces/abidlabs/english_to_spanish"
)
2021-10-17 15:25:04 +08:00
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
2021-11-03 05:22:52 +08:00
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
2021-10-17 15:25:04 +08:00
def test_sentiment_model(self):
interface_info = gr.external.load_interface(
"models/distilbert-base-uncased-finetuned-sst-2-english",
alias="sentiment_classifier",
)
2021-11-10 02:30:59 +08:00
io = gr.Interface(**interface_info)
2021-12-16 23:43:31 +08:00
io.api_mode = True
output = io("I am happy, I love you.")
self.assertGreater(output["POSITIVE"], 0.5)
def test_image_classification_model(self):
interface_info = gr.external.load_interface(
"models/google/vit-base-patch16-224"
)
2021-11-10 02:30:59 +08:00
io = gr.Interface(**interface_info)
2021-12-16 23:43:31 +08:00
io.api_mode = True
2021-11-09 04:37:32 +08:00
output = io("test/test_data/lion.jpg")
self.assertGreater(output["lion"], 0.5)
def test_translation_model(self):
interface_info = gr.external.load_interface("models/t5-base")
2021-11-10 02:30:59 +08:00
io = gr.Interface(**interface_info)
2021-12-16 23:43:31 +08:00
io.api_mode = True
output = io("My name is Sarah and I live in London")
self.assertEquals(output, "Mein Name ist Sarah und ich lebe in London")
2021-10-22 19:50:26 +08:00
def test_numerical_to_label_space(self):
interface_info = gr.external.load_interface("spaces/abidlabs/titanic-survival")
2021-11-10 02:30:59 +08:00
io = gr.Interface(**interface_info)
2021-12-16 23:43:31 +08:00
io.api_mode = True
2021-10-22 19:50:26 +08:00
output = io("male", 77, 10)
self.assertLess(output["Survives"], 0.5)
2021-10-22 19:50:26 +08:00
2021-12-21 06:04:37 +08:00
def test_speech_recognition_model(self):
interface_info = gr.external.load_interface(
2022-03-08 03:13:17 +08:00
"models/facebook/wav2vec2-base-960h"
)
2021-12-21 06:04:37 +08:00
io = gr.Interface(**interface_info)
io.api_mode = True
output = io("test/test_data/test_audio.wav")
self.assertIsNotNone(output)
def test_text_to_image_model(self):
interface_info = gr.external.load_interface(
"models/osanseviero/BigGAN-deep-128"
)
io = gr.Interface(**interface_info)
io.api_mode = True
filename = io("chest")
self.assertTrue(filename.endswith(".jpg") or filename.endswith(".jpeg"))
def test_image_to_image_space(self):
def assertIsFile(path):
if not pathlib.Path(path).resolve().is_file():
raise AssertionError("File does not exist: %s" % str(path))
2021-10-17 15:25:04 +08:00
2021-10-22 19:50:26 +08:00
interface_info = gr.external.load_interface("spaces/abidlabs/image-identity")
2021-11-10 02:30:59 +08:00
io = gr.Interface(**interface_info)
2021-12-16 23:43:31 +08:00
io.api_mode = True
2021-11-09 04:37:32 +08:00
output = io("test/test_data/lion.jpg")
assertIsFile(output)
2021-10-17 15:25:04 +08:00
2021-12-21 06:04:37 +08:00
class TestLoadFromPipeline(unittest.TestCase):
2022-02-09 21:15:01 +08:00
def test_text_to_text_model_from_pipeline(self):
pipe = transformers.pipeline(model="sshleifer/bart-tiny-random")
output = pipe("My name is Sylvain and I work at Hugging Face in Brooklyn")
2021-12-21 06:04:37 +08:00
self.assertIsNotNone(output)
if __name__ == "__main__":
unittest.main()