mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
a1c391668a
* examples as component * renamed examples * simplify internal logic * fix tests * cleanup * fixed parallel and series * cleaning up examples * examples * formatting * fixes * added unique ids * added demo * formatting * fixed test_examples * fixed test_interfaces * fixed tests * removed test from now * raise ValueError for bad parameter values * fixing series * fixed series * formatting * speed up by preprocessing examples * fixed parameter validation logic
231 lines
9.6 KiB
Python
231 lines
9.6 KiB
Python
import os
|
|
import pathlib
|
|
import unittest
|
|
|
|
import transformers
|
|
|
|
import gradio as gr
|
|
from gradio.external import TooManyRequestsError
|
|
|
|
"""
|
|
WARNING: These tests have an external dependency: namely that Hugging Face's
|
|
Hub and Space APIs do not change, and they keep their most famous models up.
|
|
So if, e.g. Spaces is down, then these test will not pass.
|
|
|
|
These tests actually test gr.Interface.load() and gr.Blocks.load() but are
|
|
included in a separate file because of the above-mentioned dependency.
|
|
"""
|
|
|
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
|
|
|
|
|
class TestLoadInterface(unittest.TestCase):
|
|
def test_audio_to_audio(self):
|
|
model_type = "audio-to-audio"
|
|
interface = gr.Interface.load(
|
|
name="speechbrain/mtl-mimic-voicebank",
|
|
src="models",
|
|
alias=model_type,
|
|
)
|
|
self.assertEqual(interface.__name__, model_type)
|
|
self.assertIsInstance(interface.input_components[0], gr.components.Audio)
|
|
self.assertIsInstance(interface.output_components[0], gr.components.Audio)
|
|
|
|
def test_question_answering(self):
|
|
model_type = "image-classification"
|
|
interface = gr.Blocks.load(
|
|
name="lysandre/tiny-vit-random", src="models", alias=model_type
|
|
)
|
|
self.assertEqual(interface.__name__, model_type)
|
|
self.assertIsInstance(interface.input_components[0], gr.components.Image)
|
|
self.assertIsInstance(interface.output_components[0], gr.components.Label)
|
|
|
|
def test_text_generation(self):
|
|
model_type = "text_generation"
|
|
interface = gr.Interface.load("models/gpt2", alias=model_type)
|
|
self.assertEqual(interface.__name__, model_type)
|
|
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
|
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
|
|
|
def test_summarization(self):
|
|
model_type = "summarization"
|
|
interface = gr.Interface.load(
|
|
"models/facebook/bart-large-cnn", api_key=None, alias=model_type
|
|
)
|
|
self.assertEqual(interface.__name__, model_type)
|
|
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
|
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
|
|
|
def test_translation(self):
|
|
model_type = "translation"
|
|
interface = gr.Interface.load(
|
|
"models/facebook/bart-large-cnn", api_key=None, alias=model_type
|
|
)
|
|
self.assertEqual(interface.__name__, model_type)
|
|
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
|
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
|
|
|
def test_text2text_generation(self):
|
|
model_type = "text2text-generation"
|
|
interface = gr.Interface.load(
|
|
"models/sshleifer/tiny-mbart", api_key=None, alias=model_type
|
|
)
|
|
self.assertEqual(interface.__name__, model_type)
|
|
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
|
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
|
|
|
def test_text_classification(self):
|
|
model_type = "text-classification"
|
|
interface = gr.Interface.load(
|
|
"models/distilbert-base-uncased-finetuned-sst-2-english",
|
|
api_key=None,
|
|
alias=model_type,
|
|
)
|
|
self.assertEqual(interface.__name__, model_type)
|
|
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
|
self.assertIsInstance(interface.output_components[0], gr.components.Label)
|
|
|
|
def test_fill_mask(self):
|
|
model_type = "fill-mask"
|
|
interface = gr.Interface.load(
|
|
"models/bert-base-uncased", api_key=None, alias=model_type
|
|
)
|
|
self.assertEqual(interface.__name__, model_type)
|
|
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
|
self.assertIsInstance(interface.output_components[0], gr.components.Label)
|
|
|
|
def test_zero_shot_classification(self):
|
|
model_type = "zero-shot-classification"
|
|
interface = gr.Interface.load(
|
|
"models/facebook/bart-large-mnli", api_key=None, alias=model_type
|
|
)
|
|
self.assertEqual(interface.__name__, model_type)
|
|
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
|
self.assertIsInstance(interface.input_components[1], gr.components.Textbox)
|
|
self.assertIsInstance(interface.input_components[2], gr.components.Checkbox)
|
|
self.assertIsInstance(interface.output_components[0], gr.components.Label)
|
|
|
|
def test_automatic_speech_recognition(self):
|
|
model_type = "automatic-speech-recognition"
|
|
interface = gr.Interface.load(
|
|
"models/facebook/wav2vec2-base-960h", api_key=None, alias=model_type
|
|
)
|
|
self.assertEqual(interface.__name__, model_type)
|
|
self.assertIsInstance(interface.input_components[0], gr.components.Audio)
|
|
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
|
|
|
def test_image_classification(self):
|
|
model_type = "image-classification"
|
|
interface = gr.Interface.load(
|
|
"models/google/vit-base-patch16-224", api_key=None, alias=model_type
|
|
)
|
|
self.assertEqual(interface.__name__, model_type)
|
|
self.assertIsInstance(interface.input_components[0], gr.components.Image)
|
|
self.assertIsInstance(interface.output_components[0], gr.components.Label)
|
|
|
|
def test_feature_extraction(self):
|
|
model_type = "feature-extraction"
|
|
interface = gr.Interface.load(
|
|
"models/sentence-transformers/distilbert-base-nli-mean-tokens",
|
|
api_key=None,
|
|
alias=model_type,
|
|
)
|
|
self.assertEqual(interface.__name__, model_type)
|
|
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
|
self.assertIsInstance(interface.output_components[0], gr.components.Dataframe)
|
|
|
|
def test_sentence_similarity(self):
|
|
model_type = "text-to-speech"
|
|
interface = gr.Interface.load(
|
|
"models/julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train",
|
|
api_key=None,
|
|
alias=model_type,
|
|
)
|
|
self.assertEqual(interface.__name__, model_type)
|
|
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
|
self.assertIsInstance(interface.output_components[0], gr.components.Audio)
|
|
|
|
def test_text_to_speech(self):
|
|
model_type = "text-to-speech"
|
|
interface = gr.Interface.load(
|
|
"models/julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train",
|
|
api_key=None,
|
|
alias=model_type,
|
|
)
|
|
self.assertEqual(interface.__name__, model_type)
|
|
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
|
self.assertIsInstance(interface.output_components[0], gr.components.Audio)
|
|
|
|
def test_text_to_image(self):
|
|
model_type = "text-to-image"
|
|
interface = gr.Interface.load(
|
|
"models/osanseviero/BigGAN-deep-128", api_key=None, alias=model_type
|
|
)
|
|
self.assertEqual(interface.__name__, model_type)
|
|
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
|
self.assertIsInstance(interface.output_components[0], gr.components.Image)
|
|
|
|
def test_english_to_spanish(self):
|
|
interface = gr.Interface.load("spaces/abidlabs/english_to_spanish")
|
|
self.assertIsInstance(interface.input_components[0], gr.components.Textbox)
|
|
self.assertIsInstance(interface.output_components[0], gr.components.Textbox)
|
|
|
|
def test_sentiment_model(self):
|
|
io = gr.Interface.load("models/distilbert-base-uncased-finetuned-sst-2-english")
|
|
try:
|
|
output = io("I am happy, I love you")
|
|
self.assertGreater(output["POSITIVE"], 0.5)
|
|
except TooManyRequestsError:
|
|
pass
|
|
|
|
def test_image_classification_model(self):
|
|
io = gr.Blocks.load(name="models/google/vit-base-patch16-224")
|
|
try:
|
|
output = io("gradio/test_data/lion.jpg")
|
|
self.assertGreater(output["lion"], 0.5)
|
|
except TooManyRequestsError:
|
|
pass
|
|
|
|
def test_translation_model(self):
|
|
io = gr.Blocks.load(name="models/t5-base")
|
|
try:
|
|
output = io("My name is Sarah and I live in London")
|
|
self.assertEqual(output, "Mein Name ist Sarah und ich lebe in London")
|
|
except TooManyRequestsError:
|
|
pass
|
|
|
|
def test_numerical_to_label_space(self):
|
|
io = gr.Interface.load("spaces/abidlabs/titanic-survival")
|
|
try:
|
|
output = io("male", 77, 10)
|
|
self.assertLess(output["Survives"], 0.5)
|
|
except TooManyRequestsError:
|
|
pass
|
|
|
|
def test_speech_recognition_model(self):
|
|
io = gr.Interface.load("models/facebook/wav2vec2-base-960h")
|
|
try:
|
|
output = io("gradio/test_data/test_audio.wav")
|
|
self.assertIsNotNone(output)
|
|
except TooManyRequestsError:
|
|
pass
|
|
|
|
def test_text_to_image_model(self):
|
|
io = gr.Interface.load("models/osanseviero/BigGAN-deep-128")
|
|
try:
|
|
filename = io("chest")
|
|
self.assertTrue(filename.endswith(".jpg") or filename.endswith(".jpeg"))
|
|
except TooManyRequestsError:
|
|
pass
|
|
|
|
|
|
class TestLoadFromPipeline(unittest.TestCase):
|
|
def test_text_to_text_model_from_pipeline(self):
|
|
pipe = transformers.pipeline(model="sshleifer/bart-tiny-random")
|
|
output = pipe("My name is Sylvain and I work at Hugging Face in Brooklyn")
|
|
self.assertIsNotNone(output)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|