mirror of
https://github.com/gradio-app/gradio.git
synced 2024-12-21 02:19:59 +08:00
Merge pull request #551 from gradio-app/Accelerate-Tests
Accelerate Tests
This commit is contained in:
commit
f932e3fe71
@ -40,10 +40,10 @@ jobs:
|
||||
- run:
|
||||
command: |
|
||||
. venv/bin/activate
|
||||
coverage run -m unittest
|
||||
coverage run -m pytest
|
||||
coverage xml
|
||||
- codecov/upload:
|
||||
file: 'coverage.xml'
|
||||
file: 'coverage.xml'
|
||||
- store_artifacts:
|
||||
path: /home/circleci/project/test/tmp
|
||||
destination: screenshots
|
||||
|
@ -31,15 +31,7 @@ bash scripts/build_frontend.sh
|
||||
bash scripts/install_test_requirements.sh
|
||||
```
|
||||
|
||||
* Install chrome driver and chrome for selenium (necessary for tests)
|
||||
|
||||
```
|
||||
https://sites.google.com/chromium.org/driver/
|
||||
```
|
||||
|
||||
```
|
||||
https://www.google.com/chrome/
|
||||
```
|
||||
* Install [chrome driver](https://sites.google.com/chromium.org/driver/) and [chrome](https://www.google.com/chrome/) for selenium (necessary for tests)
|
||||
|
||||
* Run the tests
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
import pkg_resources
|
||||
|
||||
from gradio.routes import get_state, set_state
|
||||
from gradio.flagging import *
|
||||
from gradio.interface import *
|
||||
from gradio.mix import *
|
||||
from gradio.flagging import FlaggingCallback, SimpleCSVLogger, CSVLogger, HuggingFaceDatasetSaver
|
||||
from gradio.interface import Interface, close_all, reset_all
|
||||
from gradio.mix import Parallel, Series
|
||||
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
__version__ = current_pkg_version
|
||||
|
@ -94,6 +94,12 @@ class Interface:
|
||||
pipeline (transformers.Pipeline):
|
||||
Returns:
|
||||
(gradio.Interface): a Gradio Interface object from the given Pipeline
|
||||
|
||||
Example usage:
|
||||
import gradio as gr
|
||||
from transformers import pipeline
|
||||
pipe = pipeline(model="lysandre/tiny-vit-random")
|
||||
gr.Interface.from_pipeline(pipe).launch()
|
||||
"""
|
||||
interface_info = load_from_pipeline(pipeline)
|
||||
kwargs = dict(interface_info, **kwargs)
|
||||
|
@ -3,7 +3,7 @@ if [ -z "$(ls | grep CONTRIBUTING.md)" ]; then
|
||||
echo "Please run the script from repo directory"
|
||||
exit -1
|
||||
else
|
||||
echo "Creating requirements under test/requirements.txt using requirements.in"
|
||||
echo "Creating requirements under test/requirements.txt using requirements.in. Please run this script from unix or wsl!"
|
||||
cd test
|
||||
pip install --upgrade pip-tools
|
||||
pip-compile
|
||||
|
@ -4,5 +4,5 @@ if [ -z "$(ls | grep CONTRIBUTING.md)" ]; then
|
||||
exit -1
|
||||
else
|
||||
echo "Running the tests"
|
||||
python -m unittest
|
||||
python -m pytest --cov=gradio --durations=20 --durations-min=1 test
|
||||
fi
|
@ -4,5 +4,5 @@ if [ -z "$(ls | grep CONTRIBUTING.md)" ]; then
|
||||
exit -1
|
||||
else
|
||||
echo "Running the tests"
|
||||
python -m pytest --durations=0 test
|
||||
python -m pytest --cov=gradio --durations=20 --durations-min=0.1 test/local
|
||||
fi
|
@ -1,3 +1,4 @@
|
||||
# Don't forget to run bash scripts/create_test_requirements.sh from unix or wsl when you update this file.
|
||||
IPython
|
||||
comet_ml
|
||||
coverage
|
||||
@ -10,3 +11,8 @@ shap
|
||||
pytest
|
||||
wandb
|
||||
huggingface_hub
|
||||
pytest-cov
|
||||
black
|
||||
isort
|
||||
flake8
|
||||
torch
|
||||
|
@ -8,10 +8,8 @@ absl-py==1.0.0
|
||||
# via
|
||||
# tensorboard
|
||||
# tensorflow
|
||||
alembic==1.7.5
|
||||
alembic==1.7.6
|
||||
# via mlflow
|
||||
appnope==0.1.2
|
||||
# via ipython
|
||||
asttokens==2.0.5
|
||||
# via stack-data
|
||||
astunparse==1.6.3
|
||||
@ -22,8 +20,10 @@ attrs==21.4.0
|
||||
# pytest
|
||||
backcall==0.2.0
|
||||
# via ipython
|
||||
black~=22.1
|
||||
# via ipython
|
||||
black==22.1.0
|
||||
# via
|
||||
# -r requirements.in
|
||||
# ipython
|
||||
cachetools==5.0.0
|
||||
# via google-auth
|
||||
certifi==2021.10.8
|
||||
@ -34,7 +34,7 @@ certifi==2021.10.8
|
||||
# urllib3
|
||||
cffi==1.15.0
|
||||
# via cryptography
|
||||
charset-normalizer==2.0.10
|
||||
charset-normalizer==2.0.11
|
||||
# via requests
|
||||
click==8.0.3
|
||||
# via
|
||||
@ -48,19 +48,19 @@ cloudpickle==2.0.0
|
||||
# via
|
||||
# mlflow
|
||||
# shap
|
||||
comet-ml==3.24.2
|
||||
comet-ml==3.25.0
|
||||
# via -r requirements.in
|
||||
configobj==5.0.6
|
||||
# via everett
|
||||
configparser==5.2.0
|
||||
# via wandb
|
||||
coverage==6.2
|
||||
# via -r requirements.in
|
||||
coverage[toml]==6.3.1
|
||||
# via
|
||||
# -r requirements.in
|
||||
# pytest-cov
|
||||
cryptography==36.0.1
|
||||
# via
|
||||
# pyopenssl
|
||||
# urllib3
|
||||
databricks-cli==0.16.2
|
||||
databricks-cli==0.16.4
|
||||
# via mlflow
|
||||
decorator==5.1.1
|
||||
# via ipython
|
||||
@ -68,9 +68,9 @@ docker==5.0.3
|
||||
# via mlflow
|
||||
docker-pycreds==0.4.0
|
||||
# via wandb
|
||||
dulwich==0.20.30
|
||||
dulwich==0.20.32
|
||||
# via comet-ml
|
||||
entrypoints==0.3
|
||||
entrypoints==0.4
|
||||
# via mlflow
|
||||
everett[ini]==3.0.0
|
||||
# via comet-ml
|
||||
@ -80,13 +80,15 @@ filelock==3.4.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
flake8==4.0.1
|
||||
# via -r requirements.in
|
||||
flask==2.0.2
|
||||
# via
|
||||
# mlflow
|
||||
# prometheus-flask-exporter
|
||||
flatbuffers==2.0
|
||||
# via tensorflow
|
||||
gast==0.4.0
|
||||
gast==0.5.3
|
||||
# via tensorflow
|
||||
gitdb==4.0.9
|
||||
# via gitpython
|
||||
@ -94,7 +96,7 @@ gitpython==3.1.26
|
||||
# via
|
||||
# mlflow
|
||||
# wandb
|
||||
google-auth==2.4.0
|
||||
google-auth==2.6.0
|
||||
# via
|
||||
# google-auth-oauthlib
|
||||
# tensorboard
|
||||
@ -113,7 +115,9 @@ gunicorn==20.1.0
|
||||
h5py==3.6.0
|
||||
# via tensorflow
|
||||
huggingface-hub==0.4.0
|
||||
# via transformers
|
||||
# via
|
||||
# -r requirements.in
|
||||
# transformers
|
||||
idna==3.3
|
||||
# via
|
||||
# requests
|
||||
@ -128,6 +132,8 @@ iniconfig==1.1.1
|
||||
# via pytest
|
||||
ipython==8.0.1
|
||||
# via -r requirements.in
|
||||
isort==5.10.1
|
||||
# via -r requirements.in
|
||||
itsdangerous==2.0.1
|
||||
# via flask
|
||||
jedi==0.18.1
|
||||
@ -140,11 +146,11 @@ joblib==1.1.0
|
||||
# scikit-learn
|
||||
jsonschema==4.4.0
|
||||
# via comet-ml
|
||||
keras==2.7.0
|
||||
keras==2.8.0
|
||||
# via tensorflow
|
||||
keras-preprocessing==1.1.2
|
||||
# via tensorflow
|
||||
libclang==12.0.0
|
||||
libclang==13.0.0
|
||||
# via tensorflow
|
||||
llvmlite==0.38.0
|
||||
# via numba
|
||||
@ -158,13 +164,15 @@ markupsafe==2.0.1
|
||||
# mako
|
||||
matplotlib-inline==0.1.3
|
||||
# via ipython
|
||||
mlflow==1.23.0
|
||||
mccabe==0.6.1
|
||||
# via flake8
|
||||
mlflow==1.23.1
|
||||
# via -r requirements.in
|
||||
mypy-extensions==0.4.3
|
||||
# via black
|
||||
networkx==2.6.3
|
||||
# via scikit-image
|
||||
numba==0.55.0
|
||||
numba==0.55.1
|
||||
# via shap
|
||||
numpy==1.21.5
|
||||
# via
|
||||
@ -186,7 +194,7 @@ numpy==1.21.5
|
||||
# transformers
|
||||
nvidia-ml-py3==7.352.0
|
||||
# via comet-ml
|
||||
oauthlib==3.1.1
|
||||
oauthlib==3.2.0
|
||||
# via requests-oauthlib
|
||||
opt-einsum==3.3.0
|
||||
# via tensorflow
|
||||
@ -198,7 +206,7 @@ packaging==21.3
|
||||
# scikit-image
|
||||
# shap
|
||||
# transformers
|
||||
pandas==1.3.5
|
||||
pandas==1.4.0
|
||||
# via
|
||||
# mlflow
|
||||
# shap
|
||||
@ -212,7 +220,7 @@ pexpect==4.8.0
|
||||
# via ipython
|
||||
pickleshare==0.7.5
|
||||
# via ipython
|
||||
pillow==9.0.0
|
||||
pillow==9.0.1
|
||||
# via
|
||||
# imageio
|
||||
# scikit-image
|
||||
@ -220,15 +228,15 @@ platformdirs==2.4.1
|
||||
# via black
|
||||
pluggy==1.0.0
|
||||
# via pytest
|
||||
prometheus-client==0.12.0
|
||||
prometheus-client==0.13.1
|
||||
# via prometheus-flask-exporter
|
||||
prometheus-flask-exporter==0.18.7
|
||||
# via mlflow
|
||||
promise==2.3
|
||||
# via wandb
|
||||
prompt-toolkit==3.0.24
|
||||
prompt-toolkit==3.0.26
|
||||
# via ipython
|
||||
protobuf==3.19.3
|
||||
protobuf==3.19.4
|
||||
# via
|
||||
# mlflow
|
||||
# tensorboard
|
||||
@ -248,17 +256,25 @@ pyasn1==0.4.8
|
||||
# rsa
|
||||
pyasn1-modules==0.2.8
|
||||
# via google-auth
|
||||
pycodestyle==2.8.0
|
||||
# via flake8
|
||||
pycparser==2.21
|
||||
# via cffi
|
||||
pyflakes==2.4.0
|
||||
# via flake8
|
||||
pygments==2.11.2
|
||||
# via ipython
|
||||
pyopenssl==21.0.0
|
||||
pyopenssl==22.0.0
|
||||
# via urllib3
|
||||
pyparsing==3.0.7
|
||||
# via packaging
|
||||
pyrsistent==0.18.1
|
||||
# via jsonschema
|
||||
pytest==6.2.5
|
||||
pytest==7.0.0
|
||||
# via
|
||||
# -r requirements.in
|
||||
# pytest-cov
|
||||
pytest-cov==3.0.0
|
||||
# via -r requirements.in
|
||||
python-dateutil==2.8.2
|
||||
# via
|
||||
@ -294,7 +310,7 @@ requests==2.27.1
|
||||
# tensorboard
|
||||
# transformers
|
||||
# wandb
|
||||
requests-oauthlib==1.3.0
|
||||
requests-oauthlib==1.3.1
|
||||
# via google-auth-oauthlib
|
||||
requests-toolbelt==0.9.1
|
||||
# via comet-ml
|
||||
@ -306,7 +322,7 @@ scikit-image==0.19.1
|
||||
# via -r requirements.in
|
||||
scikit-learn==1.0.2
|
||||
# via shap
|
||||
scipy==1.7.3
|
||||
scipy==1.8.0
|
||||
# via
|
||||
# mlflow
|
||||
# scikit-image
|
||||
@ -314,9 +330,9 @@ scipy==1.7.3
|
||||
# shap
|
||||
selenium==4.0.0a6.post2
|
||||
# via -r requirements.in
|
||||
semantic-version==2.8.5
|
||||
semantic-version==2.9.0
|
||||
# via comet-ml
|
||||
sentry-sdk==1.5.3
|
||||
sentry-sdk==1.5.4
|
||||
# via wandb
|
||||
shap==0.40.0
|
||||
# via -r requirements.in
|
||||
@ -336,7 +352,6 @@ six==1.16.0
|
||||
# grpcio
|
||||
# keras-preprocessing
|
||||
# promise
|
||||
# pyopenssl
|
||||
# python-dateutil
|
||||
# querystring-parser
|
||||
# sacremoses
|
||||
@ -354,8 +369,6 @@ sqlparse==0.4.2
|
||||
# via mlflow
|
||||
stack-data==0.1.4
|
||||
# via ipython
|
||||
subprocess32==3.5.4
|
||||
# via wandb
|
||||
tabulate==0.8.9
|
||||
# via databricks-cli
|
||||
tensorboard==2.8.0
|
||||
@ -364,26 +377,29 @@ tensorboard-data-server==0.6.1
|
||||
# via tensorboard
|
||||
tensorboard-plugin-wit==1.8.1
|
||||
# via tensorboard
|
||||
tensorflow==2.7.0
|
||||
tensorflow==2.8.0
|
||||
# via -r requirements.in
|
||||
tensorflow-estimator==2.7.0
|
||||
# via tensorflow
|
||||
tensorflow-io-gcs-filesystem==0.23.1
|
||||
tensorflow-io-gcs-filesystem==0.24.0
|
||||
# via tensorflow
|
||||
termcolor==1.1.0
|
||||
# via
|
||||
# tensorflow
|
||||
# yaspin
|
||||
threadpoolctl==3.0.0
|
||||
tf-estimator-nightly==2.8.0.dev2021122109
|
||||
# via tensorflow
|
||||
threadpoolctl==3.1.0
|
||||
# via scikit-learn
|
||||
tifffile==2021.11.2
|
||||
tifffile==2022.2.2
|
||||
# via scikit-image
|
||||
tokenizers==0.10.3
|
||||
tokenizers==0.11.4
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via pytest
|
||||
tomli==1.2.3
|
||||
# via black
|
||||
tomli==2.0.0
|
||||
# via
|
||||
# black
|
||||
# coverage
|
||||
# pytest
|
||||
torch==1.10.2
|
||||
# via -r requirements.in
|
||||
tqdm==4.62.3
|
||||
# via
|
||||
# huggingface-hub
|
||||
@ -394,20 +410,21 @@ traitlets==5.1.1
|
||||
# via
|
||||
# ipython
|
||||
# matplotlib-inline
|
||||
transformers==4.15.0
|
||||
transformers==4.16.2
|
||||
# via -r requirements.in
|
||||
typing-extensions==3.10.0.2
|
||||
typing-extensions==4.0.1
|
||||
# via
|
||||
# black
|
||||
# huggingface-hub
|
||||
# tensorflow
|
||||
# torch
|
||||
urllib3[secure]==1.26.8
|
||||
# via
|
||||
# dulwich
|
||||
# requests
|
||||
# selenium
|
||||
# sentry-sdk
|
||||
wandb==0.12.9
|
||||
wandb==0.12.10
|
||||
# via -r requirements.in
|
||||
wcwidth==0.2.5
|
||||
# via prompt-toolkit
|
||||
@ -423,7 +440,6 @@ wheel==0.37.1
|
||||
# via
|
||||
# astunparse
|
||||
# tensorboard
|
||||
# tensorflow
|
||||
wrapt==1.13.3
|
||||
# via
|
||||
# comet-ml
|
||||
|
@ -30,13 +30,11 @@ class TestHuggingFaceModelAPI(unittest.TestCase):
|
||||
def test_question_answering(self):
|
||||
model_type = "question-answering"
|
||||
interface_info = gr.external.get_huggingface_interface(
|
||||
"deepset/roberta-base-squad2", api_key=None, alias=model_type
|
||||
"lysandre/tiny-vit-random", api_key=None, alias=model_type
|
||||
)
|
||||
self.assertEqual(interface_info["fn"].__name__, model_type)
|
||||
self.assertIsInstance(interface_info["inputs"][0], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["inputs"][1], gr.inputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"][0], gr.outputs.Textbox)
|
||||
self.assertIsInstance(interface_info["outputs"][1], gr.outputs.Label)
|
||||
self.assertIsInstance(interface_info["inputs"], gr.inputs.Image)
|
||||
self.assertIsInstance(interface_info["outputs"], gr.outputs.Label)
|
||||
|
||||
def test_text_generation(self):
|
||||
model_type = "text_generation"
|
||||
@ -246,11 +244,9 @@ class TestLoadInterface(unittest.TestCase):
|
||||
|
||||
class TestLoadFromPipeline(unittest.TestCase):
|
||||
def test_question_answering(self):
|
||||
p = transformers.pipeline("question-answering")
|
||||
io = gr.Interface.from_pipeline(p)
|
||||
output = io(
|
||||
"My name is Sylvain and I work at Hugging Face in Brooklyn",
|
||||
"Where do I work?",
|
||||
pipe = transformers.pipeline(model="sshleifer/bart-tiny-random")
|
||||
output = pipe(
|
||||
"My name is Sylvain and I work at Hugging Face in Brooklyn"
|
||||
)
|
||||
self.assertIsNotNone(output)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user