gradio/test/test_queueing.py

419 lines
15 KiB
Python
Raw Normal View History

import os
import sys
from collections import deque
from unittest.mock import MagicMock, patch
import pytest
from gradio.queueing import Event, Queue
from gradio.utils import AsyncRequest
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class AsyncMock(MagicMock):
async def __call__(self, *args, **kwargs):
return super(AsyncMock, self).__call__(*args, **kwargs)
@pytest.fixture()
def queue() -> Queue:
queue_object = Queue(
live_updates=True,
concurrency_count=1,
update_intervals=1,
max_size=None,
blocks_dependencies=[],
)
yield queue_object
queue_object.close()
@pytest.fixture()
def mock_event() -> Event:
websocket = MagicMock()
Add Progress Bar component (#2750) * changes * version * changes * fixes * changes * changes * changes * changes * chagnes * chagnes * fix * changes * changes * changes * change * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * version update * Commit from GitHub Actions (Upload Python Package) * changes * changes * changes * fix * changes * changes * changes * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * changes * changes * changes * changes * change * changes * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/helpers.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/routes.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/helpers.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update website/homepage/src/docs/template.html Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update website/homepage/src/docs/template.html Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * changes * changes * changes * changes * changes * changes * changes * change * changes * changes * changes * change Co-authored-by: Abubakar Abid <abubakar@huggingface.co> Co-authored-by: GH ACTIONS <aliabid94@users.noreply.github.com>
2022-12-31 03:45:54 +08:00
event = Event(websocket=websocket, session_hash="test", fn_index=0)
yield event
class TestQueueMethods:
@pytest.mark.asyncio
async def test_start(self, queue: Queue):
await queue.start()
assert queue.stopped is False
assert queue.get_active_worker_count() == 0
@pytest.mark.asyncio
async def test_stop_resume(self, queue: Queue):
await queue.start()
queue.close()
assert queue.stopped
queue.resume()
assert queue.stopped is False
@pytest.mark.asyncio
async def test_receive(self, queue: Queue, mock_event: Event):
await queue.get_message(mock_event)
assert mock_event.websocket.receive_json.called
@pytest.mark.asyncio
async def test_send(self, queue: Queue, mock_event: Event):
await queue.send_message(mock_event, {})
assert mock_event.websocket.send_json.called
@pytest.mark.asyncio
async def test_add_to_queue(self, queue: Queue, mock_event: Event):
queue.push(mock_event)
assert len(queue.event_queue) == 1
@pytest.mark.asyncio
async def test_add_to_queue_with_max_size(self, queue: Queue, mock_event: Event):
queue.max_size = 1
queue.push(mock_event)
assert len(queue.event_queue) == 1
queue.push(mock_event)
assert len(queue.event_queue) == 1
@pytest.mark.asyncio
async def test_clean_event(self, queue: Queue, mock_event: Event):
queue.push(mock_event)
await queue.clean_event(mock_event)
assert len(queue.event_queue) == 0
@pytest.mark.asyncio
async def test_gather_event_data(self, queue: Queue, mock_event: Event):
queue.send_message = AsyncMock()
queue.get_message = AsyncMock()
queue.send_message.return_value = True
queue.get_message.return_value = {"data": ["test"], "fn": 0}
assert await queue.gather_event_data(mock_event)
assert queue.send_message.called
assert mock_event.data == {"data": ["test"], "fn": 0}
queue.send_message.called = False
assert await queue.gather_event_data(mock_event)
assert not (queue.send_message.called)
class TestQueueEstimation:
def test_get_update_estimation(self, queue: Queue):
queue.update_estimation(5)
estimation = queue.get_estimation()
assert estimation.avg_event_process_time == 5
queue.update_estimation(15)
estimation = queue.get_estimation()
assert estimation.avg_event_process_time == 10 # (5 + 15) / 2
queue.update_estimation(100)
estimation = queue.get_estimation()
assert estimation.avg_event_process_time == 40 # (5 + 15 + 100) / 3
@pytest.mark.asyncio
async def test_send_estimation(self, queue: Queue, mock_event: Event):
queue.send_message = AsyncMock()
queue.send_message.return_value = True
estimation = queue.get_estimation()
estimation = await queue.send_estimation(mock_event, estimation, 1)
assert queue.send_message.called
assert estimation.rank == 1
queue.update_estimation(5)
estimation = await queue.send_estimation(mock_event, estimation, 2)
assert estimation.rank == 2
assert estimation.rank_eta == 15
@pytest.mark.asyncio
async def queue_sets_concurrency_count(self):
queue_object = Queue(
live_updates=True,
concurrency_count=5,
update_intervals=1,
max_size=None,
)
assert len(queue_object.active_jobs) == 5
queue_object.close()
class TestQueueProcessEvents:
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Mocks of async context manager don't work for 3.7",
)
@pytest.mark.asyncio
@patch("gradio.queueing.AsyncRequest", new_callable=AsyncMock)
async def test_process_event(self, mock_request, queue: Queue, mock_event: Event):
queue.gather_event_data = AsyncMock()
queue.gather_event_data.return_value = True
queue.send_message = AsyncMock()
queue.send_message.return_value = True
queue.call_prediction = AsyncMock()
queue.call_prediction.return_value = MagicMock()
queue.call_prediction.return_value.has_exception = False
queue.call_prediction.return_value.json = {"is_generating": False}
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
queue.call_prediction.assert_called_once()
mock_event.disconnect.assert_called_once()
queue.clean_event.assert_called_once()
mock_request.assert_called_with(
method=AsyncRequest.Method.POST,
url=f"{queue.server_path}reset",
json={
"session_hash": mock_event.session_hash,
"fn_index": mock_event.fn_index,
},
client=None,
)
@pytest.mark.asyncio
async def test_process_event_handles_error_when_gathering_data(
self, queue: Queue, mock_event: Event
):
mock_event.websocket.send_json = AsyncMock()
mock_event.websocket.send_json.side_effect = ValueError("Can't connect")
queue.call_prediction = AsyncMock()
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
mock_event.data = None
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
assert not queue.call_prediction.called
assert queue.clean_event.call_count >= 1
@pytest.mark.asyncio
async def test_process_event_handles_error_sending_process_start_msg(
self, queue: Queue, mock_event: Event
):
mock_event.websocket.send_json = AsyncMock()
mock_event.websocket.send_json.side_effect = ["2", ValueError("Can't connect")]
queue.call_prediction = AsyncMock()
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
mock_event.data = None
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
assert not queue.call_prediction.called
assert queue.clean_event.call_count >= 1
@pytest.mark.asyncio
async def test_process_event_handles_exception_in_call_prediction_request(
self, queue: Queue, mock_event: Event
):
mock_event.disconnect = AsyncMock()
queue.gather_event_data = AsyncMock(return_value=True)
queue.clean_event = AsyncMock()
queue.send_message = AsyncMock(return_value=True)
queue.call_prediction = AsyncMock(
return_value=MagicMock(has_exception=True, exception=ValueError("foo"))
)
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
queue.call_prediction.assert_called_once()
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@pytest.mark.asyncio
async def test_process_event_handles_exception_in_is_generating_request(
self, queue: Queue, mock_event: Event
):
# We need to return a good response with is_generating=True first,
# setting up the function to expect further iterative responses.
# Then we provide a 500 response.
side_effects = [
MagicMock(has_exception=False, status=200, json=dict(is_generating=True)),
MagicMock(has_exception=False, status=500, json=dict(error="Foo")),
]
mock_event.disconnect = AsyncMock()
queue.gather_event_data = AsyncMock(return_value=True)
queue.clean_event = AsyncMock()
queue.send_message = AsyncMock(return_value=True)
queue.call_prediction = AsyncMock(side_effect=side_effects)
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
queue.send_message.assert_called_with(
mock_event,
{
"msg": "process_completed",
"output": {"error": "Foo"},
"success": False,
},
)
assert queue.call_prediction.call_count == 2
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@pytest.mark.asyncio
async def test_process_event_handles_error_sending_process_completed_msg(
self, queue: Queue, mock_event: Event
):
mock_event.websocket.send_json = AsyncMock()
mock_event.websocket.send_json.side_effect = [
"2",
"3",
ValueError("Can't connect"),
]
queue.call_prediction = AsyncMock(
return_value=MagicMock(has_exception=False, json=dict(is_generating=False))
)
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
mock_event.data = None
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
queue.call_prediction.assert_called_once()
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@pytest.mark.skipif(
sys.version_info < (3, 8),
reason="Mocks of async context manager don't work for 3.7",
)
@pytest.mark.asyncio
@patch("gradio.queueing.AsyncRequest", new_callable=AsyncMock)
async def test_process_event_handles_exception_during_disconnect(
self, mock_request, queue: Queue, mock_event: Event
):
mock_event.websocket.send_json = AsyncMock()
queue.call_prediction = AsyncMock(
return_value=MagicMock(has_exception=False, json=dict(is_generating=False))
)
# No exception should be raised during `process_event`
mock_event.disconnect = AsyncMock(side_effect=ValueError("..."))
queue.clean_event = AsyncMock()
mock_event.data = None
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
mock_request.assert_called_with(
method=AsyncRequest.Method.POST,
url=f"{queue.server_path}reset",
json={
"session_hash": mock_event.session_hash,
"fn_index": mock_event.fn_index,
},
client=None,
)
class TestQueueBatch:
@pytest.mark.asyncio
async def test_process_event(self, queue: Queue, mock_event: Event):
queue.gather_event_data = AsyncMock()
queue.gather_event_data.return_value = True
queue.send_message = AsyncMock()
queue.send_message.return_value = True
queue.call_prediction = AsyncMock()
queue.call_prediction.return_value = MagicMock()
queue.call_prediction.return_value.has_exception = False
queue.call_prediction.return_value.json = {
"is_generating": False,
"data": [[1, 2]],
}
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
websocket = MagicMock()
Add Progress Bar component (#2750) * changes * version * changes * fixes * changes * changes * changes * changes * chagnes * chagnes * fix * changes * changes * changes * change * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * version update * Commit from GitHub Actions (Upload Python Package) * changes * changes * changes * fix * changes * changes * changes * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * changes * changes * changes * changes * change * changes * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/helpers.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/routes.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/helpers.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update website/homepage/src/docs/template.html Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update website/homepage/src/docs/template.html Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * changes * changes * changes * changes * changes * changes * changes * change * changes * changes * changes * change Co-authored-by: Abubakar Abid <abubakar@huggingface.co> Co-authored-by: GH ACTIONS <aliabid94@users.noreply.github.com>
2022-12-31 03:45:54 +08:00
mock_event2 = Event(websocket=websocket, session_hash="test", fn_index=0)
mock_event2.disconnect = AsyncMock()
queue.active_jobs = [[mock_event, mock_event2]]
await queue.process_events([mock_event, mock_event2], batch=True)
queue.call_prediction.assert_called_once() # called once for both events
mock_event.disconnect.assert_called_once()
mock_event2.disconnect.assert_called_once()
queue.clean_event.call_count == 2
class TestGetEventsInBatch:
def test_empty_event_queue(self, queue: Queue):
queue.event_queue = deque()
events, _ = queue.get_events_in_batch()
assert events is None
def test_single_type_of_event(self, queue: Queue):
queue.blocks_dependencies = [{"batch": True, "max_batch_size": 3}]
queue.event_queue = deque()
queue.event_queue.extend(
[
Add Progress Bar component (#2750) * changes * version * changes * fixes * changes * changes * changes * changes * chagnes * chagnes * fix * changes * changes * changes * change * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * version update * Commit from GitHub Actions (Upload Python Package) * changes * changes * changes * fix * changes * changes * changes * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * changes * changes * changes * changes * change * changes * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/helpers.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/routes.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/helpers.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update website/homepage/src/docs/template.html Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update website/homepage/src/docs/template.html Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * changes * changes * changes * changes * changes * changes * changes * change * changes * changes * changes * change Co-authored-by: Abubakar Abid <abubakar@huggingface.co> Co-authored-by: GH ACTIONS <aliabid94@users.noreply.github.com>
2022-12-31 03:45:54 +08:00
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
]
)
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [0, 0, 0]
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [0]
def test_multiple_batch_events(self, queue: Queue):
queue.blocks_dependencies = [
{"batch": True, "max_batch_size": 3},
{"batch": True, "max_batch_size": 2},
]
queue.event_queue = deque()
queue.event_queue.extend(
[
Add Progress Bar component (#2750) * changes * version * changes * fixes * changes * changes * changes * changes * chagnes * chagnes * fix * changes * changes * changes * change * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * version update * Commit from GitHub Actions (Upload Python Package) * changes * changes * changes * fix * changes * changes * changes * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * changes * changes * changes * changes * change * changes * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/helpers.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/routes.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/helpers.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update website/homepage/src/docs/template.html Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update website/homepage/src/docs/template.html Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * changes * changes * changes * changes * changes * changes * changes * change * changes * changes * changes * change Co-authored-by: Abubakar Abid <abubakar@huggingface.co> Co-authored-by: GH ACTIONS <aliabid94@users.noreply.github.com>
2022-12-31 03:45:54 +08:00
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
Event(websocket=MagicMock(), session_hash="test", fn_index=1),
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
Event(websocket=MagicMock(), session_hash="test", fn_index=1),
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
]
)
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [0, 0, 0]
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [1, 1]
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [0]
def test_both_types_of_event(self, queue: Queue):
queue.blocks_dependencies = [
{"batch": True, "max_batch_size": 3},
{"batch": False},
]
queue.event_queue = deque()
queue.event_queue.extend(
[
Add Progress Bar component (#2750) * changes * version * changes * fixes * changes * changes * changes * changes * chagnes * chagnes * fix * changes * changes * changes * change * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * version update * Commit from GitHub Actions (Upload Python Package) * changes * changes * changes * fix * changes * changes * changes * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * changes * changes * changes * changes * change * changes * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/helpers.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/routes.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update gradio/helpers.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update guides/01_getting_started/02_key_features.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update demo/progress_simple/run.py Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update website/homepage/src/docs/template.html Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Update website/homepage/src/docs/template.html Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * changes * changes * changes * changes * changes * changes * changes * change * changes * changes * changes * change Co-authored-by: Abubakar Abid <abubakar@huggingface.co> Co-authored-by: GH ACTIONS <aliabid94@users.noreply.github.com>
2022-12-31 03:45:54 +08:00
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
Event(websocket=MagicMock(), session_hash="test", fn_index=1),
Event(websocket=MagicMock(), session_hash="test", fn_index=0),
Event(websocket=MagicMock(), session_hash="test", fn_index=1),
Event(websocket=MagicMock(), session_hash="test", fn_index=1),
]
)
events, batch = queue.get_events_in_batch()
assert batch
assert [e.fn_index for e in events] == [0, 0]
events, batch = queue.get_events_in_batch()
assert not (batch)
assert [e.fn_index for e in events] == [1]