gradio/test/test_queueing.py
Freddy Boulton 5df113a4d6
Add timeouts to queue messages (#3196)
* Fix + test

* Remove print statements + fix import for 3.7

* CHANGELOG

* Remove more print statements

* Add 60 second timeout for uploading data

* Fix test

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
2023-02-21 16:44:18 -05:00

454 lines
16 KiB
Python

import asyncio
import os
import sys
from collections import deque
from unittest.mock import MagicMock, patch
import pytest
from gradio.queueing import Event, Queue
from gradio.utils import AsyncRequest
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class AsyncMock(MagicMock):
async def __call__(self, *args, **kwargs):
return super(AsyncMock, self).__call__(*args, **kwargs)
@pytest.fixture()
def queue() -> Queue:
queue_object = Queue(
live_updates=True,
concurrency_count=1,
update_intervals=1,
max_size=None,
blocks_dependencies=[],
)
yield queue_object
queue_object.close()
@pytest.fixture()
def mock_event() -> Event:
websocket = AsyncMock()
event = Event(websocket=websocket, session_hash="test", fn_index=0)
yield event
class TestQueueMethods:
@pytest.mark.asyncio
async def test_start(self, queue: Queue):
await queue.start()
assert queue.stopped is False
assert queue.get_active_worker_count() == 0
@pytest.mark.asyncio
async def test_stop_resume(self, queue: Queue):
await queue.start()
queue.close()
assert queue.stopped
queue.resume()
assert queue.stopped is False
@pytest.mark.asyncio
async def test_receive(self, queue: Queue, mock_event: Event):
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
await queue.get_message(mock_event)
assert mock_event.websocket.receive_json.called
@pytest.mark.asyncio
async def test_receive_timeout(self, queue: Queue, mock_event: Event):
async def take_too_long():
await asyncio.sleep(1)
mock_event.websocket.receive_json = take_too_long
data, is_awake = await queue.get_message(mock_event, timeout=0.5)
assert data is None
assert not is_awake
@pytest.mark.asyncio
async def test_send(self, queue: Queue, mock_event: Event):
await queue.send_message(mock_event, {})
assert mock_event.websocket.send_json.called
@pytest.mark.asyncio
async def test_add_to_queue(self, queue: Queue, mock_event: Event):
queue.push(mock_event)
assert len(queue.event_queue) == 1
@pytest.mark.asyncio
async def test_add_to_queue_with_max_size(self, queue: Queue, mock_event: Event):
queue.max_size = 1
queue.push(mock_event)
assert len(queue.event_queue) == 1
queue.push(mock_event)
assert len(queue.event_queue) == 1
@pytest.mark.asyncio
async def test_clean_event(self, queue: Queue, mock_event: Event):
queue.push(mock_event)
await queue.clean_event(mock_event)
assert len(queue.event_queue) == 0
@pytest.mark.asyncio
async def test_gather_event_data(self, queue: Queue, mock_event: Event):
queue.send_message = AsyncMock()
queue.get_message = AsyncMock()
queue.send_message.return_value = True
queue.get_message.return_value = {"data": ["test"], "fn": 0}, True
assert await queue.gather_event_data(mock_event)
assert queue.send_message.called
assert mock_event.data == {"data": ["test"], "fn": 0}
queue.send_message.called = False
assert await queue.gather_event_data(mock_event)
assert not (queue.send_message.called)
@pytest.mark.asyncio
async def test_gather_event_data_timeout(self, queue: Queue, mock_event: Event):
async def take_too_long():
await asyncio.sleep(1)
queue.send_message = AsyncMock()
queue.send_message.return_value = True
mock_event.websocket.receive_json = take_too_long
is_awake = await queue.gather_event_data(mock_event, receive_timeout=0.5)
assert not is_awake
# Have to use awful [1][0][1] syntax cause of python 3.7
assert queue.send_message.call_args_list[1][0][1] == {
"msg": "process_completed",
"output": {"error": "Time out uploading data to server"},
"success": False,
}
class TestQueueEstimation:
def test_get_update_estimation(self, queue: Queue):
queue.update_estimation(5)
estimation = queue.get_estimation()
assert estimation.avg_event_process_time == 5
queue.update_estimation(15)
estimation = queue.get_estimation()
assert estimation.avg_event_process_time == 10 # (5 + 15) / 2
queue.update_estimation(100)
estimation = queue.get_estimation()
assert estimation.avg_event_process_time == 40 # (5 + 15 + 100) / 3
@pytest.mark.asyncio
async def test_send_estimation(self, queue: Queue, mock_event: Event):
queue.send_message = AsyncMock()
queue.send_message.return_value = True
estimation = queue.get_estimation()
estimation = await queue.send_estimation(mock_event, estimation, 1)
assert queue.send_message.called
assert estimation.rank == 1
queue.update_estimation(5)
estimation = await queue.send_estimation(mock_event, estimation, 2)
assert estimation.rank == 2
assert estimation.rank_eta == 15
@pytest.mark.asyncio
async def queue_sets_concurrency_count(self):
queue_object = Queue(
live_updates=True,
concurrency_count=5,
update_intervals=1,
max_size=None,
)
assert len(queue_object.active_jobs) == 5
queue_object.close()
class TestQueueProcessEvents:
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Mocks of async context manager don't work for 3.7",
)
@pytest.mark.asyncio
@patch("gradio.queueing.AsyncRequest", new_callable=AsyncMock)
async def test_process_event(self, mock_request, queue: Queue, mock_event: Event):
queue.gather_event_data = AsyncMock()
queue.gather_event_data.return_value = True
queue.send_message = AsyncMock()
queue.send_message.return_value = True
queue.call_prediction = AsyncMock()
queue.call_prediction.return_value = MagicMock()
queue.call_prediction.return_value.has_exception = False
queue.call_prediction.return_value.json = {"is_generating": False}
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
queue.call_prediction.assert_called_once()
mock_event.disconnect.assert_called_once()
queue.clean_event.assert_called_once()
mock_request.assert_called_with(
method=AsyncRequest.Method.POST,
url=f"{queue.server_path}reset",
json={
"session_hash": mock_event.session_hash,
"fn_index": mock_event.fn_index,
},
client=None,
)
@pytest.mark.asyncio
async def test_process_event_handles_error_when_gathering_data(
self, queue: Queue, mock_event: Event
):
mock_event.websocket.send_json = AsyncMock()
mock_event.websocket.send_json.side_effect = ValueError("Can't connect")
queue.call_prediction = AsyncMock()
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
mock_event.data = None
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
assert not queue.call_prediction.called
assert queue.clean_event.call_count >= 1
@pytest.mark.asyncio
async def test_process_event_handles_error_sending_process_start_msg(
self, queue: Queue, mock_event: Event
):
mock_event.websocket.send_json = AsyncMock()
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
mock_event.websocket.send_json.side_effect = ["2", ValueError("Can't connect")]
queue.call_prediction = AsyncMock()
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
mock_event.data = None
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
assert not queue.call_prediction.called
assert queue.clean_event.call_count >= 1
@pytest.mark.asyncio
async def test_process_event_handles_exception_in_call_prediction_request(
self, queue: Queue, mock_event: Event
):
mock_event.disconnect = AsyncMock()
queue.gather_event_data = AsyncMock(return_value=True)
queue.clean_event = AsyncMock()
queue.send_message = AsyncMock(return_value=True)
queue.call_prediction = AsyncMock(
return_value=MagicMock(has_exception=True, exception=ValueError("foo"))
)
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
queue.call_prediction.assert_called_once()
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@pytest.mark.asyncio
async def test_process_event_handles_exception_in_is_generating_request(
self, queue: Queue, mock_event: Event
):
# We need to return a good response with is_generating=True first,
# setting up the function to expect further iterative responses.
# Then we provide a 500 response.
side_effects = [
MagicMock(has_exception=False, status=200, json=dict(is_generating=True)),
MagicMock(has_exception=False, status=500, json=dict(error="Foo")),
]
mock_event.disconnect = AsyncMock()
queue.gather_event_data = AsyncMock(return_value=True)
queue.clean_event = AsyncMock()
queue.send_message = AsyncMock(return_value=True)
queue.call_prediction = AsyncMock(side_effect=side_effects)
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
queue.send_message.assert_called_with(
mock_event,
{
"msg": "process_completed",
"output": {"error": "Foo"},
"success": False,
},
)
assert queue.call_prediction.call_count == 2
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@pytest.mark.asyncio
async def test_process_event_handles_error_sending_process_completed_msg(
self, queue: Queue, mock_event: Event
):
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
mock_event.websocket.send_json = AsyncMock()
mock_event.websocket.send_json.side_effect = [
"2",
"3",
ValueError("Can't connect"),
]
queue.call_prediction = AsyncMock(
return_value=MagicMock(has_exception=False, json=dict(is_generating=False))
)
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
mock_event.data = None
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
queue.call_prediction.assert_called_once()
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Mocks of async context manager don't work for 3.7",
)
@pytest.mark.asyncio
@patch("gradio.queueing.AsyncRequest", new_callable=AsyncMock)
async def test_process_event_handles_exception_during_disconnect(
self, mock_request, queue: Queue, mock_event: Event
):
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
mock_event.websocket.send_json = AsyncMock()
queue.call_prediction = AsyncMock(
return_value=MagicMock(has_exception=False, json=dict(is_generating=False))
)
# No exception should be raised during `process_event`
mock_event.disconnect = AsyncMock(side_effect=ValueError("..."))
queue.clean_event = AsyncMock()
mock_event.data = None
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
mock_request.assert_called_with(
method=AsyncRequest.Method.POST,
url=f"{queue.server_path}reset",
json={
"session_hash": mock_event.session_hash,
"fn_index": mock_event.fn_index,
},
client=None,
)
class TestQueueBatch:
@pytest.mark.asyncio
async def test_process_event(self, queue: Queue, mock_event: Event):
queue.gather_event_data = AsyncMock()
queue.gather_event_data.return_value = True
queue.send_message = AsyncMock()
queue.send_message.return_value = True
queue.call_prediction = AsyncMock()
queue.call_prediction.return_value = MagicMock()
queue.call_prediction.return_value.has_exception = False
queue.call_prediction.return_value.json = {
"is_generating": False,
"data": [[1, 2]],
}
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
websocket = MagicMock()
mock_event2 = Event(websocket=websocket, session_hash="test", fn_index=0)
mock_event2.disconnect = AsyncMock()
queue.active_jobs = [[mock_event, mock_event2]]
await queue.process_events([mock_event, mock_event2], batch=True)
queue.call_prediction.assert_called_once() # called once for both events
mock_event.disconnect.assert_called_once()
mock_event2.disconnect.assert_called_once()
queue.clean_event.call_count == 2
class TestGetEventsInBatch:
def test_empty_event_queue(self, queue: Queue):
queue.event_queue = deque()
events, _ = queue.get_events_in_batch()
assert events is None
def test_single_type_of_event(self, queue: Queue):
queue.blocks_dependencies = [{"batch": True, "max_batch_size": 3}]
queue.event_queue = deque()
queue.event_queue.extend(
[
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
]
)
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [0, 0, 0]
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [0]
def test_multiple_batch_events(self, queue: Queue):
queue.blocks_dependencies = [
{"batch": True, "max_batch_size": 3},
{"batch": True, "max_batch_size": 2},
]
queue.event_queue = deque()
queue.event_queue.extend(
[
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
Event(websocket=MagicMock(), session_hash="test", fn_index=1),
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
Event(websocket=MagicMock(), session_hash="test", fn_index=1),
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
]
)
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [0, 0, 0]
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [1, 1]
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [0]
def test_both_types_of_event(self, queue: Queue):
queue.blocks_dependencies = [
{"batch": True, "max_batch_size": 3},
{"batch": False},
]
queue.event_queue = deque()
queue.event_queue.extend(
[
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
Event(websocket=MagicMock(), session_hash="test", fn_index=1),
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
Event(websocket=MagicMock(), session_hash="test", fn_index=1),
Event(websocket=MagicMock(), session_hash="test", fn_index=1),
]
)
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [0, 0]
events, batch = queue.get_events_in_batch()
assert not (batch)
assert [e.fn_index for e in events] == [1]