"""Contains tests for networking.py and app.py""" import os import unittest from gradio import queueing from gradio.routes import QueuePushBody os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" class TestQueuingOpenClose(unittest.TestCase): def test_init(self): queueing.init() self.assertTrue(os.path.exists(queueing.DB_FILE)) os.remove(queueing.DB_FILE) def test_close(self): queueing.close() self.assertFalse(os.path.exists(queueing.DB_FILE)) class TestQueuingActions(unittest.TestCase): def setUp(self): queueing.init() def test_hashing(self): hash1 = queueing.generate_hash() hash2 = queueing.generate_hash() self.assertNotEquals(hash1, hash2) queueing.close() def test_push_pop_status(self): request = QueuePushBody( data="test1", action="predict", fn_index=0, session_hash="-" ) hash1, position = queueing.push(request) self.assertEquals(position, 0) request = QueuePushBody( data="test2", action="predict", fn_index=0, session_hash="-" ) hash2, position = queueing.push(request) self.assertEquals(position, 1) status, position = queueing.get_status(hash2) self.assertEquals(status, "QUEUED") self.assertEquals(position, 1) _, hash_popped, input_data, action = queueing.pop() self.assertEquals(hash_popped, hash1) self.assertEquals( input_data, {"action": "predict", "data": "test1", "fn_index": 0, "session_hash": "-"}, ) self.assertEquals(action, "predict") def test_jobs(self): request = QueuePushBody( data="test1", action="predict", fn_index=0, session_hash="-" ) hash1, _ = queueing.push(request) hash2, position = queueing.push(request) self.assertEquals(position, 1) queueing.start_job(hash1) _, position = queueing.get_status(hash2) self.assertEquals(position, 1) queueing.pass_job(hash1, {"data": "result"}) _, position = queueing.get_status(hash2) self.assertEquals(position, 0) queueing.start_job(hash2) queueing.fail_job(hash2, "failure") status, _ = queueing.get_status(hash2) self.assertEquals(status, "FAILED") def tearDown(self): queueing.close() if __name__ == "__main__": unittest.main()