Add timeouts to queue messages (#3196)

* Fix + test

* Remove print statements + fix import for 3.7

* CHANGELOG

* Remove more print statements

* Add 60 second timeout for uploading data

* Fix test

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2023-02-21 16:44:18 -05:00 committed by GitHub
parent fa094a03e2
commit 5df113a4d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 81 additions and 14 deletions

View File

@ -8,10 +8,10 @@
## Bug Fixes:
- Ensure `mirror_webcam` is always respected by [@pngwn](https://github.com/pngwn) in [PR 3245](https://github.com/gradio-app/gradio/pull/3245)
- Fix issue where updated markdown links were not being opened in a new tab by [@gante](https://github.com/gante) in [PR 3236](https://github.com/gradio-app/gradio/pull/3236)
- Added a timeout to queue messages as some demos were experiencing infinitely growing queues from active jobs waiting forever for clients to respond by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3196](https://github.com/gradio-app/gradio/pull/3196)
- Fixes the height of rendered LaTeX images so that they match the height of surrounding text by [@abidlabs](https://github.com/abidlabs) in [PR 3258](https://github.com/gradio-app/gradio/pull/3258)
- Fix bug where matplotlib images where always too small on the front end by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3274](https://github.com/gradio-app/gradio/pull/3274)
## Documentation Changes:
No changes to highlight.
@ -35,6 +35,7 @@ No changes to highlight.
## Bug Fixes:
- UI fixes including footer and API docs by [@aliabid94](https://github.com/aliabid94) in [PR 3242](https://github.com/gradio-app/gradio/pull/3242)
- Updated image upload component to accept all image formats, including lossless formats like .webp by [@fienestar](https://github.com/fienestar) in [PR 3225](https://github.com/gradio-app/gradio/pull/3225)
## Documentation Changes:
No changes to highlight.

View File

@ -4,6 +4,7 @@ import asyncio
import copy
import sys
import time
from asyncio import TimeoutError as AsyncTimeOutError
from collections import deque
from typing import Any, Deque, Dict, List, Tuple
@ -205,7 +206,7 @@ class Queue:
if self.live_updates:
await self.broadcast_estimations()
async def gather_event_data(self, event: Event) -> bool:
async def gather_event_data(self, event: Event, receive_timeout=60) -> bool:
"""
Gather data for the event
@ -216,7 +217,20 @@ class Queue:
client_awake = await self.send_message(event, {"msg": "send_data"})
if not client_awake:
return False
event.data = await self.get_message(event)
data, client_awake = await self.get_message(event, timeout=receive_timeout)
if not client_awake:
# In the event, we timeout due to large data size
# Let the client know, otherwise will hang
await self.send_message(
event,
{
"msg": "process_completed",
"output": {"error": "Time out uploading data to server"},
"success": False,
},
)
return False
event.data = data
return True
async def notify_clients(self) -> None:
@ -424,21 +438,25 @@ class Queue:
# to start "from scratch"
await self.reset_iterators(event.session_hash, event.fn_index)
async def send_message(self, event, data: Dict) -> bool:
async def send_message(self, event, data: Dict, timeout: float | int = 1) -> bool:
try:
await event.websocket.send_json(data=data)
await asyncio.wait_for(
event.websocket.send_json(data=data), timeout=timeout
)
return True
except:
await self.clean_event(event)
return False
async def get_message(self, event) -> PredictBody | None:
async def get_message(self, event, timeout=5) -> Tuple[PredictBody | None, bool]:
try:
data = await event.websocket.receive_json()
return PredictBody(**data)
except:
data = await asyncio.wait_for(
event.websocket.receive_json(), timeout=timeout
)
return PredictBody(**data), True
except AsyncTimeOutError:
await self.clean_event(event)
return None
return None, False
async def reset_iterators(self, session_hash: str, fn_index: int):
await AsyncRequest(

View File

@ -12,6 +12,7 @@ import posixpath
import secrets
import tempfile
import traceback
from asyncio import TimeoutError as AsyncTimeOutError
from collections import defaultdict
from copy import deepcopy
from typing import Any, Dict, List, Optional, Type
@ -479,8 +480,20 @@ class App(FastAPI):
await websocket.accept()
# In order to cancel jobs, we need the session_hash and fn_index
# to create a unique id for each job
await websocket.send_json({"msg": "send_hash"})
session_info = await websocket.receive_json()
try:
await asyncio.wait_for(
websocket.send_json({"msg": "send_hash"}), timeout=1
)
except AsyncTimeOutError:
return
try:
session_info = await asyncio.wait_for(
websocket.receive_json(), timeout=1
)
except AsyncTimeOutError:
return
event = Event(
websocket, session_info["session_hash"], session_info["fn_index"]
)

View File

@ -1,3 +1,4 @@
import asyncio
import os
import sys
from collections import deque
@ -31,7 +32,7 @@ def queue() -> Queue:
@pytest.fixture()
def mock_event() -> Event:
websocket = MagicMock()
websocket = AsyncMock()
event = Event(websocket=websocket, session_hash="test", fn_index=0)
yield event
@ -53,9 +54,20 @@ class TestQueueMethods:
@pytest.mark.asyncio
async def test_receive(self, queue: Queue, mock_event: Event):
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
await queue.get_message(mock_event)
assert mock_event.websocket.receive_json.called
@pytest.mark.asyncio
async def test_receive_timeout(self, queue: Queue, mock_event: Event):
async def take_too_long():
await asyncio.sleep(1)
mock_event.websocket.receive_json = take_too_long
data, is_awake = await queue.get_message(mock_event, timeout=0.5)
assert data is None
assert not is_awake
@pytest.mark.asyncio
async def test_send(self, queue: Queue, mock_event: Event):
await queue.send_message(mock_event, {})
@ -85,7 +97,7 @@ class TestQueueMethods:
queue.send_message = AsyncMock()
queue.get_message = AsyncMock()
queue.send_message.return_value = True
queue.get_message.return_value = {"data": ["test"], "fn": 0}
queue.get_message.return_value = {"data": ["test"], "fn": 0}, True
assert await queue.gather_event_data(mock_event)
assert queue.send_message.called
@ -95,6 +107,25 @@ class TestQueueMethods:
assert await queue.gather_event_data(mock_event)
assert not (queue.send_message.called)
@pytest.mark.asyncio
async def test_gather_event_data_timeout(self, queue: Queue, mock_event: Event):
async def take_too_long():
await asyncio.sleep(1)
queue.send_message = AsyncMock()
queue.send_message.return_value = True
mock_event.websocket.receive_json = take_too_long
is_awake = await queue.gather_event_data(mock_event, receive_timeout=0.5)
assert not is_awake
# Have to use awful [1][0][1] syntax cause of python 3.7
assert queue.send_message.call_args_list[1][0][1] == {
"msg": "process_completed",
"output": {"error": "Time out uploading data to server"},
"success": False,
}
class TestQueueEstimation:
def test_get_update_estimation(self, queue: Queue):
@ -193,6 +224,8 @@ class TestQueueProcessEvents:
self, queue: Queue, mock_event: Event
):
mock_event.websocket.send_json = AsyncMock()
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
mock_event.websocket.send_json.side_effect = ["2", ValueError("Can't connect")]
queue.call_prediction = AsyncMock()
mock_event.disconnect = AsyncMock()
@ -260,6 +293,7 @@ class TestQueueProcessEvents:
async def test_process_event_handles_error_sending_process_completed_msg(
self, queue: Queue, mock_event: Event
):
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
mock_event.websocket.send_json = AsyncMock()
mock_event.websocket.send_json.side_effect = [
"2",
@ -289,6 +323,7 @@ class TestQueueProcessEvents:
async def test_process_event_handles_exception_during_disconnect(
self, mock_request, queue: Queue, mock_event: Event
):
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
mock_event.websocket.send_json = AsyncMock()
queue.call_prediction = AsyncMock(
return_value=MagicMock(has_exception=False, json=dict(is_generating=False))