mirror of
https://github.com/gradio-app/gradio.git
synced 2025-01-24 10:54:04 +08:00
Use asyncio.Event to stop stream in heartbeat route (#7932)
This commit is contained in:
parent
f2a1a859ae
commit
b78129d90f
5
.changeset/tiny-friends-stick.md
Normal file
5
.changeset/tiny-friends-stick.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": patch
|
||||
---
|
||||
|
||||
feat:Use asyncio.Event to stop stream in heartbeat route
|
@ -2480,6 +2480,7 @@ Received outputs:
|
||||
self._queue.close()
|
||||
# set this before closing server to shut down heartbeats
|
||||
self.is_running = False
|
||||
self.app.stop_event.set()
|
||||
if self.server:
|
||||
self.server.close()
|
||||
# So that the startup events (starting the queue)
|
||||
|
@ -8,6 +8,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from collections import deque
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from dataclasses import dataclass as python_dataclass
|
||||
@ -792,6 +793,11 @@ async def _delete_state(app: App):
|
||||
@asynccontextmanager
|
||||
async def _delete_state_handler(app: App):
|
||||
"""When the server launches, regularly delete expired state."""
|
||||
# The stop event needs to get the current event loop for python 3.8
|
||||
# but the loop parameter is deprecated for 3.8+
|
||||
if sys.version_info < (3, 9):
|
||||
loop = asyncio.get_running_loop()
|
||||
app.stop_event = asyncio.Event(loop=loop)
|
||||
asyncio.create_task(_delete_state(app))
|
||||
yield
|
||||
|
||||
|
@ -156,6 +156,7 @@ class App(FastAPI):
|
||||
self.iterators: dict[str, AsyncIterator] = {}
|
||||
self.iterators_to_reset: set[str] = set()
|
||||
self.lock = utils.safe_get_lock()
|
||||
self.stop_event = utils.safe_get_stop_event()
|
||||
self.cookie_id = secrets.token_urlsafe(32)
|
||||
self.queue_token = secrets.token_urlsafe(32)
|
||||
self.startup_events_triggered = False
|
||||
@ -606,8 +607,7 @@ class App(FastAPI):
|
||||
return "wait"
|
||||
|
||||
async def stop_stream():
|
||||
while app.get_blocks().is_running:
|
||||
await asyncio.sleep(0.25)
|
||||
await app.stop_event.wait()
|
||||
return "stop"
|
||||
|
||||
async def iterator():
|
||||
|
@ -89,6 +89,14 @@ def safe_get_lock() -> asyncio.Lock:
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
def safe_get_stop_event() -> asyncio.Event:
|
||||
try:
|
||||
asyncio.get_event_loop()
|
||||
return asyncio.Event()
|
||||
except RuntimeError:
|
||||
return None # type: ignore
|
||||
|
||||
|
||||
class BaseReloader(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
|
@ -29,6 +29,6 @@ test("when a user closes the page, the unload event should be triggered", async
|
||||
expect(data).toContain("incremented 1");
|
||||
expect(data).toContain("incremented 2");
|
||||
expect(data).toContain("incremented 3");
|
||||
expect(data).toContain("deleted 4");
|
||||
expect(data).toContain("unloading");
|
||||
expect(data).toContain("deleted 4");
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user