Delete user state when they close the tab. Add an unload event for the demo and a delete_callback on gr.State to let developers control how resources are cleaned up (#7829)

* Delete state

* add changeset

* Delete state

* WIP

* Add load event

* Working ttl

* unload e2e test

* Clean up

* add changeset

* Fix notebook

* add changeset

* Connect to heartbeat in python client

* 15 second heartbeat

* Demo for unload

* Add notebook

* add changeset

* Fix docs

* revert demo changes

* Add docstrings

* lint 🙄

* Edit

* handle shutdown issue

* state comments

* client test

* Fix:

* Fix e2e test

* 3.11 incompatibility

* delete after one hour

* lint + highlight

* Update .changeset/better-tires-shave.md

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Update .changeset/better-tires-shave.md

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2024-04-01 15:31:56 -07:00 committed by GitHub
parent 83010a290a
commit 6a4bf7abe2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 482 additions and 45 deletions

View File

@ -0,0 +1,42 @@
---
"@gradio/client": minor
"gradio": minor
"gradio_client": minor
---
highlight:
#### Automatically delete state after user has disconnected from the webpage
Gradio now automatically deletes `gr.State` variables stored in the server's RAM when users close their browser tab.
The deletion will happen 60 minutes after the server detected a disconnect from the user's browser.
If the user connects again in that timeframe, their state will not be deleted.
Additionally, Gradio now includes a `Blocks.unload()` event, allowing you to run arbitrary cleanup functions when users disconnect (this does not have a 60 minute delay).
You can think of the `unload` event as the opposite of the `load` event.
```python
with gr.Blocks() as demo:
gr.Markdown(
"""# State Cleanup Demo
🖼️ Images are saved in a user-specific directory and deleted when the users closes the page via demo.unload.
""")
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
img = gr.Image(label="Generated Image", height=300, width=300)
with gr.Row():
gen = gr.Button(value="Generate")
with gr.Row():
history = gr.Gallery(label="Previous Generations", height=500, columns=10)
state = gr.State(value=[], delete_callback=lambda v: print("STATE DELETED"))
demo.load(generate_random_img, [state], [img, state, history])
gen.click(generate_random_img, [state], [img, state, history])
demo.unload(delete_directory)
demo.launch(auth=lambda user,pwd: True,
auth_message="Enter any username and password to continue")
```

View File

@ -64,7 +64,8 @@ function spawn_gradio_app(app, port, verbose) {
...process.env,
GRADIO_SERVER_PORT: `7879`,
PYTHONUNBUFFERED: "true",
GRADIO_ANALYTICS_ENABLED: "False"
GRADIO_ANALYTICS_ENABLED: "False",
GRADIO_IS_E2E_TEST: "1"
}
});
_process.stdout.setEncoding("utf8");

1
.gitignore vendored
View File

