mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-01 11:45:36 +08:00
fix : queue could be blocked (#2288)
* fix: queue could be blocked * add: send error message * fix * fix * Add tests Co-authored-by: freddyaboulton <alfonsoboulton@gmail.com>
This commit is contained in:
parent
b8fc551807
commit
0c4b13620e
@ -220,49 +220,61 @@ class Queue:
|
||||
return response
|
||||
|
||||
async def process_event(self, event: Event) -> None:
|
||||
client_awake = await self.gather_event_data(event)
|
||||
if not client_awake:
|
||||
return
|
||||
client_awake = await self.send_message(event, {"msg": "process_starts"})
|
||||
if not client_awake:
|
||||
return
|
||||
begin_time = time.time()
|
||||
response = await self.call_prediction(event)
|
||||
if response.json.get("is_generating", False):
|
||||
while response.json.get("is_generating", False):
|
||||
old_response = response
|
||||
try:
|
||||
client_awake = await self.gather_event_data(event)
|
||||
if not client_awake:
|
||||
return
|
||||
client_awake = await self.send_message(event, {"msg": "process_starts"})
|
||||
if not client_awake:
|
||||
return
|
||||
begin_time = time.time()
|
||||
response = await self.call_prediction(event)
|
||||
if response.has_exception:
|
||||
await self.send_message(
|
||||
event,
|
||||
{
|
||||
"msg": "process_generating",
|
||||
"msg": "process_completed",
|
||||
"output": {"error": str(response.exception)},
|
||||
"success": False,
|
||||
},
|
||||
)
|
||||
elif response.json.get("is_generating", False):
|
||||
while response.json.get("is_generating", False):
|
||||
old_response = response
|
||||
await self.send_message(
|
||||
event,
|
||||
{
|
||||
"msg": "process_generating",
|
||||
"output": old_response.json,
|
||||
"success": old_response.status == 200,
|
||||
},
|
||||
)
|
||||
response = await self.call_prediction(event)
|
||||
await self.send_message(
|
||||
event,
|
||||
{
|
||||
"msg": "process_completed",
|
||||
"output": old_response.json,
|
||||
"success": old_response.status == 200,
|
||||
},
|
||||
)
|
||||
response = await self.call_prediction(event)
|
||||
await self.send_message(
|
||||
event,
|
||||
{
|
||||
"msg": "process_completed",
|
||||
"output": old_response.json,
|
||||
"success": old_response.status == 200,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await self.send_message(
|
||||
event,
|
||||
{
|
||||
"msg": "process_completed",
|
||||
"output": response.json,
|
||||
"success": response.status == 200,
|
||||
},
|
||||
)
|
||||
end_time = time.time()
|
||||
if response.status == 200:
|
||||
self.update_estimation(end_time - begin_time)
|
||||
|
||||
await event.disconnect()
|
||||
await self.clean_event(event)
|
||||
else:
|
||||
await self.send_message(
|
||||
event,
|
||||
{
|
||||
"msg": "process_completed",
|
||||
"output": response.json,
|
||||
"success": response.status == 200,
|
||||
},
|
||||
)
|
||||
end_time = time.time()
|
||||
if response.status == 200:
|
||||
self.update_estimation(end_time - begin_time)
|
||||
finally:
|
||||
try:
|
||||
await event.disconnect()
|
||||
finally:
|
||||
await self.clean_event(event)
|
||||
|
||||
async def send_message(self, event, data: Dict) -> bool:
|
||||
try:
|
||||
|
@ -158,10 +158,79 @@ class TestQueueProcessEvents:
|
||||
queue.send_message.return_value = True
|
||||
queue.call_prediction = AsyncMock()
|
||||
queue.call_prediction.return_value = MagicMock()
|
||||
queue.call_prediction.return_value.has_exception = False
|
||||
queue.call_prediction.return_value.json = {"is_generating": False}
|
||||
mock_event.disconnect = AsyncMock()
|
||||
queue.clean_event = AsyncMock()
|
||||
await queue.process_event(mock_event)
|
||||
|
||||
assert queue.call_prediction.called
|
||||
assert mock_event.disconnect.called
|
||||
queue.call_prediction.assert_called_once()
|
||||
mock_event.disconnect.assert_called_once()
|
||||
queue.clean_event.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_event_handles_error_when_gathering_data(
|
||||
self, queue: Queue, mock_event: Event
|
||||
):
|
||||
mock_event.websocket.send_json = AsyncMock()
|
||||
mock_event.websocket.send_json.side_effect = ValueError("Can't connect")
|
||||
queue.call_prediction = AsyncMock()
|
||||
mock_event.disconnect = AsyncMock()
|
||||
queue.clean_event = AsyncMock()
|
||||
mock_event.data = None
|
||||
await queue.process_event(mock_event)
|
||||
assert not queue.call_prediction.called
|
||||
mock_event.disconnect.assert_called_once()
|
||||
assert queue.clean_event.call_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_event_handles_error_sending_process_start_msg(
|
||||
self, queue: Queue, mock_event: Event
|
||||
):
|
||||
mock_event.websocket.send_json = AsyncMock()
|
||||
mock_event.websocket.send_json.side_effect = ["2", ValueError("Can't connect")]
|
||||
queue.call_prediction = AsyncMock()
|
||||
mock_event.disconnect = AsyncMock()
|
||||
queue.clean_event = AsyncMock()
|
||||
mock_event.data = None
|
||||
await queue.process_event(mock_event)
|
||||
assert not queue.call_prediction.called
|
||||
mock_event.disconnect.assert_called_once()
|
||||
assert queue.clean_event.call_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_event_handles_exception_in_call_prediction_request(
|
||||
self, queue: Queue, mock_event: Event
|
||||
):
|
||||
mock_event.disconnect = AsyncMock()
|
||||
queue.gather_event_data = AsyncMock(return_value=True)
|
||||
queue.clean_event = AsyncMock()
|
||||
queue.send_message = AsyncMock(return_value=True)
|
||||
queue.call_prediction = AsyncMock(
|
||||
return_value=MagicMock(has_exception=True, exception=ValueError("foo"))
|
||||
)
|
||||
await queue.process_event(mock_event)
|
||||
queue.call_prediction.assert_called_once()
|
||||
mock_event.disconnect.assert_called_once()
|
||||
assert queue.clean_event.call_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_event_handles_error_sending_process_completed_msg(
|
||||
self, queue: Queue, mock_event: Event
|
||||
):
|
||||
mock_event.websocket.send_json = AsyncMock()
|
||||
mock_event.websocket.send_json.side_effect = [
|
||||
"2",
|
||||
"3",
|
||||
ValueError("Can't connect"),
|
||||
]
|
||||
queue.call_prediction = AsyncMock(
|
||||
return_value=MagicMock(has_exception=False, json=dict(is_generating=False))
|
||||
)
|
||||
mock_event.disconnect = AsyncMock()
|
||||
queue.clean_event = AsyncMock()
|
||||
mock_event.data = None
|
||||
await queue.process_event(mock_event)
|
||||
queue.call_prediction.assert_called_once()
|
||||
mock_event.disconnect.assert_called_once()
|
||||
assert queue.clean_event.call_count >= 1
|
||||
|
Loading…
Reference in New Issue
Block a user