mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Merge pull request #493 from gradio-app/backend-tests
Improve backend tests
This commit is contained in:
commit
9c478e2d50
@ -18,7 +18,7 @@ jobs:
|
||||
. venv/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install -r gradio.egg-info/requires.txt
|
||||
pip install shap IPython comet_ml wandb mlflow tensorflow transformers
|
||||
pip install shap IPython comet_ml wandb mlflow tensorflow transformers huggingface_hub
|
||||
pip install selenium==4.0.0a6.post2 coverage scikit-image
|
||||
- run:
|
||||
command: |
|
||||
|
@ -1,8 +1,8 @@
|
||||
import pkg_resources
|
||||
|
||||
from gradio.app import get_state, set_state
|
||||
from gradio.routes import get_state, set_state
|
||||
from gradio.flagging import *
|
||||
from gradio.interface import * # This makes it possible to import `Interface` as `gradio.Interface`.
|
||||
from gradio.interface import *
|
||||
from gradio.mix import *
|
||||
|
||||
current_pkg_version = pkg_resources.require("gradio")[0].version
|
||||
|
@ -6,6 +6,7 @@ automatically added to a registry, which allows them to be easily referenced in
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
import tempfile
|
||||
@ -1182,8 +1183,7 @@ class Audio(InputComponent):
|
||||
return self
|
||||
|
||||
def tokenize(self, x):
|
||||
file_obj = processing_utils.decode_base64_to_file(x)
|
||||
sample_rate, data = processing_utils.audio_from_file(x)
|
||||
sample_rate, data = processing_utils.audio_from_file(x["name"])
|
||||
leave_one_out_sets = []
|
||||
tokens = []
|
||||
masks = []
|
||||
@ -1193,20 +1193,27 @@ class Audio(InputComponent):
|
||||
for index in range(len(boundaries) - 1):
|
||||
start, stop = boundaries[index], boundaries[index + 1]
|
||||
masks.append((start, stop))
|
||||
|
||||
# Handle the leave one outs
|
||||
leave_one_out_data = np.copy(data)
|
||||
leave_one_out_data[start:stop] = 0
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
||||
processing_utils.audio_to_file(sample_rate, leave_one_out_data, file.name)
|
||||
out_data = processing_utils.encode_file_to_base64(file.name)
|
||||
leave_one_out_sets.append(out_data)
|
||||
file.close()
|
||||
os.unlink(file.name)
|
||||
|
||||
# Handle the tokens
|
||||
token = np.copy(data)
|
||||
token[0:start] = 0
|
||||
token[stop:] = 0
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
||||
processing_utils.audio_to_file(sample_rate, token, file.name)
|
||||
token_data = processing_utils.encode_file_to_base64(file.name)
|
||||
file.close()
|
||||
os.unlink(file.name)
|
||||
|
||||
tokens.append(token_data)
|
||||
return tokens, leave_one_out_sets, masks
|
||||
|
||||
@ -1215,7 +1222,7 @@ class Audio(InputComponent):
|
||||
x = tokens[0]
|
||||
file_obj = processing_utils.decode_base64_to_file(x)
|
||||
sample_rate, data = processing_utils.audio_from_file(file_obj.name)
|
||||
zero_input = np.zeros_like(data, dtype=int)
|
||||
zero_input = np.zeros_like(data, dtype='int16')
|
||||
# decode all of the tokens
|
||||
token_data = []
|
||||
for token in tokens:
|
||||
@ -1229,8 +1236,10 @@ class Audio(InputComponent):
|
||||
for t, b in zip(token_data, binary_mask_vector):
|
||||
masked_input = masked_input + t * int(b)
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
processing_utils.audio_to_file(sample_rate, masked_input, file_obj.name)
|
||||
processing_utils.audio_to_file(sample_rate, masked_input, file.name)
|
||||
masked_data = processing_utils.encode_file_to_base64(file.name)
|
||||
file.close()
|
||||
os.unlink(file.name)
|
||||
masked_inputs.append(masked_data)
|
||||
return masked_inputs
|
||||
|
||||
|
@ -17,7 +17,7 @@ import webbrowser
|
||||
from logging import warning
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
|
||||
import markdown2 # type: ignore
|
||||
import markdown2
|
||||
|
||||
from gradio import (encryptor, interpretation, networking, # type: ignore
|
||||
queueing, strings, utils)
|
||||
|
@ -19,7 +19,7 @@ import requests
|
||||
import uvicorn
|
||||
|
||||
from gradio import queueing
|
||||
from gradio.app import app
|
||||
from gradio.routes import app
|
||||
from gradio.tunneling import create_tunnel
|
||||
|
||||
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
|
||||
@ -73,29 +73,6 @@ def get_first_available_port(initial: int, final: int) -> int:
|
||||
)
|
||||
|
||||
|
||||
def queue_thread(path_to_local_server, test_mode=False):
|
||||
while True:
|
||||
try:
|
||||
next_job = queueing.pop()
|
||||
if next_job is not None:
|
||||
_, hash, input_data, task_type = next_job
|
||||
queueing.start_job(hash)
|
||||
response = requests.post(
|
||||
path_to_local_server + "api/" + task_type + "/", json=input_data
|
||||
)
|
||||
if response.status_code == 200:
|
||||
queueing.pass_job(hash, response.json())
|
||||
else:
|
||||
queueing.fail_job(hash, response.text)
|
||||
else:
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
time.sleep(1)
|
||||
pass
|
||||
if test_mode:
|
||||
break
|
||||
|
||||
|
||||
def start_server(
|
||||
interface: Interface,
|
||||
server_name: Optional[str] = None,
|
||||
@ -147,7 +124,7 @@ def start_server(
|
||||
raise ValueError("Cannot queue with encryption or authentication enabled.")
|
||||
queueing.init()
|
||||
app.queue_thread = threading.Thread(
|
||||
target=queue_thread, args=(path_to_local_server,)
|
||||
target=queueing.queue_thread, args=(path_to_local_server,)
|
||||
)
|
||||
app.queue_thread.start()
|
||||
if interface.save_to is not None: # Used for selenium tests
|
||||
|
@ -25,14 +25,6 @@ def decode_base64_to_image(encoding):
|
||||
return Image.open(BytesIO(base64.b64decode(image_encoded)))
|
||||
|
||||
|
||||
def get_url_or_file_as_bytes(path):
|
||||
try:
|
||||
return requests.get(path).content
|
||||
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
|
||||
with open(path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def encode_url_or_file_to_base64(path):
|
||||
try:
|
||||
requests.get(path)
|
||||
@ -144,7 +136,7 @@ def audio_to_file(sample_rate, data, filename):
|
||||
sample_width=data.dtype.itemsize,
|
||||
channels=(1 if len(data.shape) == 1 else data.shape[1]),
|
||||
)
|
||||
audio.export(filename, format="wav")
|
||||
audio.export(filename, format="wav").close()
|
||||
|
||||
|
||||
##################
|
||||
|
@ -1,12 +1,39 @@
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import Dict, Tuple
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
|
||||
DB_FILE = "gradio_queue.db"
|
||||
|
||||
|
||||
def generate_hash():
|
||||
def queue_thread(
|
||||
path_to_local_server: str
|
||||
) -> None:
|
||||
while True:
|
||||
try:
|
||||
next_job = pop()
|
||||
if next_job is not None:
|
||||
_, hash, input_data, task_type = next_job
|
||||
start_job(hash)
|
||||
response = requests.post(
|
||||
path_to_local_server + "api/" + task_type + "/", json=input_data
|
||||
)
|
||||
if response.status_code == 200:
|
||||
pass_job(hash, response.json())
|
||||
else:
|
||||
fail_job(hash, response.text)
|
||||
else:
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
time.sleep(1)
|
||||
pass
|
||||
|
||||
|
||||
def generate_hash() -> str:
|
||||
generate = True
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
@ -24,7 +51,7 @@ def generate_hash():
|
||||
return hash
|
||||
|
||||
|
||||
def init():
|
||||
def init() -> None:
|
||||
if os.path.exists(DB_FILE):
|
||||
os.remove(DB_FILE)
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
@ -51,12 +78,12 @@ def init():
|
||||
conn.commit()
|
||||
|
||||
|
||||
def close():
|
||||
def close() -> None:
|
||||
if os.path.exists(DB_FILE):
|
||||
os.remove(DB_FILE)
|
||||
|
||||
|
||||
def pop():
|
||||
def pop() -> Tuple[int, str, Dict, str]:
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute("BEGIN EXCLUSIVE")
|
||||
@ -81,7 +108,10 @@ def pop():
|
||||
return result[0], result[1], json.loads(result[2]), result[3]
|
||||
|
||||
|
||||
def push(input_data, action):
|
||||
def push(
|
||||
input_data: Dict,
|
||||
action: str
|
||||
) -> Tuple[str, int]:
|
||||
input_data = json.dumps(input_data)
|
||||
hash = generate_hash()
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
@ -104,20 +134,19 @@ def push(input_data, action):
|
||||
if queue_position is None:
|
||||
conn.commit()
|
||||
raise ValueError("Hash not found.")
|
||||
elif queue_position == 0:
|
||||
c.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM jobs WHERE status = "PENDING";
|
||||
c.execute(
|
||||
"""
|
||||
)
|
||||
result = c.fetchone()
|
||||
if result[0] == 0:
|
||||
queue_position -= 1
|
||||
SELECT COUNT(*) FROM jobs WHERE status = "PENDING";
|
||||
"""
|
||||
)
|
||||
result = c.fetchone()
|
||||
if not(result[0] == 0):
|
||||
queue_position += 1
|
||||
conn.commit()
|
||||
return hash, queue_position
|
||||
|
||||
|
||||
def get_status(hash):
|
||||
def get_status(hash: str) -> Tuple[str, int]:
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
@ -169,20 +198,19 @@ def get_status(hash):
|
||||
)
|
||||
result = c.fetchone()
|
||||
queue_position = result[0]
|
||||
if queue_position == 0:
|
||||
c.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM jobs WHERE status = "PENDING";
|
||||
c.execute(
|
||||
"""
|
||||
)
|
||||
result = c.fetchone()
|
||||
if result[0] == 0:
|
||||
queue_position -= 1
|
||||
SELECT COUNT(*) FROM jobs WHERE status = "PENDING";
|
||||
"""
|
||||
)
|
||||
result = c.fetchone()
|
||||
if not(result[0] == 0):
|
||||
queue_position += 1
|
||||
conn.commit()
|
||||
return "QUEUED", queue_position
|
||||
|
||||
|
||||
def start_job(hash):
|
||||
def start_job(hash: str) -> None:
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute("BEGIN EXCLUSIVE")
|
||||
@ -201,7 +229,10 @@ def start_job(hash):
|
||||
conn.commit()
|
||||
|
||||
|
||||
def fail_job(hash, error_message):
|
||||
def fail_job(
|
||||
hash: str,
|
||||
error_message: str
|
||||
) -> None:
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
@ -216,7 +247,10 @@ def fail_job(hash, error_message):
|
||||
conn.commit()
|
||||
|
||||
|
||||
def pass_job(hash, output_data):
|
||||
def pass_job(
|
||||
hash: str,
|
||||
output_data: Dict
|
||||
) -> None:
|
||||
output_data = json.dumps(output_data)
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
c = conn.cursor()
|
||||
|
@ -129,15 +129,17 @@ def static_resource(path: str):
|
||||
|
||||
@app.get("/file/{path:path}", dependencies=[Depends(login_check)])
|
||||
def file(path):
|
||||
if app.interface.encrypt and isinstance(
|
||||
app.interface.examples, str) and path.startswith(
|
||||
app.interface.examples):
|
||||
if (
|
||||
app.interface.encrypt
|
||||
and isinstance(app.interface.examples, str)
|
||||
and path.startswith(app.interface.examples)
|
||||
):
|
||||
with open(safe_join(app.cwd, path), "rb") as encrypted_file:
|
||||
encrypted_data = encrypted_file.read()
|
||||
file_data = encryptor.decrypt(
|
||||
app.interface.encryption_key, encrypted_data)
|
||||
file_data = encryptor.decrypt(app.interface.encryption_key, encrypted_data)
|
||||
return FileResponse(
|
||||
io.BytesIO(file_data), attachment_filename=os.path.basename(path))
|
||||
io.BytesIO(file_data), attachment_filename=os.path.basename(path)
|
||||
)
|
||||
else:
|
||||
return FileResponse(safe_join(app.cwd, path))
|
||||
|
@ -1,12 +1,16 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
import gradio as gr
|
||||
from gradio import flagging
|
||||
|
||||
|
||||
class TestDefaultFlagging(unittest.TestCase):
|
||||
def test_default_flagging_handler(self):
|
||||
def test_default_flagging_callback(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
|
||||
io.launch(prevent_thread_lock=True)
|
||||
@ -18,7 +22,7 @@ class TestDefaultFlagging(unittest.TestCase):
|
||||
|
||||
|
||||
class TestSimpleFlagging(unittest.TestCase):
|
||||
def test_simple_csv_flagging_handler(self):
|
||||
def test_simple_csv_flagging_callback(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(
|
||||
lambda x: x,
|
||||
@ -29,11 +33,39 @@ class TestSimpleFlagging(unittest.TestCase):
|
||||
)
|
||||
io.launch(prevent_thread_lock=True)
|
||||
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
self.assertEqual(row_count, 0) # no header
|
||||
self.assertEqual(row_count, 0) # no header in SimpleCSVLogger
|
||||
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
self.assertEqual(row_count, 1) # no header
|
||||
self.assertEqual(row_count, 1) # no header in SimpleCSVLogger
|
||||
io.close()
|
||||
|
||||
|
||||
class TestHuggingFaceDatasetSaver(unittest.TestCase):
|
||||
def test_saver_setup(self):
|
||||
huggingface_hub.create_repo = MagicMock()
|
||||
huggingface_hub.Repository = MagicMock()
|
||||
flagger = flagging.HuggingFaceDatasetSaver("test", "test")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
flagger.setup(tmpdirname)
|
||||
huggingface_hub.create_repo.assert_called_once()
|
||||
|
||||
def test_saver_flag(self):
|
||||
huggingface_hub.create_repo = MagicMock()
|
||||
huggingface_hub.Repository = MagicMock()
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
io = gr.Interface(
|
||||
lambda x: x,
|
||||
"text",
|
||||
"text",
|
||||
flagging_dir=tmpdirname,
|
||||
flagging_callback=flagging.HuggingFaceDatasetSaver("test", "test"),
|
||||
)
|
||||
os.mkdir(os.path.join(tmpdirname, "test"))
|
||||
io.launch(prevent_thread_lock=True)
|
||||
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
self.assertEqual(row_count, 1) # 2 rows written including header
|
||||
row_count = io.flagging_callback.flag(io, ["test"], ["test"])
|
||||
self.assertEqual(row_count, 2) # 3 rows written including header
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,3 +1,4 @@
|
||||
from difflib import SequenceMatcher
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
@ -565,21 +566,14 @@ class TestAudio(unittest.TestCase):
|
||||
x_wav = gr.processing_utils.audio_from_file("test/test_files/audio_sample.wav")
|
||||
self.assertIsInstance(audio_input.serialize(x_wav, False), dict)
|
||||
|
||||
# def test_in_interface(self):
|
||||
# x_wav = gr.test_data.BASE64_AUDIO
|
||||
# def max_amplitude_from_wav_file(wav_file):
|
||||
# audio_segment = AudioSegment.from_file(wav_file.name)
|
||||
# data = np.array(audio_segment.get_array_of_samples())
|
||||
# return np.max(data)
|
||||
# iface = gr.Interface(
|
||||
# max_amplitude_from_wav_file,
|
||||
# gr.inputs.Audio(type="file"),
|
||||
# "number", interpretation="default")
|
||||
# # TODO(aliabd): investigate why this sometimes fails (returns 5239 or 576)
|
||||
# self.assertEqual(iface.process([x_wav])[0], [576])
|
||||
# scores, alternative_outputs = iface.interpret([x_wav])
|
||||
# self.assertEqual(scores, ... )
|
||||
# self.assertEqual(alternative_outputs, ...)
|
||||
def test_tokenize(self):
|
||||
x_wav = gr.test_data.BASE64_AUDIO
|
||||
audio_input = gr.inputs.Audio()
|
||||
tokens, _, _ = audio_input.tokenize(x_wav)
|
||||
self.assertEquals(len(tokens), audio_input.interpretation_segments)
|
||||
x_new = audio_input.get_masked_inputs(tokens, [[1]*len(tokens)])[0]
|
||||
similarity = SequenceMatcher(a=x_wav["data"], b=x_new).ratio()
|
||||
self.assertGreater(similarity, 0.9)
|
||||
|
||||
|
||||
class TestFile(unittest.TestCase):
|
||||
|
@ -9,7 +9,7 @@ import warnings
|
||||
import aiohttp
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from gradio import Interface, flagging, networking, reset_all, utils
|
||||
from gradio import Interface, flagging, networking, reset_all, queueing
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
@ -35,77 +35,6 @@ class TestPort(unittest.TestCase):
|
||||
warnings.warn("Unable to test, no ports available")
|
||||
|
||||
|
||||
class TestRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
|
||||
self.client = TestClient(self.app)
|
||||
|
||||
def test_get_main_route(self):
|
||||
response = self.client.get("/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_get_api_route(self):
|
||||
response = self.client.get("/api/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_static_files_served_safely(self):
|
||||
# Make sure things outside the static folder are not accessible
|
||||
response = self.client.get(r"/static/..%2findex.html")
|
||||
self.assertEqual(response.status_code, 404)
|
||||
response = self.client.get(r"/static/..%2f..%2fapi_docs.html")
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
def test_get_config_route(self):
|
||||
response = self.client.get("/config/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_predict_route(self):
|
||||
response = self.client.post("/api/predict/", json={"data": ["test"]})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
output = dict(response.json())
|
||||
self.assertEqual(output["data"], ["test"])
|
||||
self.assertTrue("durations" in output)
|
||||
self.assertTrue("avg_durations" in output)
|
||||
|
||||
# def test_queue_push_route(self):
|
||||
# networking.queue.push = mock.MagicMock(return_value=(None, None))
|
||||
# response = self.client.post('/api/queue/push/', json={"data": "test", "action": "test"})
|
||||
# self.assertEqual(response.status_code, 200)
|
||||
|
||||
# def test_queue_push_route(self):
|
||||
# networking.queue.get_status = mock.MagicMock(return_value=(None, None))
|
||||
# response = self.client.post('/api/queue/status/', json={"hash": "test"})
|
||||
# self.assertEqual(response.status_code, 200)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.io.close()
|
||||
reset_all()
|
||||
|
||||
|
||||
class TestAuthenticatedRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(
|
||||
auth=("test", "correct_password"), prevent_thread_lock=True
|
||||
)
|
||||
self.client = TestClient(self.app)
|
||||
|
||||
def test_post_login(self):
|
||||
response = self.client.post(
|
||||
"/login", data=dict(username="test", password="correct_password")
|
||||
)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
response = self.client.post(
|
||||
"/login", data=dict(username="test", password="incorrect_password")
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.io.close()
|
||||
reset_all()
|
||||
|
||||
|
||||
class TestInterfaceCustomParameters(unittest.TestCase):
|
||||
def test_show_error(self):
|
||||
io = Interface(lambda x: 1 / x, "number", "number")
|
||||
@ -180,26 +109,5 @@ class TestURLs(unittest.TestCase):
|
||||
self.assertTrue(res)
|
||||
|
||||
|
||||
# class TestQueuing(unittest.TestCase):
|
||||
# def test_queueing(self):
|
||||
# # mock queue methods and post method
|
||||
# networking.queue.pop = mock.MagicMock(return_value=(None, None, None, 'predict'))
|
||||
# networking.queue.pass_job = mock.MagicMock(return_value=(None, None))
|
||||
# networking.queue.fail_job = mock.MagicMock(return_value=(None, None))
|
||||
# networking.queue.start_job = mock.MagicMock(return_value=None)
|
||||
# requests.post = mock.MagicMock(return_value=mock.MagicMock(status_code=200))
|
||||
# # execute queue action successfully
|
||||
# networking.queue_thread('test_path', test_mode=True)
|
||||
# networking.queue.pass_job.assert_called_once()
|
||||
# # execute queue action unsuccessfully
|
||||
# requests.post = mock.MagicMock(return_value=mock.MagicMock(status_code=500))
|
||||
# networking.queue_thread('test_path', test_mode=True)
|
||||
# networking.queue.fail_job.assert_called_once()
|
||||
# # no more things on the queue so methods shouldn't be called any more times
|
||||
# networking.queue.pop = mock.MagicMock(return_value=None)
|
||||
# networking.queue.pass_job.assert_called_once()
|
||||
# networking.queue.fail_job.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
68
test/test_queuing.py
Normal file
68
test/test_queuing.py
Normal file
@ -0,0 +1,68 @@
|
||||
"""Contains tests for networking.py and app.py"""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
import requests
|
||||
|
||||
from gradio import Interface, queueing
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestQueuingOpenClose(unittest.TestCase):
|
||||
def test_init(self):
|
||||
queueing.init()
|
||||
self.assertTrue(os.path.exists(queueing.DB_FILE))
|
||||
os.remove(queueing.DB_FILE)
|
||||
|
||||
def test_close(self):
|
||||
queueing.close()
|
||||
self.assertFalse(os.path.exists(queueing.DB_FILE))
|
||||
|
||||
class TestQueuingActions(unittest.TestCase):
|
||||
def setUp(self):
|
||||
queueing.init()
|
||||
|
||||
def test_hashing(self):
|
||||
hash1 = queueing.generate_hash()
|
||||
hash2 = queueing.generate_hash()
|
||||
self.assertNotEquals(hash1, hash2)
|
||||
queueing.close()
|
||||
|
||||
def test_push_pop_status(self):
|
||||
hash1, position = queueing.push({"data": "test1"}, "predict")
|
||||
self.assertEquals(position, 0)
|
||||
hash2, position = queueing.push({"data": "test2"}, "predict")
|
||||
self.assertEquals(position, 1)
|
||||
status, position = queueing.get_status(hash2)
|
||||
self.assertEquals(status, "QUEUED")
|
||||
self.assertEquals(position, 1)
|
||||
_, hash_popped, input_data, action = queueing.pop()
|
||||
self.assertEquals(hash_popped, hash1)
|
||||
self.assertEquals(input_data, {"data": "test1"})
|
||||
self.assertEquals(action, "predict")
|
||||
|
||||
def test_jobs(self):
|
||||
hash1, _ = queueing.push({"data": "test1"}, "predict")
|
||||
hash2, position = queueing.push({"data": "test1"}, "predict")
|
||||
self.assertEquals(position, 1)
|
||||
|
||||
queueing.start_job(hash1)
|
||||
_, position = queueing.get_status(hash2)
|
||||
self.assertEquals(position, 1)
|
||||
queueing.pass_job(hash1, {"data": "result"})
|
||||
_, position = queueing.get_status(hash2)
|
||||
self.assertEquals(position, 0)
|
||||
|
||||
queueing.start_job(hash2)
|
||||
queueing.fail_job(hash2, "failure")
|
||||
status, _ = queueing.get_status(hash2)
|
||||
self.assertEquals(status, "FAILED")
|
||||
|
||||
def tearDown(self):
|
||||
queueing.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
88
test/test_routes.py
Normal file
88
test/test_routes.py
Normal file
@ -0,0 +1,88 @@
|
||||
"""Contains tests for networking.py and app.py"""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from gradio import Interface, flagging, networking, queueing, reset_all
|
||||
|
||||
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
||||
|
||||
|
||||
class TestRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(prevent_thread_lock=True)
|
||||
self.client = TestClient(self.app)
|
||||
|
||||
def test_get_main_route(self):
|
||||
response = self.client.get("/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_get_api_route(self):
|
||||
response = self.client.get("/api/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_static_files_served_safely(self):
|
||||
# Make sure things outside the static folder are not accessible
|
||||
response = self.client.get(r"/static/..%2findex.html")
|
||||
self.assertEqual(response.status_code, 404)
|
||||
response = self.client.get(r"/static/..%2f..%2fapi_docs.html")
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
def test_get_config_route(self):
|
||||
response = self.client.get("/config/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_predict_route(self):
|
||||
response = self.client.post("/api/predict/", json={"data": ["test"]})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
output = dict(response.json())
|
||||
self.assertEqual(output["data"], ["test"])
|
||||
self.assertTrue("durations" in output)
|
||||
self.assertTrue("avg_durations" in output)
|
||||
|
||||
def test_queue_push_route(self):
|
||||
queueing.push = mock.MagicMock(return_value=(None, None))
|
||||
response = self.client.post(
|
||||
"/api/queue/push/", json={"data": "test", "action": "test"}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_queue_push_route(self):
|
||||
queueing.get_status = mock.MagicMock(return_value=(None, None))
|
||||
response = self.client.post("/api/queue/status/", json={"hash": "test"})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.io.close()
|
||||
reset_all()
|
||||
|
||||
|
||||
class TestAuthenticatedRoutes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.io = Interface(lambda x: x, "text", "text")
|
||||
self.app, _, _ = self.io.launch(
|
||||
auth=("test", "correct_password"), prevent_thread_lock=True
|
||||
)
|
||||
self.client = TestClient(self.app)
|
||||
|
||||
def test_post_login(self):
|
||||
response = self.client.post(
|
||||
"/login", data=dict(username="test", password="correct_password")
|
||||
)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
response = self.client.post(
|
||||
"/login", data=dict(username="test", password="incorrect_password")
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.io.close()
|
||||
reset_all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user