@ -45,6 +45,7 @@ demo/*/config.json
demo/annotatedimage_component/*.png
demo/fake_diffusion_with_gif/*.gif
demo/cancel_events/cancel_events_output_log.txt
demo/unload_event_test/output_log.txt
# Etc
.idea/*

View File

@ -357,6 +357,10 @@ export function api_factory(
);
const _config = await config_success(config);
// connect to the heartbeat endpoint via GET request
const heartbeat = new EventSource(
`${config.root}/heartbeat/${session_hash}`
);
res(_config);
} catch (e) {
console.error(e);

View File

@ -159,6 +159,7 @@ class Client:
self.sse_url = urllib.parse.urljoin(
self.src, utils.SSE_URL_V0 if self.protocol == "sse" else utils.SSE_URL
)
self.heartbeat_url = urllib.parse.urljoin(self.src, utils.HEARTBEAT_URL)
self.sse_data_url = urllib.parse.urljoin(
self.src,
utils.SSE_DATA_URL_V0 if self.protocol == "sse" else utils.SSE_DATA_URL,
@ -184,13 +185,43 @@ class Client:
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
# Disable telemetry by setting the env variable HF_HUB_DISABLE_TELEMETRY=1
threading.Thread(target=self._telemetry_thread).start()
threading.Thread(target=self._telemetry_thread, daemon=True).start()
self._refresh_heartbeat = threading.Event()
self._kill_heartbeat = threading.Event()
self.heartbeat = threading.Thread(target=self._stream_heartbeat, daemon=True)
self.heartbeat.start()
self.stream_open = False
self.streaming_future: Future | None = None
self.pending_messages_per_event: dict[str, list[Message | None]] = {}
self.pending_event_ids: set[str] = set()
def close(self):
self._kill_heartbeat.set()
self.heartbeat.join(timeout=1)
def _stream_heartbeat(self):
while True:
url = self.heartbeat_url.format(session_hash=self.session_hash)
try:
with httpx.stream(
"GET",
url,
headers=self.headers,
cookies=self.cookies,
verify=self.ssl_verify,
timeout=20,
) as response:
for _ in response.iter_lines():
if self._refresh_heartbeat.is_set():
self._refresh_heartbeat.clear()
break
if self._kill_heartbeat.is_set():
return
except httpx.TransportError:
return
async def stream_messages(
self, protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"]
) -> None:
@ -640,6 +671,7 @@ class Client:
def reset_session(self) -> None:
self.session_hash = str(uuid.uuid4())
self._refresh_heartbeat.set()
def _render_endpoints_info(
self,

View File

@ -42,6 +42,7 @@ RAW_API_INFO_URL = "info?serialize=False"
SPACE_FETCHER_URL = "https://gradio-space-api-fetcher-v2.hf.space/api"
RESET_URL = "reset"
SPACE_URL = "https://hf.space/{}"
HEARTBEAT_URL = "heartbeat/{session_hash}"
STATE_COMPONENT = "state"
INVALID_RUNTIME = [

View File

@ -75,10 +75,11 @@ def calculator_demo_with_defaults():
@pytest.fixture
def state_demo():
state = gr.State(delete_callback=lambda x: print("STATE DELETED"))
demo = gr.Interface(
lambda x, y: (x, y),
["textbox", "state"],
["textbox", "state"],
["textbox", state],
["textbox", state],
)
return demo

View File

@ -46,11 +46,7 @@ def connect(
# because we should set a timeout
# the tests that call .cancel() can get stuck
# waiting for the thread to join
if demo.enable_queue:
demo._queue.close()
demo.is_running = False
demo.server.should_exit = True
demo.server.thread.join(timeout=1)
demo.close()
class TestClientInitialization:
@ -608,6 +604,15 @@ class TestClientPredictions:
pred = client.predict(api_name="/predict")
assert pred[0] == data[0]
def test_state_reset_when_session_changes(self, capsys, state_demo, monkeypatch):
monkeypatch.setenv("GRADIO_IS_E2E_TEST", "1")
with connect(state_demo) as client:
client.predict("Hello", api_name="/predict")
client.reset_session()
time.sleep(5)
out = capsys.readouterr().out
assert "STATE DELETED" in out
class TestClientPredictionsWithKwargs:
def test_no_default_params(self, calculator_demo):

View File

@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: state_cleanup"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["from __future__ import annotations\n", "import gradio as gr\n", "import numpy as np\n", "from PIL import Image\n", "from pathlib import Path\n", "import secrets\n", "import shutil\n", "\n", "current_dir = Path(__file__).parent\n", "\n", "\n", "def generate_random_img(history: list[Image.Image], request: gr.Request):\n", " \"\"\"Generate a random red, green, blue, orange, yellor or purple image.\"\"\"\n", " colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 165, 0), (255, 255, 0), (128, 0, 128)]\n", " color = colors[np.random.randint(0, len(colors))]\n", " img = Image.new('RGB', (100, 100), color)\n", " \n", " user_dir: Path = current_dir / request.username # type: ignore\n", " user_dir.mkdir(exist_ok=True)\n", " path = user_dir / f\"{secrets.token_urlsafe(8)}.webp\"\n", "\n", " img.save(path)\n", " history.append(img)\n", "\n", " return img, history, history\n", "\n", "def delete_directory(req: gr.Request):\n", " if not req.username:\n", " return\n", " user_dir: Path = current_dir / req.username\n", " shutil.rmtree(str(user_dir))\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"\"\"# State Cleanup Demo\n", " \ud83d\uddbc\ufe0f Images are saved in a user-specific directory and deleted when the users closes the page via demo.unload.\n", " \"\"\")\n", " with gr.Row():\n", " with gr.Column(scale=1):\n", " with gr.Row():\n", " img = gr.Image(label=\"Generated Image\", height=300, width=300)\n", " with gr.Row():\n", " gen = gr.Button(value=\"Generate\")\n", " with gr.Row():\n", " history = gr.Gallery(label=\"Previous Generations\", height=500, columns=10)\n", " state = gr.State(value=[], delete_callback=lambda v: print(\"STATE DELETED\"))\n", "\n", " demo.load(generate_random_img, [state], [img, state, history]) \n", " gen.click(generate_random_img, [state], [img, state, history])\n", " demo.unload(delete_directory)\n", "\n", "\n", "demo.launch(auth=lambda user,pwd: True,\n", " auth_message=\"Enter any username and password to continue\")"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

53
demo/state_cleanup/run.py Normal file
View File

@ -0,0 +1,53 @@
from __future__ import annotations
import gradio as gr
import numpy as np
from PIL import Image
from pathlib import Path
import secrets
import shutil
current_dir = Path(__file__).parent
def generate_random_img(history: list[Image.Image], request: gr.Request):
"""Generate a random red, green, blue, orange, yellor or purple image."""
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 165, 0), (255, 255, 0), (128, 0, 128)]
color = colors[np.random.randint(0, len(colors))]
img = Image.new('RGB', (100, 100), color)
user_dir: Path = current_dir / request.username # type: ignore
user_dir.mkdir(exist_ok=True)
path = user_dir / f"{secrets.token_urlsafe(8)}.webp"
img.save(path)
history.append(img)
return img, history, history
def delete_directory(req: gr.Request):
if not req.username:
return
user_dir: Path = current_dir / req.username
shutil.rmtree(str(user_dir))
with gr.Blocks() as demo:
gr.Markdown("""# State Cleanup Demo
🖼 Images are saved in a user-specific directory and deleted when the users closes the page via demo.unload.
""")
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
img = gr.Image(label="Generated Image", height=300, width=300)
with gr.Row():
gen = gr.Button(value="Generate")
with gr.Row():
history = gr.Gallery(label="Previous Generations", height=500, columns=10)
state = gr.State(value=[], delete_callback=lambda v: print("STATE DELETED"))
demo.load(generate_random_img, [state], [img, state, history])
gen.click(generate_random_img, [state], [img, state, history])
demo.unload(delete_directory)
demo.launch(auth=lambda user,pwd: True,
auth_message="Enter any username and password to continue")

View File

@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: unload_event_test"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["\"\"\"This demo is only meant to test the unload event.\n", "It will write to a file when the unload event is triggered.\n", "May not work as expected if multiple people are using it.\n", "\"\"\"\n", "import gradio as gr\n", "from pathlib import Path\n", "\n", "log_file = (Path(__file__).parent / \"output_log.txt\").resolve()\n", "\n", "\n", "def test_fn(x):\n", " with open(log_file, \"a\") as f:\n", " f.write(f\"incremented {x}\\n\")\n", " return x + 1, x + 1\n", "\n", "def delete_fn(v):\n", " with log_file.open(\"a\") as f:\n", " f.write(f\"deleted {v}\\n\")\n", "\n", "def unload_fn():\n", " with log_file.open(\"a\") as f:\n", " f.write(f\"unloading\\n\")\n", "\n", "with gr.Blocks() as demo:\n", " n1 = gr.Number(value=0, label=\"Number\")\n", " state = gr.State(value=0, delete_callback=delete_fn)\n", " button = gr.Button(\"Increment\")\n", " button.click(test_fn, [state], [n1, state], api_name=\"increment\")\n", " demo.unload(unload_fn)\n", " demo.load(lambda: log_file.write_text(\"\"))\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}

View File

@ -0,0 +1,33 @@
"""This demo is only meant to test the unload event.
It will write to a file when the unload event is triggered.
May not work as expected if multiple people are using it.
"""
import gradio as gr
from pathlib import Path
log_file = (Path(__file__).parent / "output_log.txt").resolve()
def test_fn(x):
with open(log_file, "a") as f:
f.write(f"incremented {x}\n")
return x + 1, x + 1
def delete_fn(v):
with log_file.open("a") as f:
f.write(f"deleted {v}\n")
def unload_fn():
with log_file.open("a") as f:
f.write(f"unloading\n")
with gr.Blocks() as demo:
n1 = gr.Number(value=0, label="Number")
state = gr.State(value=0, delete_callback=delete_fn)
button = gr.Button("Increment")
button.click(test_fn, [state], [n1, state], api_name="increment")
demo.unload(unload_fn)
demo.load(lambda: log_file.write_text(""))
if __name__ == "__main__":
demo.launch()

View File

@ -128,6 +128,10 @@ class Block:
if render:
self.render()
@property
def stateful(self):
return False
@property
def skip_api(self):
return False
@ -506,7 +510,7 @@ def convert_component_dict_to_list(
return predictions
@document("launch", "queue", "integrate", "load")
@document("launch", "queue", "integrate", "load", "unload")
class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
"""
Blocks is Gradio's low-level API that allows you to create more custom web
@ -878,6 +882,43 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
for block in self.blocks.values()
)
def unload(self, fn: Callable):
"""This listener is triggered when the user closes or refreshes the tab, ending the user session.
It is useful for cleaning up resources when the app is closed.
Parameters:
fn: Callable function to run to clear resources. The function should not take any arguments and the output is not used.
Example:
import gradio as gr
with gr.Blocks() as demo:
gr.Markdown("# When you close the tab, hello will be printed to the console")
demo.unload(lambda: print("hello"))
demo.launch()
"""
self.set_event_trigger(
targets=[EventListenerMethod(None, "unload")],
fn=fn,
inputs=None,
outputs=None,
preprocess=False,
postprocess=False,
show_progress="hidden",
api_name=None,
js=None,
no_target=True,
queue=None,
batch=False,
max_batch_size=4,
cancels=None,
every=None,
collects_event_data=None,
trigger_after=None,
trigger_only_on_success=False,
trigger_mode="once",
concurrency_limit="default",
concurrency_id=None,
show_api=False,
)
def set_event_trigger(
self,
targets: Sequence[EventListenerMethod],
@ -1498,7 +1539,7 @@ Received outputs:
f"Output component with id {output_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
) from e
if getattr(block, "stateful", False):
if block.stateful:
if not utils.is_update(predictions[i]):
state[output_id] = predictions[i]
output.append(None)
@ -2392,9 +2433,10 @@ Received outputs:
self._queue._cancel_asyncio_tasks()
self.server_app._cancel_asyncio_tasks()
self._queue.close()
# set this before closing server to shut down heartbeats
self.is_running = False
if self.server:
self.server.close()
self.is_running = False
# So that the startup events (starting the queue)
# happen the next time the app is launched
self.app.startup_events_triggered = False
@ -2448,6 +2490,7 @@ Received outputs:
self._queue.start()
# So that processing can resume in case the queue was stopped
self._queue.stopped = False
self.is_running = True
self.create_limiter()
def queue_enabled_for_fn(self, fn_index: int):

View File

@ -2,8 +2,9 @@
from __future__ import annotations
import math
from copy import deepcopy
from typing import Any
from typing import Any, Callable
from gradio_client.documentation import document
@ -16,8 +17,7 @@ class State(Component):
"""
Special hidden component that stores session state across runs of the demo by the
same user. The value of the State variable is cleared when the user refreshes the page.
Demos: interface_state, blocks_simple_squares
Demos: interface_state, blocks_simple_squares, state_cleanup
Guides: real-time-speech-recognition
"""
@ -27,13 +27,21 @@ class State(Component):
self,
value: Any = None,
render: bool = True,
*,
time_to_live: int | float | None = None,
delete_callback: Callable[[Any], None] | None = None,
):
"""
Parameters:
value: the initial value (of arbitrary type) of the state. The provided argument is deepcopied. If a callable is provided, the function will be called whenever the app loads to set the initial value of the state.
render: has no effect, but is included for consistency with other components.
time_to_live: The number of seconds the state should be stored for after it is created or updated. If None, the state will be stored indefinitely. Gradio automatically deletes state variables after a user closes the browser tab or refreshes the page, so this is useful for clearing state for potentially long running sessions.
delete_callback: A function that is called when the state is deleted. The function should take the state value as an argument.
"""
self.stateful = True
self.time_to_live = self.time_to_live = (
math.inf if time_to_live is None else time_to_live
)
self.delete_callback = delete_callback or (lambda a: None) # noqa: ARG005
try:
self.value = deepcopy(value)
except TypeError as err:
@ -42,6 +50,10 @@ class State(Component):
) from err
super().__init__(value=self.value, render=render)
@property
def stateful(self):
return True
def preprocess(self, payload: Any) -> Any:
"""
Parameters:

View File

@ -65,7 +65,7 @@ class Server(uvicorn.Server):
if self.reloader:
self.reloader.stop()
self.watch_thread.join()
self.thread.join()
self.thread.join(timeout=5)
def start_server(

View File

@ -9,7 +9,7 @@ import os
import re
import shutil
from collections import deque
from contextlib import asynccontextmanager
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass as python_dataclass
from datetime import datetime
from pathlib import Path
@ -734,7 +734,6 @@ class CustomCORSMiddleware:
def delete_files_created_by_app(blocks: Blocks, age: int | None) -> None:
"""Delete files that are older than age. If age is None, delete all files."""
dont_delete = set()
for component in blocks.blocks.values():
dont_delete.update(getattr(component, "keep_in_cache", set()))
@ -770,27 +769,40 @@ async def _lifespan_handler(
app: App, frequency: int = 1, age: int = 1
) -> AsyncGenerator:
"""A context manager that triggers the startup and shutdown events of the app."""
app.get_blocks().startup_events()
app.startup_events_triggered = True
asyncio.create_task(delete_files_on_schedule(app, frequency, age))
yield
delete_files_created_by_app(app.get_blocks(), age=None)
async def _delete_state(app: App):
"""Delete all expired state every second."""
while True:
app.state_holder.delete_all_expired_state()
await asyncio.sleep(1)
@asynccontextmanager
async def _delete_state_handler(app: App):
"""When the server launches, regularly delete expired state."""
asyncio.create_task(_delete_state(app))
yield
def create_lifespan_handler(
user_lifespan: Callable[[App], AsyncContextManager] | None,
frequency: int = 1,
age: int = 1,
frequency: int | None = 1,
age: int | None = 1,
) -> Callable[[App], AsyncContextManager]:
"""Return a context manager that applies _lifespan_handler and user_lifespan if it exists."""
@asynccontextmanager
async def _handler(app: App):
async with _lifespan_handler(app, frequency, age):
async with AsyncExitStack() as stack:
await stack.enter_async_context(_delete_state_handler(app))
if frequency and age:
await stack.enter_async_context(_lifespan_handler(app, frequency, age))
if user_lifespan is not None:
async with user_lifespan(app):
yield
else:
yield
await stack.enter_async_context(user_lifespan(app))
yield
return _handler

View File

@ -222,10 +222,10 @@ class App(FastAPI):
) -> App:
app_kwargs = app_kwargs or {}
app_kwargs.setdefault("default_response_class", ORJSONResponse)
if blocks.delete_cache is not None:
app_kwargs["lifespan"] = create_lifespan_handler(
app_kwargs.get("lifespan", None), *blocks.delete_cache
)
delete_cache = blocks.delete_cache or (None, None)
app_kwargs["lifespan"] = create_lifespan_handler(
app_kwargs.get("lifespan", None), *delete_cache
)
app = App(auth_dependency=auth_dependency, **app_kwargs)
app.configure_app(blocks)
@ -589,6 +589,75 @@ class App(FastAPI):
await app.get_blocks()._queue.clean_events(event_id=body.event_id)
return {"success": True}
@app.get("/heartbeat/{session_hash}")
def heartbeat(
session_hash: str,
request: fastapi.Request,
background_tasks: BackgroundTasks,
username: str = Depends(get_current_user),
):
"""Clients make a persistent connection to this endpoint to keep the session alive.
When the client disconnects, the session state is deleted.
"""
heartbeat_rate = 0.25 if os.getenv("GRADIO_IS_E2E_TEST", None) else 15
async def wait():
await asyncio.sleep(heartbeat_rate)
return "wait"
async def stop_stream():
while app.get_blocks().is_running:
await asyncio.sleep(0.25)
return "stop"
async def iterator():
while True:
try:
yield "data: ALIVE\n\n"
# We need to close the heartbeat connections as soon as the server stops
# otherwise the server can take forever to close
wait_task = asyncio.create_task(wait())
stop_stream_task = asyncio.create_task(stop_stream())
done, _ = await asyncio.wait(
[wait_task, stop_stream_task],
return_when=asyncio.FIRST_COMPLETED,
)
done = [d.result() for d in done]
if "stop" in done:
raise asyncio.CancelledError()
except asyncio.CancelledError:
req = Request(request, username)
root_path = route_utils.get_root_url(
request=request,
route_path=f"/hearbeat/{session_hash}",
root_path=app.root_path,
)
body = PredictBody(
session_hash=session_hash, data=[], request=request
)
unload_fn_indices = [
i
for i, dep in enumerate(app.get_blocks().dependencies)
if any(t for t in dep["targets"] if t[1] == "unload")
]
for fn_index in unload_fn_indices:
# The task runnning this loop has been cancelled
# so we add tasks in the background
background_tasks.add_task(
route_utils.call_process_api,
app=app,
body=body,
gr_request=req,
fn_index_inferred=fn_index,
root_path=root_path,
)
# This will mark the state to be deleted in an hour
if session_hash in app.state_holder.session_data:
app.state_holder.session_data[session_hash].is_closed = True
return
return StreamingResponse(iterator(), media_type="text/event-stream")
# had to use '/run' endpoint for Colab compatibility, '/api' supported for backwards compatibility
@app.post("/run/{api_name}", dependencies=[Depends(login_check)])
@app.post("/run/{api_name}/", dependencies=[Depends(login_check)])
@ -1098,8 +1167,9 @@ def mount_gradio_app(
async with old_lifespan(
app
): # Instert the startup events inside the FastAPI context manager
gradio_app.get_blocks().startup_events()
yield
async with gradio_app.router.lifespan_context(gradio_app):
gradio_app.get_blocks().startup_events()
yield
app.router.lifespan_context = new_lifespan

View File

@ -1,18 +1,22 @@
from __future__ import annotations
import datetime
import os
import threading
from collections import OrderedDict
from copy import deepcopy
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Iterator
if TYPE_CHECKING:
from gradio.blocks import Blocks
from gradio.components import State
class StateHolder:
def __init__(self):
self.capacity = 10000
self.session_data = OrderedDict()
self.session_data: OrderedDict[str, SessionState] = OrderedDict()
self.time_last_used: dict[str, datetime.datetime] = {}
self.lock = threading.Lock()
def set_blocks(self, blocks: Blocks):
@ -29,6 +33,7 @@ class StateHolder:
if session_id not in self.session_data:
self.session_data[session_id] = SessionState(self.blocks)
self.update(session_id)
self.time_last_used[session_id] = datetime.datetime.now()
return self.session_data[session_id]
def __contains__(self, session_id: str):
@ -41,11 +46,36 @@ class StateHolder:
if len(self.session_data) > self.capacity:
self.session_data.popitem(last=False)
def delete_all_expired_state(
self,
):
for session_id in self.session_data:
self.delete_state(session_id, expired_only=True)
def delete_state(self, session_id: str, expired_only: bool = False):
if session_id not in self.session_data:
return
to_delete = []
session_state = self.session_data[session_id]
for component, value, expired in session_state.state_components:
if not expired_only or expired:
component.delete_callback(value)
to_delete.append(component._id)
for component in to_delete:
del session_state._data[component]
class SessionState:
def __init__(self, blocks: Blocks):
self.blocks = blocks
self._data = {}
self._state_ttl = {}
self.is_closed = False
# When a session is closed, the state is stored for an hour to give the user time to reopen the session.
# During testing we set to a lower value to be able to test
self.STATE_TTL_WHEN_CLOSED = (
1 if os.getenv("GRADIO_IS_E2E_TEST", None) else 3600
)
def __getitem__(self, key: int) -> Any:
if key not in self._data:
@ -57,7 +87,32 @@ class SessionState:
return self._data[key]
def __setitem__(self, key: int, value: Any):
from gradio.components import State
block = self.blocks.blocks[key]
if isinstance(block, State):
self._state_ttl[key] = (
block.time_to_live,
datetime.datetime.now(),
)
self._data[key] = value
def __contains__(self, key: int):
return key in self._data
@property
def state_components(self) -> Iterator[tuple[State, Any, bool]]:
from gradio.components import State
for id in self._data:
block = self.blocks.blocks[id]
if isinstance(block, State) and id in self._state_ttl:
time_to_live, created_at = self._state_ttl[id]
if self.is_closed:
time_to_live = self.STATE_TTL_WHEN_CLOSED
value = self._data[id]
yield (
block,
value,
(datetime.datetime.now() - created_at).seconds > time_to_live,
)

View File

@ -0,0 +1,34 @@
import { test, expect } from "@gradio/tootils";
import { readFileSync } from "fs";
test("when a user closes the page, the unload event should be triggered", async ({
page
}) => {
const increment = await page.locator("button", {
hasText: /Increment/
});
// if you click too fast, the page may close before the event is processed
await increment.click();
await page.waitForTimeout(100);
await increment.click();
await page.waitForTimeout(100);
await increment.click();
await page.waitForTimeout(100);
await increment.click();
await expect(page.getByLabel("Number")).toHaveValue("4");
await page.close();
await new Promise((resolve) => setTimeout(resolve, 5000));
const data = readFileSync(
"../../demo/unload_event_test/output_log.txt",
"utf-8"
);
expect(data).toContain("incremented 0");
expect(data).toContain("incremented 1");
expect(data).toContain("incremented 2");
expect(data).toContain("incremented 3");
expect(data).toContain("deleted 4");
expect(data).toContain("unloading");
});

View File

@ -49,16 +49,11 @@ def connect():
def _connect(demo: gr.Blocks, serialize=True, **kwargs):
_, local_url, _ = demo.launch(prevent_thread_lock=True, **kwargs)
try:
yield Client(local_url, serialize=serialize)
client = Client(local_url, serialize=serialize)
yield client
finally:
# A more verbose version of .close()
# because we should set a timeout
# the tests that call .cancel() can get stuck
# waiting for the thread to join
demo._queue.close()
demo.is_running = False
demo.server.should_exit = True
demo.server.thread.join(timeout=1)
client.close()
demo.close()
return _connect

View File

@ -1726,3 +1726,44 @@ def test_static_files_multiple_apps(gradio_temp_dir):
# Input/Output got saved to cache
assert len(list(gradio_temp_dir.glob("**/*.*"))) == 0
def test_time_to_live_and_delete_callback_for_state(capsys, monkeypatch):
monkeypatch.setenv("GRADIO_IS_E2E_TEST", 1)
def test_fn(x):
return x + 1, x + 1
def delete_fn(v):
print(f"deleted {v}")
with gr.Blocks() as demo:
n1 = gr.Number(value=0)
state = gr.State(
value=0, time_to_live=1, delete_callback=lambda v: delete_fn(v)
)
button = gr.Button("Increment")
button.click(test_fn, [state], [n1, state], api_name="increment")
app, url, _ = demo.launch(prevent_thread_lock=True)
try:
client_1 = grc.Client(url)
client_2 = grc.Client(url)
client_1.predict(api_name="/increment")
client_1.predict(api_name="/increment")
client_1.predict(api_name="/increment")
client_2.predict(api_name="/increment")
client_2.predict(api_name="/increment")
time.sleep(3)
captured = capsys.readouterr()
assert "deleted 2" in captured.out
assert "deleted 3" in captured.out
for client in [client_1, client_2]:
assert len(app.state_holder.session_data[client.session_hash]._data) == 0
finally:
demo.close()