Correct message when is_generating hits 500 code (#2889)

This commit is contained in:
Jay Smith 2022-12-27 09:31:27 -06:00 committed by GitHub
parent 21820f47ab
commit 571e5eb66c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 2 deletions

View File

@ -5,6 +5,8 @@
## Bug Fixes:
* Fixed bug where setting `default_enabled=False` made it so that the entire queue did not start by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 2876](https://github.com/gradio-app/gradio/pull/2876)
* Fixed bug where an error raised after yielding iterative output would not be displayed in the browser by
[@JaySmithWpg](https://github.com/JaySmithWpg) in [PR 2889](https://github.com/gradio-app/gradio/pull/2889)
## Documentation Changes:
* Added a Guide on using Google Sheets to create a real-time dashboard with Gradio's `DataFrame` and `LinePlot` component, by [@abidlabs](https://github.com/abidlabs) in [PR 2816](https://github.com/gradio-app/gradio/pull/2816)

View File

@ -330,12 +330,17 @@ class Queue:
return
response = await self.call_prediction(awake_events, batch)
for event in awake_events:
if response.status != 200:
relevant_response = response
else:
relevant_response = old_response
await self.send_message(
event,
{
"msg": "process_completed",
"output": old_response.json,
"success": old_response.status == 200,
"output": relevant_response.json,
"success": relevant_response.status == 200,
},
)
else:

View File

@ -223,6 +223,38 @@ class TestQueueProcessEvents:
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@pytest.mark.asyncio
async def test_process_event_handles_exception_in_is_generating_request(
self, queue: Queue, mock_event: Event
):
# We need to return a good response with is_generating=True first,
# setting up the function to expect further iterative responses.
# Then we provide a 500 response.
side_effects = [
MagicMock(has_exception=False, status=200, json=dict(is_generating=True)),
MagicMock(has_exception=False, status=500, json=dict(error="Foo")),
]
mock_event.disconnect = AsyncMock()
queue.gather_event_data = AsyncMock(return_value=True)
queue.clean_event = AsyncMock()
queue.send_message = AsyncMock(return_value=True)
queue.call_prediction = AsyncMock(side_effect=side_effects)
queue.active_jobs = [[mock_event]]
await queue.process_events([mock_event], batch=False)
queue.send_message.assert_called_with(
mock_event,
{
"msg": "process_completed",
"output": {"error": "Foo"},
"success": False,
},
)
assert queue.call_prediction.call_count == 2
mock_event.disconnect.assert_called_once()
assert queue.clean_event.call_count >= 1
@pytest.mark.asyncio
async def test_process_event_handles_error_sending_process_completed_msg(
self, queue: Queue, mock_event: Event