mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-18 10:44:33 +08:00
Add timeouts to queue messages (#3196)
* Fix + test * Remove print statements + fix import for 3.7 * CHANGELOG * Remove more print statements * Add 60 second timeout for uploading data * Fix test --------- Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
fa094a03e2
commit
5df113a4d6
@ -8,10 +8,10 @@
|
||||
## Bug Fixes:
|
||||
- Ensure `mirror_webcam` is always respected by [@pngwn](https://github.com/pngwn) in [PR 3245](https://github.com/gradio-app/gradio/pull/3245)
|
||||
- Fix issue where updated markdown links were not being opened in a new tab by [@gante](https://github.com/gante) in [PR 3236](https://github.com/gradio-app/gradio/pull/3236)
|
||||
- Added a timeout to queue messages as some demos were experiencing infinitely growing queues from active jobs waiting forever for clients to respond by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3196](https://github.com/gradio-app/gradio/pull/3196)
|
||||
- Fixes the height of rendered LaTeX images so that they match the height of surrounding text by [@abidlabs](https://github.com/abidlabs) in [PR 3258](https://github.com/gradio-app/gradio/pull/3258)
|
||||
- Fix bug where matplotlib images where always too small on the front end by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3274](https://github.com/gradio-app/gradio/pull/3274)
|
||||
|
||||
|
||||
## Documentation Changes:
|
||||
No changes to highlight.
|
||||
|
||||
@ -35,6 +35,7 @@ No changes to highlight.
|
||||
|
||||
## Bug Fixes:
|
||||
- UI fixes including footer and API docs by [@aliabid94](https://github.com/aliabid94) in [PR 3242](https://github.com/gradio-app/gradio/pull/3242)
|
||||
- Updated image upload component to accept all image formats, including lossless formats like .webp by [@fienestar](https://github.com/fienestar) in [PR 3225](https://github.com/gradio-app/gradio/pull/3225)
|
||||
|
||||
## Documentation Changes:
|
||||
No changes to highlight.
|
||||
|
@ -4,6 +4,7 @@ import asyncio
|
||||
import copy
|
||||
import sys
|
||||
import time
|
||||
from asyncio import TimeoutError as AsyncTimeOutError
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, List, Tuple
|
||||
|
||||
@ -205,7 +206,7 @@ class Queue:
|
||||
if self.live_updates:
|
||||
await self.broadcast_estimations()
|
||||
|
||||
async def gather_event_data(self, event: Event) -> bool:
|
||||
async def gather_event_data(self, event: Event, receive_timeout=60) -> bool:
|
||||
"""
|
||||
Gather data for the event
|
||||
|
||||
@ -216,7 +217,20 @@ class Queue:
|
||||
client_awake = await self.send_message(event, {"msg": "send_data"})
|
||||
if not client_awake:
|
||||
return False
|
||||
event.data = await self.get_message(event)
|
||||
data, client_awake = await self.get_message(event, timeout=receive_timeout)
|
||||
if not client_awake:
|
||||
# In the event, we timeout due to large data size
|
||||
# Let the client know, otherwise will hang
|
||||
await self.send_message(
|
||||
event,
|
||||
{
|
||||
"msg": "process_completed",
|
||||
"output": {"error": "Time out uploading data to server"},
|
||||
"success": False,
|
||||
},
|
||||
)
|
||||
return False
|
||||
event.data = data
|
||||
return True
|
||||
|
||||
async def notify_clients(self) -> None:
|
||||
@ -424,21 +438,25 @@ class Queue:
|
||||
# to start "from scratch"
|
||||
await self.reset_iterators(event.session_hash, event.fn_index)
|
||||
|
||||
async def send_message(self, event, data: Dict) -> bool:
|
||||
async def send_message(self, event, data: Dict, timeout: float | int = 1) -> bool:
|
||||
try:
|
||||
await event.websocket.send_json(data=data)
|
||||
await asyncio.wait_for(
|
||||
event.websocket.send_json(data=data), timeout=timeout
|
||||
)
|
||||
return True
|
||||
except:
|
||||
await self.clean_event(event)
|
||||
return False
|
||||
|
||||
async def get_message(self, event) -> PredictBody | None:
|
||||
async def get_message(self, event, timeout=5) -> Tuple[PredictBody | None, bool]:
|
||||
try:
|
||||
data = await event.websocket.receive_json()
|
||||
return PredictBody(**data)
|
||||
except:
|
||||
data = await asyncio.wait_for(
|
||||
event.websocket.receive_json(), timeout=timeout
|
||||
)
|
||||
return PredictBody(**data), True
|
||||
except AsyncTimeOutError:
|
||||
await self.clean_event(event)
|
||||
return None
|
||||
return None, False
|
||||
|
||||
async def reset_iterators(self, session_hash: str, fn_index: int):
|
||||
await AsyncRequest(
|
||||
|
@ -12,6 +12,7 @@ import posixpath
|
||||
import secrets
|
||||
import tempfile
|
||||
import traceback
|
||||
from asyncio import TimeoutError as AsyncTimeOutError
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
@ -479,8 +480,20 @@ class App(FastAPI):
|
||||
await websocket.accept()
|
||||
# In order to cancel jobs, we need the session_hash and fn_index
|
||||
# to create a unique id for each job
|
||||
await websocket.send_json({"msg": "send_hash"})
|
||||
session_info = await websocket.receive_json()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
websocket.send_json({"msg": "send_hash"}), timeout=1
|
||||
)
|
||||
except AsyncTimeOutError:
|
||||
return
|
||||
|
||||
try:
|
||||
session_info = await asyncio.wait_for(
|
||||
websocket.receive_json(), timeout=1
|
||||
)
|
||||
except AsyncTimeOutError:
|
||||
return
|
||||
|
||||
event = Event(
|
||||
websocket, session_info["session_hash"], session_info["fn_index"]
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from collections import deque
|
||||
@ -31,7 +32,7 @@ def queue() -> Queue:
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_event() -> Event:
|
||||
websocket = MagicMock()
|
||||
websocket = AsyncMock()
|
||||
event = Event(websocket=websocket, session_hash="test", fn_index=0)
|
||||
yield event
|
||||
|
||||
@ -53,9 +54,20 @@ class TestQueueMethods:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive(self, queue: Queue, mock_event: Event):
|
||||
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
|
||||
await queue.get_message(mock_event)
|
||||
assert mock_event.websocket.receive_json.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_timeout(self, queue: Queue, mock_event: Event):
|
||||
async def take_too_long():
|
||||
await asyncio.sleep(1)
|
||||
|
||||
mock_event.websocket.receive_json = take_too_long
|
||||
data, is_awake = await queue.get_message(mock_event, timeout=0.5)
|
||||
assert data is None
|
||||
assert not is_awake
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send(self, queue: Queue, mock_event: Event):
|
||||
await queue.send_message(mock_event, {})
|
||||
@ -85,7 +97,7 @@ class TestQueueMethods:
|
||||
queue.send_message = AsyncMock()
|
||||
queue.get_message = AsyncMock()
|
||||
queue.send_message.return_value = True
|
||||
queue.get_message.return_value = {"data": ["test"], "fn": 0}
|
||||
queue.get_message.return_value = {"data": ["test"], "fn": 0}, True
|
||||
|
||||
assert await queue.gather_event_data(mock_event)
|
||||
assert queue.send_message.called
|
||||
@ -95,6 +107,25 @@ class TestQueueMethods:
|
||||
assert await queue.gather_event_data(mock_event)
|
||||
assert not (queue.send_message.called)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gather_event_data_timeout(self, queue: Queue, mock_event: Event):
|
||||
async def take_too_long():
|
||||
await asyncio.sleep(1)
|
||||
|
||||
queue.send_message = AsyncMock()
|
||||
queue.send_message.return_value = True
|
||||
|
||||
mock_event.websocket.receive_json = take_too_long
|
||||
is_awake = await queue.gather_event_data(mock_event, receive_timeout=0.5)
|
||||
assert not is_awake
|
||||
|
||||
# Have to use awful [1][0][1] syntax cause of python 3.7
|
||||
assert queue.send_message.call_args_list[1][0][1] == {
|
||||
"msg": "process_completed",
|
||||
"output": {"error": "Time out uploading data to server"},
|
||||
"success": False,
|
||||
}
|
||||
|
||||
|
||||
class TestQueueEstimation:
|
||||
def test_get_update_estimation(self, queue: Queue):
|
||||
@ -193,6 +224,8 @@ class TestQueueProcessEvents:
|
||||
self, queue: Queue, mock_event: Event
|
||||
):
|
||||
mock_event.websocket.send_json = AsyncMock()
|
||||
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
|
||||
|
||||
mock_event.websocket.send_json.side_effect = ["2", ValueError("Can't connect")]
|
||||
queue.call_prediction = AsyncMock()
|
||||
mock_event.disconnect = AsyncMock()
|
||||
@ -260,6 +293,7 @@ class TestQueueProcessEvents:
|
||||
async def test_process_event_handles_error_sending_process_completed_msg(
|
||||
self, queue: Queue, mock_event: Event
|
||||
):
|
||||
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
|
||||
mock_event.websocket.send_json = AsyncMock()
|
||||
mock_event.websocket.send_json.side_effect = [
|
||||
"2",
|
||||
@ -289,6 +323,7 @@ class TestQueueProcessEvents:
|
||||
async def test_process_event_handles_exception_during_disconnect(
|
||||
self, mock_request, queue: Queue, mock_event: Event
|
||||
):
|
||||
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
|
||||
mock_event.websocket.send_json = AsyncMock()
|
||||
queue.call_prediction = AsyncMock(
|
||||
return_value=MagicMock(has_exception=False, json=dict(is_generating=False))
|
||||
|
Loading…
Reference in New Issue
Block a user