gradio/test/test_queuing.py
aliabid94 b0306e716d
Restore queue in Blocks (#1137)
* first commit

* format

* fix tests

Co-authored-by: Ali Abid <aliabid94@gmail.com>
2022-05-02 16:17:53 -07:00

81 lines
2.4 KiB
Python

"""Contains tests for networking.py and app.py"""
import os
import unittest
from gradio import queueing
from gradio.routes import QueuePushBody
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestQueuingOpenClose(unittest.TestCase):
def test_init(self):
queueing.init()
self.assertTrue(os.path.exists(queueing.DB_FILE))
os.remove(queueing.DB_FILE)
def test_close(self):
queueing.close()
self.assertFalse(os.path.exists(queueing.DB_FILE))
class TestQueuingActions(unittest.TestCase):
def setUp(self):
queueing.init()
def test_hashing(self):
hash1 = queueing.generate_hash()
hash2 = queueing.generate_hash()
self.assertNotEquals(hash1, hash2)
queueing.close()
def test_push_pop_status(self):
request = QueuePushBody(
data="test1", action="predict", fn_index=0, session_hash="-"
)
hash1, position = queueing.push(request)
self.assertEquals(position, 0)
request = QueuePushBody(
data="test2", action="predict", fn_index=0, session_hash="-"
)
hash2, position = queueing.push(request)
self.assertEquals(position, 1)
status, position = queueing.get_status(hash2)
self.assertEquals(status, "QUEUED")
self.assertEquals(position, 1)
_, hash_popped, input_data, action = queueing.pop()
self.assertEquals(hash_popped, hash1)
self.assertEquals(
input_data,
{"action": "predict", "data": "test1", "fn_index": 0, "session_hash": "-"},
)
self.assertEquals(action, "predict")
def test_jobs(self):
request = QueuePushBody(
data="test1", action="predict", fn_index=0, session_hash="-"
)
hash1, _ = queueing.push(request)
hash2, position = queueing.push(request)
self.assertEquals(position, 1)
queueing.start_job(hash1)
_, position = queueing.get_status(hash2)
self.assertEquals(position, 1)
queueing.pass_job(hash1, {"data": "result"})
_, position = queueing.get_status(hash2)
self.assertEquals(position, 0)
queueing.start_job(hash2)
queueing.fail_job(hash2, "failure")
status, _ = queueing.get_status(hash2)
self.assertEquals(status, "FAILED")
def tearDown(self):
queueing.close()
if __name__ == "__main__":
unittest.main()