diff --git a/.circleci/config.yml b/.circleci/config.yml index 5ada1de676..4c34e2d08e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -18,7 +18,7 @@ jobs: . venv/bin/activate pip install --upgrade pip pip install -r gradio.egg-info/requires.txt - pip install shap IPython comet_ml wandb mlflow tensorflow transformers + pip install shap IPython comet_ml wandb mlflow tensorflow transformers huggingface_hub pip install selenium==4.0.0a6.post2 coverage scikit-image - run: command: | diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 0f67b26ceb..500ae57043 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,17 +1,12 @@ # Description -Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. +Please include: +* relevant motivation +* a summary of the change +* which issue is fixed. +* any additional dependencies that are required for this change. -Fixes # (issue) - -## Type of change - -Please delete options that are not relevant. - -- [ ] Bug fix (non-breaking change which fixes an issue) -- [ ] New feature (non-breaking change which adds functionality) -- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) -- [ ] This change requires a documentation update +Fixes: # (issue) # Checklist: @@ -19,5 +14,5 @@ Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have commented my code in hard-to-understand areas - [ ] I have made corresponding changes to the documentation -- [ ] New and existing unit tests pass locally with my changes - [ ] I have added tests that prove my fix is effective or that my feature works +- [ ] New and existing unit tests pass locally with my changes diff --git a/gradio/__init__.py b/gradio/__init__.py index 8d6c44fd7b..3c3ea32779 100644 --- a/gradio/__init__.py +++ b/gradio/__init__.py @@ -1,8 +1,8 @@ import pkg_resources -from gradio.app import get_state, set_state +from gradio.routes import get_state, set_state from gradio.flagging import * -from gradio.interface import * # This makes it possible to import `Interface` as `gradio.Interface`. +from gradio.interface import * from gradio.mix import * current_pkg_version = pkg_resources.require("gradio")[0].version diff --git a/gradio/inputs.py b/gradio/inputs.py index b3fc7c2a38..61ae6643fa 100644 --- a/gradio/inputs.py +++ b/gradio/inputs.py @@ -6,6 +6,7 @@ automatically added to a registry, which allows them to be easily referenced in from __future__ import annotations +import os import json import math import tempfile @@ -1182,8 +1183,7 @@ class Audio(InputComponent): return self def tokenize(self, x): - file_obj = processing_utils.decode_base64_to_file(x) - sample_rate, data = processing_utils.audio_from_file(x) + sample_rate, data = processing_utils.audio_from_file(x["name"]) leave_one_out_sets = [] tokens = [] masks = [] @@ -1193,20 +1193,27 @@ class Audio(InputComponent): for index in range(len(boundaries) - 1): start, stop = boundaries[index], boundaries[index + 1] masks.append((start, stop)) + # Handle the leave one outs leave_one_out_data = np.copy(data) leave_one_out_data[start:stop] = 0 - file = tempfile.NamedTemporaryFile(delete=False) + file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") processing_utils.audio_to_file(sample_rate, leave_one_out_data, file.name) out_data = processing_utils.encode_file_to_base64(file.name) leave_one_out_sets.append(out_data) + file.close() + os.unlink(file.name) + # Handle the tokens token = np.copy(data) token[0:start] = 0 token[stop:] = 0 - file = tempfile.NamedTemporaryFile(delete=False) + file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") processing_utils.audio_to_file(sample_rate, token, file.name) token_data = processing_utils.encode_file_to_base64(file.name) + file.close() + os.unlink(file.name) + tokens.append(token_data) return tokens, leave_one_out_sets, masks @@ -1215,7 +1222,7 @@ class Audio(InputComponent): x = tokens[0] file_obj = processing_utils.decode_base64_to_file(x) sample_rate, data = processing_utils.audio_from_file(file_obj.name) - zero_input = np.zeros_like(data, dtype=int) + zero_input = np.zeros_like(data, dtype='int16') # decode all of the tokens token_data = [] for token in tokens: @@ -1229,8 +1236,10 @@ class Audio(InputComponent): for t, b in zip(token_data, binary_mask_vector): masked_input = masked_input + t * int(b) file = tempfile.NamedTemporaryFile(delete=False) - processing_utils.audio_to_file(sample_rate, masked_input, file_obj.name) + processing_utils.audio_to_file(sample_rate, masked_input, file.name) masked_data = processing_utils.encode_file_to_base64(file.name) + file.close() + os.unlink(file.name) masked_inputs.append(masked_data) return masked_inputs diff --git a/gradio/interface.py b/gradio/interface.py index 2c7e9c78bd..25f1c988c4 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -17,7 +17,7 @@ import webbrowser from logging import warning from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple -import markdown2 # type: ignore +import markdown2 from gradio import (encryptor, interpretation, networking, # type: ignore queueing, strings, utils) diff --git a/gradio/networking.py b/gradio/networking.py index 2beba0ca7b..5991d275b5 100644 --- a/gradio/networking.py +++ b/gradio/networking.py @@ -19,7 +19,7 @@ import requests import uvicorn from gradio import queueing -from gradio.app import app +from gradio.routes import app from gradio.tunneling import create_tunnel if TYPE_CHECKING: # Only import for type checking (to avoid circular imports). @@ -73,29 +73,6 @@ def get_first_available_port(initial: int, final: int) -> int: ) -def queue_thread(path_to_local_server, test_mode=False): - while True: - try: - next_job = queueing.pop() - if next_job is not None: - _, hash, input_data, task_type = next_job - queueing.start_job(hash) - response = requests.post( - path_to_local_server + "api/" + task_type + "/", json=input_data - ) - if response.status_code == 200: - queueing.pass_job(hash, response.json()) - else: - queueing.fail_job(hash, response.text) - else: - time.sleep(1) - except Exception as e: - time.sleep(1) - pass - if test_mode: - break - - def start_server( interface: Interface, server_name: Optional[str] = None, @@ -147,7 +124,7 @@ def start_server( raise ValueError("Cannot queue with encryption or authentication enabled.") queueing.init() app.queue_thread = threading.Thread( - target=queue_thread, args=(path_to_local_server,) + target=queueing.queue_thread, args=(path_to_local_server,) ) app.queue_thread.start() if interface.save_to is not None: # Used for selenium tests diff --git a/gradio/process_examples.py b/gradio/process_examples.py index f12bd42eaf..087855a8f4 100644 --- a/gradio/process_examples.py +++ b/gradio/process_examples.py @@ -1,15 +1,27 @@ +""" +Defines helper methods useful for loading and caching Interface examples. +""" +from __future__ import annotations + import csv import os import shutil -from typing import Any, List +from typing import Any, List, Tuple, TYPE_CHECKING from gradio.flagging import CSVLogger +if TYPE_CHECKING: # Only import for type checking (to avoid circular imports). + from gradio import Interface + CACHED_FOLDER = "gradio_cached_examples" CACHE_FILE = os.path.join(CACHED_FOLDER, "log.csv") -def process_example(interface, example_id: int): +def process_example( + interface: Interface, + example_id: int +) -> Tuple[List[Any], List[float]]: + """Loads an example from the interface and returns its prediction.""" example_set = interface.examples[example_id] raw_input = [ interface.input_components[i].preprocess_example(example) @@ -19,7 +31,10 @@ def process_example(interface, example_id: int): return prediction, durations -def cache_interface_examples(interface) -> None: +def cache_interface_examples( + interface: Interface +) -> None: + """Caches all of the examples from an interface.""" if os.path.exists(CACHE_FILE): print( f"Using cache from '{os.path.abspath(CACHED_FOLDER)}/' directory. If method or examples have changed since last caching, delete this folder to clear cache." @@ -39,7 +54,11 @@ def cache_interface_examples(interface) -> None: raise e -def load_from_cache(interface, example_id: int) -> List[Any]: +def load_from_cache( + interface: Interface, + example_id: int +) -> List[Any]: + """Loads a particular cached example for the interface.""" with open(CACHE_FILE) as cache: examples = list(csv.reader(cache)) example = examples[example_id + 1] # +1 to adjust for header diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index b59cf5f83d..59f55a0756 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -25,14 +25,6 @@ def decode_base64_to_image(encoding): return Image.open(BytesIO(base64.b64decode(image_encoded))) -def get_url_or_file_as_bytes(path): - try: - return requests.get(path).content - except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema): - with open(path, "rb") as f: - return f.read() - - def encode_url_or_file_to_base64(path): try: requests.get(path) @@ -144,7 +136,7 @@ def audio_to_file(sample_rate, data, filename): sample_width=data.dtype.itemsize, channels=(1 if len(data.shape) == 1 else data.shape[1]), ) - audio.export(filename, format="wav") + audio.export(filename, format="wav").close() ################## diff --git a/gradio/queueing.py b/gradio/queueing.py index c196cc3155..257bc6c31c 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -1,12 +1,39 @@ import json import os import sqlite3 +import time +from typing import Dict, Tuple import uuid +import requests + DB_FILE = "gradio_queue.db" -def generate_hash(): +def queue_thread( + path_to_local_server: str +) -> None: + while True: + try: + next_job = pop() + if next_job is not None: + _, hash, input_data, task_type = next_job + start_job(hash) + response = requests.post( + path_to_local_server + "api/" + task_type + "/", json=input_data + ) + if response.status_code == 200: + pass_job(hash, response.json()) + else: + fail_job(hash, response.text) + else: + time.sleep(1) + except Exception as e: + time.sleep(1) + pass + + +def generate_hash() -> str: generate = True conn = sqlite3.connect(DB_FILE) c = conn.cursor() @@ -24,7 +51,7 @@ def generate_hash(): return hash -def init(): +def init() -> None: if os.path.exists(DB_FILE): os.remove(DB_FILE) conn = sqlite3.connect(DB_FILE) @@ -51,12 +78,12 @@ def init(): conn.commit() -def close(): +def close() -> None: if os.path.exists(DB_FILE): os.remove(DB_FILE) -def pop(): +def pop() -> Tuple[int, str, Dict, str]: conn = sqlite3.connect(DB_FILE) c = conn.cursor() c.execute("BEGIN EXCLUSIVE") @@ -81,7 +108,10 @@ def pop(): return result[0], result[1], json.loads(result[2]), result[3] -def push(input_data, action): +def push( + input_data: Dict, + action: str +) -> Tuple[str, int]: input_data = json.dumps(input_data) hash = generate_hash() conn = sqlite3.connect(DB_FILE) @@ -104,20 +134,19 @@ def push(input_data, action): if queue_position is None: conn.commit() raise ValueError("Hash not found.") - elif queue_position == 0: - c.execute( - """ - SELECT COUNT(*) FROM jobs WHERE status = "PENDING"; + c.execute( """ - ) - result = c.fetchone() - if result[0] == 0: - queue_position -= 1 + SELECT COUNT(*) FROM jobs WHERE status = "PENDING"; + """ + ) + result = c.fetchone() + if not(result[0] == 0): + queue_position += 1 conn.commit() return hash, queue_position -def get_status(hash): +def get_status(hash: str) -> Tuple[str, int]: conn = sqlite3.connect(DB_FILE) c = conn.cursor() c.execute( @@ -169,20 +198,19 @@ def get_status(hash): ) result = c.fetchone() queue_position = result[0] - if queue_position == 0: - c.execute( - """ - SELECT COUNT(*) FROM jobs WHERE status = "PENDING"; + c.execute( """ - ) - result = c.fetchone() - if result[0] == 0: - queue_position -= 1 + SELECT COUNT(*) FROM jobs WHERE status = "PENDING"; + """ + ) + result = c.fetchone() + if not(result[0] == 0): + queue_position += 1 conn.commit() return "QUEUED", queue_position -def start_job(hash): +def start_job(hash: str) -> None: conn = sqlite3.connect(DB_FILE) c = conn.cursor() c.execute("BEGIN EXCLUSIVE") @@ -201,7 +229,10 @@ def start_job(hash): conn.commit() -def fail_job(hash, error_message): +def fail_job( + hash: str, + error_message: str +) -> None: conn = sqlite3.connect(DB_FILE) c = conn.cursor() c.execute( @@ -216,7 +247,10 @@ def fail_job(hash, error_message): conn.commit() -def pass_job(hash, output_data): +def pass_job( + hash: str, + output_data: Dict +) -> None: output_data = json.dumps(output_data) conn = sqlite3.connect(DB_FILE) c = conn.cursor() diff --git a/gradio/app.py b/gradio/routes.py similarity index 97% rename from gradio/app.py rename to gradio/routes.py index b1b793569f..1266c23377 100644 --- a/gradio/app.py +++ b/gradio/routes.py @@ -129,15 +129,17 @@ def static_resource(path: str): @app.get("/file/{path:path}", dependencies=[Depends(login_check)]) def file(path): - if app.interface.encrypt and isinstance( - app.interface.examples, str) and path.startswith( - app.interface.examples): + if ( + app.interface.encrypt + and isinstance(app.interface.examples, str) + and path.startswith(app.interface.examples) + ): with open(safe_join(app.cwd, path), "rb") as encrypted_file: encrypted_data = encrypted_file.read() - file_data = encryptor.decrypt( - app.interface.encryption_key, encrypted_data) + file_data = encryptor.decrypt(app.interface.encryption_key, encrypted_data) return FileResponse( - io.BytesIO(file_data), attachment_filename=os.path.basename(path)) + io.BytesIO(file_data), attachment_filename=os.path.basename(path) + ) else: return FileResponse(safe_join(app.cwd, path)) diff --git a/test/test_flagging.py b/test/test_flagging.py index 1533949c0f..b3636ff0fc 100644 --- a/test/test_flagging.py +++ b/test/test_flagging.py @@ -1,12 +1,16 @@ +import os import tempfile import unittest +from unittest.mock import MagicMock + +import huggingface_hub import gradio as gr from gradio import flagging class TestDefaultFlagging(unittest.TestCase): - def test_default_flagging_handler(self): + def test_default_flagging_callback(self): with tempfile.TemporaryDirectory() as tmpdirname: io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname) io.launch(prevent_thread_lock=True) @@ -18,7 +22,7 @@ class TestDefaultFlagging(unittest.TestCase): class TestSimpleFlagging(unittest.TestCase): - def test_simple_csv_flagging_handler(self): + def test_simple_csv_flagging_callback(self): with tempfile.TemporaryDirectory() as tmpdirname: io = gr.Interface( lambda x: x, @@ -29,11 +33,39 @@ class TestSimpleFlagging(unittest.TestCase): ) io.launch(prevent_thread_lock=True) row_count = io.flagging_callback.flag(io, ["test"], ["test"]) - self.assertEqual(row_count, 0) # no header + self.assertEqual(row_count, 0) # no header in SimpleCSVLogger row_count = io.flagging_callback.flag(io, ["test"], ["test"]) - self.assertEqual(row_count, 1) # no header + self.assertEqual(row_count, 1) # no header in SimpleCSVLogger io.close() +class TestHuggingFaceDatasetSaver(unittest.TestCase): + def test_saver_setup(self): + huggingface_hub.create_repo = MagicMock() + huggingface_hub.Repository = MagicMock() + flagger = flagging.HuggingFaceDatasetSaver("test", "test") + with tempfile.TemporaryDirectory() as tmpdirname: + flagger.setup(tmpdirname) + huggingface_hub.create_repo.assert_called_once() + + def test_saver_flag(self): + huggingface_hub.create_repo = MagicMock() + huggingface_hub.Repository = MagicMock() + with tempfile.TemporaryDirectory() as tmpdirname: + io = gr.Interface( + lambda x: x, + "text", + "text", + flagging_dir=tmpdirname, + flagging_callback=flagging.HuggingFaceDatasetSaver("test", "test"), + ) + os.mkdir(os.path.join(tmpdirname, "test")) + io.launch(prevent_thread_lock=True) + row_count = io.flagging_callback.flag(io, ["test"], ["test"]) + self.assertEqual(row_count, 1) # 2 rows written including header + row_count = io.flagging_callback.flag(io, ["test"], ["test"]) + self.assertEqual(row_count, 2) # 3 rows written including header + + if __name__ == "__main__": unittest.main() diff --git a/test/test_inputs.py b/test/test_inputs.py index 41218bb02e..c914c29d67 100644 --- a/test/test_inputs.py +++ b/test/test_inputs.py @@ -1,3 +1,4 @@ +from difflib import SequenceMatcher import json import os import tempfile @@ -565,21 +566,14 @@ class TestAudio(unittest.TestCase): x_wav = gr.processing_utils.audio_from_file("test/test_files/audio_sample.wav") self.assertIsInstance(audio_input.serialize(x_wav, False), dict) - # def test_in_interface(self): - # x_wav = gr.test_data.BASE64_AUDIO - # def max_amplitude_from_wav_file(wav_file): - # audio_segment = AudioSegment.from_file(wav_file.name) - # data = np.array(audio_segment.get_array_of_samples()) - # return np.max(data) - # iface = gr.Interface( - # max_amplitude_from_wav_file, - # gr.inputs.Audio(type="file"), - # "number", interpretation="default") - # # TODO(aliabd): investigate why this sometimes fails (returns 5239 or 576) - # self.assertEqual(iface.process([x_wav])[0], [576]) - # scores, alternative_outputs = iface.interpret([x_wav]) - # self.assertEqual(scores, ... ) - # self.assertEqual(alternative_outputs, ...) + def test_tokenize(self): + x_wav = gr.test_data.BASE64_AUDIO + audio_input = gr.inputs.Audio() + tokens, _, _ = audio_input.tokenize(x_wav) + self.assertEquals(len(tokens), audio_input.interpretation_segments) + x_new = audio_input.get_masked_inputs(tokens, [[1]*len(tokens)])[0] + similarity = SequenceMatcher(a=x_wav["data"], b=x_new).ratio() + self.assertGreater(similarity, 0.9) class TestFile(unittest.TestCase): diff --git a/test/test_networking.py b/test/test_networking.py index 562d5cbaa8..3ae08c3dc7 100644 --- a/test/test_networking.py +++ b/test/test_networking.py @@ -9,7 +9,7 @@ import warnings import aiohttp from fastapi.testclient import TestClient -from gradio import Interface, flagging, networking, reset_all, utils +from gradio import Interface, flagging, networking, reset_all, queueing os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -35,77 +35,6 @@ class TestPort(unittest.TestCase): warnings.warn("Unable to test, no ports available") -class TestRoutes(unittest.TestCase): - def setUp(self) -> None: - self.io = Interface(lambda x: x, "text", "text") - self.app, _, _ = self.io.launch(prevent_thread_lock=True) - self.client = TestClient(self.app) - - def test_get_main_route(self): - response = self.client.get("/") - self.assertEqual(response.status_code, 200) - - def test_get_api_route(self): - response = self.client.get("/api/") - self.assertEqual(response.status_code, 200) - - def test_static_files_served_safely(self): - # Make sure things outside the static folder are not accessible - response = self.client.get(r"/static/..%2findex.html") - self.assertEqual(response.status_code, 404) - response = self.client.get(r"/static/..%2f..%2fapi_docs.html") - self.assertEqual(response.status_code, 404) - - def test_get_config_route(self): - response = self.client.get("/config/") - self.assertEqual(response.status_code, 200) - - def test_predict_route(self): - response = self.client.post("/api/predict/", json={"data": ["test"]}) - self.assertEqual(response.status_code, 200) - output = dict(response.json()) - self.assertEqual(output["data"], ["test"]) - self.assertTrue("durations" in output) - self.assertTrue("avg_durations" in output) - - # def test_queue_push_route(self): - # networking.queue.push = mock.MagicMock(return_value=(None, None)) - # response = self.client.post('/api/queue/push/', json={"data": "test", "action": "test"}) - # self.assertEqual(response.status_code, 200) - - # def test_queue_push_route(self): - # networking.queue.get_status = mock.MagicMock(return_value=(None, None)) - # response = self.client.post('/api/queue/status/', json={"hash": "test"}) - # self.assertEqual(response.status_code, 200) - - def tearDown(self) -> None: - self.io.close() - reset_all() - - -class TestAuthenticatedRoutes(unittest.TestCase): - def setUp(self) -> None: - self.io = Interface(lambda x: x, "text", "text") - self.app, _, _ = self.io.launch( - auth=("test", "correct_password"), prevent_thread_lock=True - ) - self.client = TestClient(self.app) - - def test_post_login(self): - response = self.client.post( - "/login", data=dict(username="test", password="correct_password") - ) - self.assertEqual(response.status_code, 302) - response = self.client.post( - "/login", data=dict(username="test", password="incorrect_password") - ) - self.assertEqual(response.status_code, 400) - - def tearDown(self) -> None: - self.io.close() - reset_all() - - class TestInterfaceCustomParameters(unittest.TestCase): def test_show_error(self): io = Interface(lambda x: 1 / x, "number", "number") @@ -180,26 +109,5 @@ class TestURLs(unittest.TestCase): self.assertTrue(res) -# class TestQueuing(unittest.TestCase): -# def test_queueing(self): -# # mock queue methods and post method -# networking.queue.pop = mock.MagicMock(return_value=(None, None, None, 'predict')) -# networking.queue.pass_job = mock.MagicMock(return_value=(None, None)) -# networking.queue.fail_job = mock.MagicMock(return_value=(None, None)) -# networking.queue.start_job = mock.MagicMock(return_value=None) -# requests.post = mock.MagicMock(return_value=mock.MagicMock(status_code=200)) -# # execute queue action successfully -# networking.queue_thread('test_path', test_mode=True) -# networking.queue.pass_job.assert_called_once() -# # execute queue action unsuccessfully -# requests.post = mock.MagicMock(return_value=mock.MagicMock(status_code=500)) -# networking.queue_thread('test_path', test_mode=True) -# networking.queue.fail_job.assert_called_once() -# # no more things on the queue so methods shouldn't be called any more times -# networking.queue.pop = mock.MagicMock(return_value=None) -# networking.queue.pass_job.assert_called_once() -# networking.queue.fail_job.assert_called_once() - - if __name__ == "__main__": unittest.main() diff --git a/test/test_process_examples.py b/test/test_process_examples.py new file mode 100644 index 0000000000..b4b5d0da24 --- /dev/null +++ b/test/test_process_examples.py @@ -0,0 +1,27 @@ +import os +import unittest + +from gradio import Interface, process_examples + +os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" + + +class TestProcessExamples(unittest.TestCase): + def test_process_example(self): + io = Interface(lambda x: "Hello " + x, "text", "text", + examples=[["World"]]) + prediction, _ = process_examples.process_example(io, 0) + self.assertEquals(prediction[0], "Hello World") + + def test_caching(self): + io = Interface(lambda x: "Hello " + x, "text", "text", + examples=[["World"], ["Dunya"], ["Monde"]]) + io.launch(prevent_thread_lock=True) + process_examples.cache_interface_examples(io) + prediction = process_examples.load_from_cache(io, 1) + io.close() + self.assertEquals(prediction[0], "Hello Dunya") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_queuing.py b/test/test_queuing.py new file mode 100644 index 0000000000..fd1edd4dfa --- /dev/null +++ b/test/test_queuing.py @@ -0,0 +1,68 @@ +"""Contains tests for networking.py and app.py""" + +import os +import unittest +import unittest.mock as mock + +import requests + +from gradio import Interface, queueing + +os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" + + +class TestQueuingOpenClose(unittest.TestCase): + def test_init(self): + queueing.init() + self.assertTrue(os.path.exists(queueing.DB_FILE)) + os.remove(queueing.DB_FILE) + + def test_close(self): + queueing.close() + self.assertFalse(os.path.exists(queueing.DB_FILE)) + +class TestQueuingActions(unittest.TestCase): + def setUp(self): + queueing.init() + + def test_hashing(self): + hash1 = queueing.generate_hash() + hash2 = queueing.generate_hash() + self.assertNotEquals(hash1, hash2) + queueing.close() + + def test_push_pop_status(self): + hash1, position = queueing.push({"data": "test1"}, "predict") + self.assertEquals(position, 0) + hash2, position = queueing.push({"data": "test2"}, "predict") + self.assertEquals(position, 1) + status, position = queueing.get_status(hash2) + self.assertEquals(status, "QUEUED") + self.assertEquals(position, 1) + _, hash_popped, input_data, action = queueing.pop() + self.assertEquals(hash_popped, hash1) + self.assertEquals(input_data, {"data": "test1"}) + self.assertEquals(action, "predict") + + def test_jobs(self): + hash1, _ = queueing.push({"data": "test1"}, "predict") + hash2, position = queueing.push({"data": "test1"}, "predict") + self.assertEquals(position, 1) + + queueing.start_job(hash1) + _, position = queueing.get_status(hash2) + self.assertEquals(position, 1) + queueing.pass_job(hash1, {"data": "result"}) + _, position = queueing.get_status(hash2) + self.assertEquals(position, 0) + + queueing.start_job(hash2) + queueing.fail_job(hash2, "failure") + status, _ = queueing.get_status(hash2) + self.assertEquals(status, "FAILED") + + def tearDown(self): + queueing.close() + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_routes.py b/test/test_routes.py new file mode 100644 index 0000000000..5c3f96f93b --- /dev/null +++ b/test/test_routes.py @@ -0,0 +1,88 @@ +"""Contains tests for networking.py and app.py""" + +import os +import unittest +import unittest.mock as mock + +from fastapi.testclient import TestClient + +from gradio import Interface, flagging, networking, queueing, reset_all + +os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" + + +class TestRoutes(unittest.TestCase): + def setUp(self) -> None: + self.io = Interface(lambda x: x, "text", "text") + self.app, _, _ = self.io.launch(prevent_thread_lock=True) + self.client = TestClient(self.app) + + def test_get_main_route(self): + response = self.client.get("/") + self.assertEqual(response.status_code, 200) + + def test_get_api_route(self): + response = self.client.get("/api/") + self.assertEqual(response.status_code, 200) + + def test_static_files_served_safely(self): + # Make sure things outside the static folder are not accessible + response = self.client.get(r"/static/..%2findex.html") + self.assertEqual(response.status_code, 404) + response = self.client.get(r"/static/..%2f..%2fapi_docs.html") + self.assertEqual(response.status_code, 404) + + def test_get_config_route(self): + response = self.client.get("/config/") + self.assertEqual(response.status_code, 200) + + def test_predict_route(self): + response = self.client.post("/api/predict/", json={"data": ["test"]}) + self.assertEqual(response.status_code, 200) + output = dict(response.json()) + self.assertEqual(output["data"], ["test"]) + self.assertTrue("durations" in output) + self.assertTrue("avg_durations" in output) + + def test_queue_push_route(self): + queueing.push = mock.MagicMock(return_value=(None, None)) + response = self.client.post( + "/api/queue/push/", json={"data": "test", "action": "test"} + ) + self.assertEqual(response.status_code, 200) + + def test_queue_push_route(self): + queueing.get_status = mock.MagicMock(return_value=(None, None)) + response = self.client.post("/api/queue/status/", json={"hash": "test"}) + self.assertEqual(response.status_code, 200) + + def tearDown(self) -> None: + self.io.close() + reset_all() + + +class TestAuthenticatedRoutes(unittest.TestCase): + def setUp(self) -> None: + self.io = Interface(lambda x: x, "text", "text") + self.app, _, _ = self.io.launch( + auth=("test", "correct_password"), prevent_thread_lock=True + ) + self.client = TestClient(self.app) + + def test_post_login(self): + response = self.client.post( + "/login", data=dict(username="test", password="correct_password") + ) + self.assertEqual(response.status_code, 302) + response = self.client.post( + "/login", data=dict(username="test", password="incorrect_password") + ) + self.assertEqual(response.status_code, 400) + + def tearDown(self) -> None: + self.io.close() + reset_all() + + +if __name__ == "__main__": + unittest.main()