Queuing tests

This commit is contained in:
Abubakar Abid 2022-01-25 16:52:09 -06:00
parent 19d95c460d
commit 3261722ac7
4 changed files with 27 additions and 53 deletions

View File

@ -73,29 +73,6 @@ def get_first_available_port(initial: int, final: int) -> int:
)
def queue_thread(path_to_local_server, test_mode=False):
while True:
try:
next_job = queueing.pop()
if next_job is not None:
_, hash, input_data, task_type = next_job
queueing.start_job(hash)
response = requests.post(
path_to_local_server + "api/" + task_type + "/", json=input_data
)
if response.status_code == 200:
queueing.pass_job(hash, response.json())
else:
queueing.fail_job(hash, response.text)
else:
time.sleep(1)
except Exception as e:
time.sleep(1)
pass
if test_mode:
break
def start_server(
interface: Interface,
server_name: Optional[str] = None,
@ -147,7 +124,7 @@ def start_server(
raise ValueError("Cannot queue with encryption or authentication enabled.")
queueing.init()
app.queue_thread = threading.Thread(
target=queue_thread, args=(path_to_local_server,)
target=queueing.queue_thread, args=(path_to_local_server,)
)
app.queue_thread.start()
if interface.save_to is not None: # Used for selenium tests

View File

@ -25,14 +25,6 @@ def decode_base64_to_image(encoding):
return Image.open(BytesIO(base64.b64decode(image_encoded)))
def get_url_or_file_as_bytes(path):
try:
return requests.get(path).content
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
with open(path, "rb") as f:
return f.read()
def encode_url_or_file_to_base64(path):
try:
requests.get(path)

View File

@ -1,11 +1,37 @@
import json
import os
import sqlite3
import time
import uuid
import requests
DB_FILE = "gradio_queue.db"
def queue_thread(path_to_local_server, test_mode=False):
while True:
try:
next_job = pop()
if next_job is not None:
_, hash, input_data, task_type = next_job
start_job(hash)
response = requests.post(
path_to_local_server + "api/" + task_type + "/", json=input_data
)
if response.status_code == 200:
pass_job(hash, response.json())
else:
fail_job(hash, response.text)
else:
time.sleep(1)
except Exception as e:
time.sleep(1)
pass
if test_mode:
break
def generate_hash():
generate = True
conn = sqlite3.connect(DB_FILE)

View File

@ -109,26 +109,5 @@ class TestURLs(unittest.TestCase):
self.assertTrue(res)
# class TestQueuing(unittest.TestCase):
# def test_queueing(self):
# # mock queue methods and post method
# networking.queue.pop = mock.MagicMock(return_value=(None, None, None, 'predict'))
# networking.queue.pass_job = mock.MagicMock(return_value=(None, None))
# networking.queue.fail_job = mock.MagicMock(return_value=(None, None))
# networking.queue.start_job = mock.MagicMock(return_value=None)
# requests.post = mock.MagicMock(return_value=mock.MagicMock(status_code=200))
# # execute queue action successfully
# networking.queue_thread('test_path', test_mode=True)
# networking.queue.pass_job.assert_called_once()
# # execute queue action unsuccessfully
# requests.post = mock.MagicMock(return_value=mock.MagicMock(status_code=500))
# networking.queue_thread('test_path', test_mode=True)
# networking.queue.fail_job.assert_called_once()
# # no more things on the queue so methods shouldn't be called any more times
# networking.queue.pop = mock.MagicMock(return_value=None)
# networking.queue.pass_job.assert_called_once()
# networking.queue.fail_job.assert_called_once()
if __name__ == "__main__":
unittest.main()