Use asyncio.Event to stop stream in heartbeat route (#7932)

This commit is contained in:
Freddy Boulton 2024-04-05 11:57:49 -07:00 committed by GitHub
parent f2a1a859ae
commit b78129d90f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 23 additions and 3 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": patch
---
feat:Use asyncio.Event to stop stream in heartbeat route

View File

@ -2480,6 +2480,7 @@ Received outputs:
self._queue.close()
# set this before closing server to shut down heartbeats
self.is_running = False
self.app.stop_event.set()
if self.server:
self.server.close()
# So that the startup events (starting the queue)

View File

@ -8,6 +8,7 @@ import json
import os
import re
import shutil
import sys
from collections import deque
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass as python_dataclass
@ -792,6 +793,11 @@ async def _delete_state(app: App):
@asynccontextmanager
async def _delete_state_handler(app: App):
"""When the server launches, regularly delete expired state."""
# The stop event needs to get the current event loop for python 3.8
# but the loop parameter is deprecated for 3.8+
if sys.version_info < (3, 9):
loop = asyncio.get_running_loop()
app.stop_event = asyncio.Event(loop=loop)
asyncio.create_task(_delete_state(app))
yield

View File

@ -156,6 +156,7 @@ class App(FastAPI):
self.iterators: dict[str, AsyncIterator] = {}
self.iterators_to_reset: set[str] = set()
self.lock = utils.safe_get_lock()
self.stop_event = utils.safe_get_stop_event()
self.cookie_id = secrets.token_urlsafe(32)
self.queue_token = secrets.token_urlsafe(32)
self.startup_events_triggered = False
@ -606,8 +607,7 @@ class App(FastAPI):
return "wait"
async def stop_stream():
while app.get_blocks().is_running:
await asyncio.sleep(0.25)
await app.stop_event.wait()
return "stop"
async def iterator():

View File

@ -89,6 +89,14 @@ def safe_get_lock() -> asyncio.Lock:
return None # type: ignore
def safe_get_stop_event() -> asyncio.Event:
try:
asyncio.get_event_loop()
return asyncio.Event()
except RuntimeError:
return None # type: ignore
class BaseReloader(ABC):
@property
@abstractmethod

View File

@ -29,6 +29,6 @@ test("when a user closes the page, the unload event should be triggered", async
expect(data).toContain("incremented 1");
expect(data).toContain("incremented 2");
expect(data).toContain("incremented 3");
expect(data).toContain("deleted 4");
expect(data).toContain("unloading");
expect(data).toContain("deleted 4");
});