mirror of
https://github.com/gradio-app/gradio.git
synced 2025-04-06 12:30:29 +08:00
Queuing tests
This commit is contained in:
parent
19d95c460d
commit
3261722ac7
@ -73,29 +73,6 @@ def get_first_available_port(initial: int, final: int) -> int:
|
||||
)
|
||||
|
||||
|
||||
def queue_thread(path_to_local_server, test_mode=False):
|
||||
while True:
|
||||
try:
|
||||
next_job = queueing.pop()
|
||||
if next_job is not None:
|
||||
_, hash, input_data, task_type = next_job
|
||||
queueing.start_job(hash)
|
||||
response = requests.post(
|
||||
path_to_local_server + "api/" + task_type + "/", json=input_data
|
||||
)
|
||||
if response.status_code == 200:
|
||||
queueing.pass_job(hash, response.json())
|
||||
else:
|
||||
queueing.fail_job(hash, response.text)
|
||||
else:
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
time.sleep(1)
|
||||
pass
|
||||
if test_mode:
|
||||
break
|
||||
|
||||
|
||||
def start_server(
|
||||
interface: Interface,
|
||||
server_name: Optional[str] = None,
|
||||
@ -147,7 +124,7 @@ def start_server(
|
||||
raise ValueError("Cannot queue with encryption or authentication enabled.")
|
||||
queueing.init()
|
||||
app.queue_thread = threading.Thread(
|
||||
target=queue_thread, args=(path_to_local_server,)
|
||||
target=queueing.queue_thread, args=(path_to_local_server,)
|
||||
)
|
||||
app.queue_thread.start()
|
||||
if interface.save_to is not None: # Used for selenium tests
|
||||
|
@ -25,14 +25,6 @@ def decode_base64_to_image(encoding):
|
||||
return Image.open(BytesIO(base64.b64decode(image_encoded)))
|
||||
|
||||
|
||||
def get_url_or_file_as_bytes(path):
|
||||
try:
|
||||
return requests.get(path).content
|
||||
except (requests.exceptions.MissingSchema, requests.exceptions.InvalidSchema):
|
||||
with open(path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def encode_url_or_file_to_base64(path):
|
||||
try:
|
||||
requests.get(path)
|
||||
|
@ -1,11 +1,37 @@
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
|
||||
DB_FILE = "gradio_queue.db"
|
||||
|
||||
|
||||
def queue_thread(path_to_local_server, test_mode=False):
|
||||
while True:
|
||||
try:
|
||||
next_job = pop()
|
||||
if next_job is not None:
|
||||
_, hash, input_data, task_type = next_job
|
||||
start_job(hash)
|
||||
response = requests.post(
|
||||
path_to_local_server + "api/" + task_type + "/", json=input_data
|
||||
)
|
||||
if response.status_code == 200:
|
||||
pass_job(hash, response.json())
|
||||
else:
|
||||
fail_job(hash, response.text)
|
||||
else:
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
time.sleep(1)
|
||||
pass
|
||||
if test_mode:
|
||||
break
|
||||
|
||||
|
||||
def generate_hash():
|
||||
generate = True
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
|
@ -109,26 +109,5 @@ class TestURLs(unittest.TestCase):
|
||||
self.assertTrue(res)
|
||||
|
||||
|
||||
# class TestQueuing(unittest.TestCase):
|
||||
# def test_queueing(self):
|
||||
# # mock queue methods and post method
|
||||
# networking.queue.pop = mock.MagicMock(return_value=(None, None, None, 'predict'))
|
||||
# networking.queue.pass_job = mock.MagicMock(return_value=(None, None))
|
||||
# networking.queue.fail_job = mock.MagicMock(return_value=(None, None))
|
||||
# networking.queue.start_job = mock.MagicMock(return_value=None)
|
||||
# requests.post = mock.MagicMock(return_value=mock.MagicMock(status_code=200))
|
||||
# # execute queue action successfully
|
||||
# networking.queue_thread('test_path', test_mode=True)
|
||||
# networking.queue.pass_job.assert_called_once()
|
||||
# # execute queue action unsuccessfully
|
||||
# requests.post = mock.MagicMock(return_value=mock.MagicMock(status_code=500))
|
||||
# networking.queue_thread('test_path', test_mode=True)
|
||||
# networking.queue.fail_job.assert_called_once()
|
||||
# # no more things on the queue so methods shouldn't be called any more times
|
||||
# networking.queue.pop = mock.MagicMock(return_value=None)
|
||||
# networking.queue.pass_job.assert_called_once()
|
||||
# networking.queue.fail_job.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user