fix : queue could be blocked (#2288)

* fix: queue could be blocked

* add: send error message

* fix

* fix

* Add tests

Co-authored-by: freddyaboulton <alfonsoboulton@gmail.com>
This commit is contained in:
SkyTNT 2022-09-21 02:34:08 +08:00 committed by GitHub
parent b8fc551807
commit 0c4b13620e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 119 additions and 38 deletions

View File

@ -220,49 +220,61 @@ class Queue:
return response
async def process_event(self, event: Event) -> None:
client_awake = await self.gather_event_data(event)
if not client_awake:
return
client_awake = await self.send_message(event, {"msg": "process_starts"})
if not client_awake:
return
begin_time = time.time()
response = await self.call_prediction(event)
if response.json.get("is_generating", False):
while response.json.get("is_generating", False):
old_response = response
try:
client_awake = await self.gather_event_data(event)
if not client_awake:
return
client_awake = await self.send_message(event, {"msg": "process_starts"})
if not client_awake:
return
begin_time = time.time()
response = await self.call_prediction(event)
if response.has_exception:
await self.send_message(
event,
{
"msg": "process_generating",
"msg": "process_completed",
"output": {"error": str(response.exception)},
"success": False,
},
)
elif response.json.get("is_generating", False):
while response.json.get("is_generating", False):
old_response = response
await self.send_message(
event,
{
"msg": "process_generating",
"output": old_response.json,
"success": old_response.status == 200,
},
)
response = await self.call_prediction(event)
await self.send_message(
event,
{
"msg": "process_completed",
"output": old_response.json,
"success": old_response.status == 200,
},
)
response = await self.call_prediction(event)
await self.send_message(
event,
{
"msg": "process_completed",
"output": old_response.json,
"success": old_response.status == 200,
},
)
else:
await self.send_message(
event,
{
"msg": "process_completed",
"output": response.json,
"success": response.status == 200,
},
)
end_time = time.time()
if response.status == 200:
self.update_estimation(end_time - begin_time)
await event.disconnect()
await self.clean_event(event)
else:
await self.send_message(
event,
{
"msg": "process_completed",
"output": response.json,
"success": response.status == 200,
},
)
end_time = time.time()
if response.status == 200:
self.update_estimation(end_time - begin_time)
finally:
try:
await event.disconnect()
finally:
await self.clean_event(event)
async def send_message(self, event, data: Dict) -> bool:
try:

View File

@ -158,10 +158,79 @@ class TestQueueProcessEvents:
queue.send_message.return_value = True
queue.call_prediction = AsyncMock()
queue.call_prediction.return_value = MagicMock()
queue.call_prediction.return_value.has_exception = False
queue.call_prediction.return_value.json = {"is_generating": False}
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
await queue.process_event(mock_event)
assert queue.call_prediction.called
assert mock_event.disconnect.called
queue.call_prediction.assert_called_once()
mock_event.disconnect.assert_called_once()
queue.clean_event.assert_called_once()
@pytest.mark.asyncio
async def test_process_event_handles_error_when_gathering_data(
self, queue: Queue, mock_event: Event
):
mock_event.websocket.send_json = AsyncMock()
mock_event.websocket.send_json.side_effect = ValueError("Can't connect")
queue.call_prediction = AsyncMock()
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
mock_event.data = None
await queue.process_event(mock_event)
assert not queue.call_prediction.called
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@pytest.mark.asyncio
async def test_process_event_handles_error_sending_process_start_msg(
self, queue: Queue, mock_event: Event
):
mock_event.websocket.send_json = AsyncMock()
mock_event.websocket.send_json.side_effect = ["2", ValueError("Can't connect")]
queue.call_prediction = AsyncMock()
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
mock_event.data = None
await queue.process_event(mock_event)
assert not queue.call_prediction.called
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@pytest.mark.asyncio
async def test_process_event_handles_exception_in_call_prediction_request(
self, queue: Queue, mock_event: Event
):
mock_event.disconnect = AsyncMock()
queue.gather_event_data = AsyncMock(return_value=True)
queue.clean_event = AsyncMock()
queue.send_message = AsyncMock(return_value=True)
queue.call_prediction = AsyncMock(
return_value=MagicMock(has_exception=True, exception=ValueError("foo"))
)
await queue.process_event(mock_event)
queue.call_prediction.assert_called_once()
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@pytest.mark.asyncio
async def test_process_event_handles_error_sending_process_completed_msg(
self, queue: Queue, mock_event: Event
):
mock_event.websocket.send_json = AsyncMock()
mock_event.websocket.send_json.side_effect = [
"2",
"3",
ValueError("Can't connect"),
]
queue.call_prediction = AsyncMock(
return_value=MagicMock(has_exception=False, json=dict(is_generating=False))
)
mock_event.disconnect = AsyncMock()
queue.clean_event = AsyncMock()
mock_event.data = None
await queue.process_event(mock_event)
queue.call_prediction.assert_called_once()
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1