mirror of
https://github.com/gradio-app/gradio.git
synced 2025-03-31 12:20:26 +08:00
Delete user state when they close the tab. Add an unload event for the demo and a delete_callback on gr.State to let developers control how resources are cleaned up (#7829)
* Delete state
* add changeset
* Delete state
* WIP
* Add load event
* Working ttl
* unload e2e test
* Clean up
* add changeset
* Fix notebook
* add changeset
* Connect to heartbeat in python client
* 15 second heartbeat
* Demo for unload
* Add notebook
* add changeset
* Fix docs
* revert demo changes
* Add docstrings
* lint 🙄
* Edit
* handle shutdown issue
* state comments
* client test
* Fix:
* Fix e2e test
* 3.11 incompatibility
* delete after one hour
* lint + highlight
* Update .changeset/better-tires-shave.md
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
* Update .changeset/better-tires-shave.md
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
---------
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
83010a290a
commit
6a4bf7abe2
42
.changeset/better-tires-shave.md
Normal file
42
.changeset/better-tires-shave.md
Normal file
@ -0,0 +1,42 @@
|
||||
---
|
||||
"@gradio/client": minor
|
||||
"gradio": minor
|
||||
"gradio_client": minor
|
||||
---
|
||||
|
||||
highlight:
|
||||
|
||||
#### Automatically delete state after user has disconnected from the webpage
|
||||
|
||||
Gradio now automatically deletes `gr.State` variables stored in the server's RAM when users close their browser tab.
|
||||
The deletion will happen 60 minutes after the server detected a disconnect from the user's browser.
|
||||
If the user connects again in that timeframe, their state will not be deleted.
|
||||
|
||||
Additionally, Gradio now includes a `Blocks.unload()` event, allowing you to run arbitrary cleanup functions when users disconnect (this does not have a 60 minute delay).
|
||||
You can think of the `unload` event as the opposite of the `load` event.
|
||||
|
||||
|
||||
```python
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(
|
||||
"""# State Cleanup Demo
|
||||
🖼️ Images are saved in a user-specific directory and deleted when the users closes the page via demo.unload.
|
||||
""")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
with gr.Row():
|
||||
img = gr.Image(label="Generated Image", height=300, width=300)
|
||||
with gr.Row():
|
||||
gen = gr.Button(value="Generate")
|
||||
with gr.Row():
|
||||
history = gr.Gallery(label="Previous Generations", height=500, columns=10)
|
||||
state = gr.State(value=[], delete_callback=lambda v: print("STATE DELETED"))
|
||||
|
||||
demo.load(generate_random_img, [state], [img, state, history])
|
||||
gen.click(generate_random_img, [state], [img, state, history])
|
||||
demo.unload(delete_directory)
|
||||
|
||||
|
||||
demo.launch(auth=lambda user,pwd: True,
|
||||
auth_message="Enter any username and password to continue")
|
||||
```
|
@ -64,7 +64,8 @@ function spawn_gradio_app(app, port, verbose) {
|
||||
...process.env,
|
||||
GRADIO_SERVER_PORT: `7879`,
|
||||
PYTHONUNBUFFERED: "true",
|
||||
GRADIO_ANALYTICS_ENABLED: "False"
|
||||
GRADIO_ANALYTICS_ENABLED: "False",
|
||||
GRADIO_IS_E2E_TEST: "1"
|
||||
}
|
||||
});
|
||||
_process.stdout.setEncoding("utf8");
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -45,6 +45,7 @@ demo/*/config.json
|
||||
demo/annotatedimage_component/*.png
|
||||
demo/fake_diffusion_with_gif/*.gif
|
||||
demo/cancel_events/cancel_events_output_log.txt
|
||||
demo/unload_event_test/output_log.txt
|
||||
|
||||
# Etc
|
||||
.idea/*
|
||||
|
@ -357,6 +357,10 @@ export function api_factory(
|
||||
);
|
||||
|
||||
const _config = await config_success(config);
|
||||
// connect to the heartbeat endpoint via GET request
|
||||
const heartbeat = new EventSource(
|
||||
`${config.root}/heartbeat/${session_hash}`
|
||||
);
|
||||
res(_config);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
|
@ -159,6 +159,7 @@ class Client:
|
||||
self.sse_url = urllib.parse.urljoin(
|
||||
self.src, utils.SSE_URL_V0 if self.protocol == "sse" else utils.SSE_URL
|
||||
)
|
||||
self.heartbeat_url = urllib.parse.urljoin(self.src, utils.HEARTBEAT_URL)
|
||||
self.sse_data_url = urllib.parse.urljoin(
|
||||
self.src,
|
||||
utils.SSE_DATA_URL_V0 if self.protocol == "sse" else utils.SSE_DATA_URL,
|
||||
@ -184,13 +185,43 @@ class Client:
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
||||
|
||||
# Disable telemetry by setting the env variable HF_HUB_DISABLE_TELEMETRY=1
|
||||
threading.Thread(target=self._telemetry_thread).start()
|
||||
threading.Thread(target=self._telemetry_thread, daemon=True).start()
|
||||
self._refresh_heartbeat = threading.Event()
|
||||
self._kill_heartbeat = threading.Event()
|
||||
|
||||
self.heartbeat = threading.Thread(target=self._stream_heartbeat, daemon=True)
|
||||
self.heartbeat.start()
|
||||
|
||||
self.stream_open = False
|
||||
self.streaming_future: Future | None = None
|
||||
self.pending_messages_per_event: dict[str, list[Message | None]] = {}
|
||||
self.pending_event_ids: set[str] = set()
|
||||
|
||||
def close(self):
|
||||
self._kill_heartbeat.set()
|
||||
self.heartbeat.join(timeout=1)
|
||||
|
||||
def _stream_heartbeat(self):
|
||||
while True:
|
||||
url = self.heartbeat_url.format(session_hash=self.session_hash)
|
||||
try:
|
||||
with httpx.stream(
|
||||
"GET",
|
||||
url,
|
||||
headers=self.headers,
|
||||
cookies=self.cookies,
|
||||
verify=self.ssl_verify,
|
||||
timeout=20,
|
||||
) as response:
|
||||
for _ in response.iter_lines():
|
||||
if self._refresh_heartbeat.is_set():
|
||||
self._refresh_heartbeat.clear()
|
||||
break
|
||||
if self._kill_heartbeat.is_set():
|
||||
return
|
||||
except httpx.TransportError:
|
||||
return
|
||||
|
||||
async def stream_messages(
|
||||
self, protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"]
|
||||
) -> None:
|
||||
@ -640,6 +671,7 @@ class Client:
|
||||
|
||||
def reset_session(self) -> None:
|
||||
self.session_hash = str(uuid.uuid4())
|
||||
self._refresh_heartbeat.set()
|
||||
|
||||
def _render_endpoints_info(
|
||||
self,
|
||||
|
@ -42,6 +42,7 @@ RAW_API_INFO_URL = "info?serialize=False"
|
||||
SPACE_FETCHER_URL = "https://gradio-space-api-fetcher-v2.hf.space/api"
|
||||
RESET_URL = "reset"
|
||||
SPACE_URL = "https://hf.space/{}"
|
||||
HEARTBEAT_URL = "heartbeat/{session_hash}"
|
||||
|
||||
STATE_COMPONENT = "state"
|
||||
INVALID_RUNTIME = [
|
||||
|
@ -75,10 +75,11 @@ def calculator_demo_with_defaults():
|
||||
|
||||
@pytest.fixture
|
||||
def state_demo():
|
||||
state = gr.State(delete_callback=lambda x: print("STATE DELETED"))
|
||||
demo = gr.Interface(
|
||||
lambda x, y: (x, y),
|
||||
["textbox", "state"],
|
||||
["textbox", "state"],
|
||||
["textbox", state],
|
||||
["textbox", state],
|
||||
)
|
||||
return demo
|
||||
|
||||
|
@ -46,11 +46,7 @@ def connect(
|
||||
# because we should set a timeout
|
||||
# the tests that call .cancel() can get stuck
|
||||
# waiting for the thread to join
|
||||
if demo.enable_queue:
|
||||
demo._queue.close()
|
||||
demo.is_running = False
|
||||
demo.server.should_exit = True
|
||||
demo.server.thread.join(timeout=1)
|
||||
demo.close()
|
||||
|
||||
|
||||
class TestClientInitialization:
|
||||
@ -608,6 +604,15 @@ class TestClientPredictions:
|
||||
pred = client.predict(api_name="/predict")
|
||||
assert pred[0] == data[0]
|
||||
|
||||
def test_state_reset_when_session_changes(self, capsys, state_demo, monkeypatch):
|
||||
monkeypatch.setenv("GRADIO_IS_E2E_TEST", "1")
|
||||
with connect(state_demo) as client:
|
||||
client.predict("Hello", api_name="/predict")
|
||||
client.reset_session()
|
||||
time.sleep(5)
|
||||
out = capsys.readouterr().out
|
||||
assert "STATE DELETED" in out
|
||||
|
||||
|
||||
class TestClientPredictionsWithKwargs:
|
||||
def test_no_default_params(self, calculator_demo):
|
||||
|
1
demo/state_cleanup/run.ipynb
Normal file
1
demo/state_cleanup/run.ipynb
Normal file
@ -0,0 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: state_cleanup"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["from __future__ import annotations\n", "import gradio as gr\n", "import numpy as np\n", "from PIL import Image\n", "from pathlib import Path\n", "import secrets\n", "import shutil\n", "\n", "current_dir = Path(__file__).parent\n", "\n", "\n", "def generate_random_img(history: list[Image.Image], request: gr.Request):\n", " \"\"\"Generate a random red, green, blue, orange, yellor or purple image.\"\"\"\n", " colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 165, 0), (255, 255, 0), (128, 0, 128)]\n", " color = colors[np.random.randint(0, len(colors))]\n", " img = Image.new('RGB', (100, 100), color)\n", " \n", " user_dir: Path = current_dir / request.username # type: ignore\n", " user_dir.mkdir(exist_ok=True)\n", " path = user_dir / f\"{secrets.token_urlsafe(8)}.webp\"\n", "\n", " img.save(path)\n", " history.append(img)\n", "\n", " return img, history, history\n", "\n", "def delete_directory(req: gr.Request):\n", " if not req.username:\n", " return\n", " user_dir: Path = current_dir / req.username\n", " shutil.rmtree(str(user_dir))\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"\"\"# State Cleanup Demo\n", " \ud83d\uddbc\ufe0f Images are saved in a user-specific directory and deleted when the users closes the page via demo.unload.\n", " \"\"\")\n", " with gr.Row():\n", " with gr.Column(scale=1):\n", " with gr.Row():\n", " img = gr.Image(label=\"Generated Image\", height=300, width=300)\n", " with gr.Row():\n", " gen = gr.Button(value=\"Generate\")\n", " with gr.Row():\n", " history = gr.Gallery(label=\"Previous Generations\", height=500, columns=10)\n", " state = gr.State(value=[], delete_callback=lambda v: print(\"STATE DELETED\"))\n", "\n", " demo.load(generate_random_img, [state], [img, state, history]) \n", " gen.click(generate_random_img, [state], [img, state, history])\n", " demo.unload(delete_directory)\n", "\n", "\n", "demo.launch(auth=lambda user,pwd: True,\n", " auth_message=\"Enter any username and password to continue\")"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
53
demo/state_cleanup/run.py
Normal file
53
demo/state_cleanup/run.py
Normal file
@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
import secrets
|
||||
import shutil
|
||||
|
||||
current_dir = Path(__file__).parent
|
||||
|
||||
|
||||
def generate_random_img(history: list[Image.Image], request: gr.Request):
|
||||
"""Generate a random red, green, blue, orange, yellor or purple image."""
|
||||
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 165, 0), (255, 255, 0), (128, 0, 128)]
|
||||
color = colors[np.random.randint(0, len(colors))]
|
||||
img = Image.new('RGB', (100, 100), color)
|
||||
|
||||
user_dir: Path = current_dir / request.username # type: ignore
|
||||
user_dir.mkdir(exist_ok=True)
|
||||
path = user_dir / f"{secrets.token_urlsafe(8)}.webp"
|
||||
|
||||
img.save(path)
|
||||
history.append(img)
|
||||
|
||||
return img, history, history
|
||||
|
||||
def delete_directory(req: gr.Request):
|
||||
if not req.username:
|
||||
return
|
||||
user_dir: Path = current_dir / req.username
|
||||
shutil.rmtree(str(user_dir))
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("""# State Cleanup Demo
|
||||
🖼️ Images are saved in a user-specific directory and deleted when the users closes the page via demo.unload.
|
||||
""")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
with gr.Row():
|
||||
img = gr.Image(label="Generated Image", height=300, width=300)
|
||||
with gr.Row():
|
||||
gen = gr.Button(value="Generate")
|
||||
with gr.Row():
|
||||
history = gr.Gallery(label="Previous Generations", height=500, columns=10)
|
||||
state = gr.State(value=[], delete_callback=lambda v: print("STATE DELETED"))
|
||||
|
||||
demo.load(generate_random_img, [state], [img, state, history])
|
||||
gen.click(generate_random_img, [state], [img, state, history])
|
||||
demo.unload(delete_directory)
|
||||
|
||||
|
||||
demo.launch(auth=lambda user,pwd: True,
|
||||
auth_message="Enter any username and password to continue")
|
1
demo/unload_event_test/run.ipynb
Normal file
1
demo/unload_event_test/run.ipynb
Normal file
@ -0,0 +1 @@
|
||||
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: unload_event_test"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["\"\"\"This demo is only meant to test the unload event.\n", "It will write to a file when the unload event is triggered.\n", "May not work as expected if multiple people are using it.\n", "\"\"\"\n", "import gradio as gr\n", "from pathlib import Path\n", "\n", "log_file = (Path(__file__).parent / \"output_log.txt\").resolve()\n", "\n", "\n", "def test_fn(x):\n", " with open(log_file, \"a\") as f:\n", " f.write(f\"incremented {x}\\n\")\n", " return x + 1, x + 1\n", "\n", "def delete_fn(v):\n", " with log_file.open(\"a\") as f:\n", " f.write(f\"deleted {v}\\n\")\n", "\n", "def unload_fn():\n", " with log_file.open(\"a\") as f:\n", " f.write(f\"unloading\\n\")\n", "\n", "with gr.Blocks() as demo:\n", " n1 = gr.Number(value=0, label=\"Number\")\n", " state = gr.State(value=0, delete_callback=delete_fn)\n", " button = gr.Button(\"Increment\")\n", " button.click(test_fn, [state], [n1, state], api_name=\"increment\")\n", " demo.unload(unload_fn)\n", " demo.load(lambda: log_file.write_text(\"\"))\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
|
33
demo/unload_event_test/run.py
Normal file
33
demo/unload_event_test/run.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""This demo is only meant to test the unload event.
|
||||
It will write to a file when the unload event is triggered.
|
||||
May not work as expected if multiple people are using it.
|
||||
"""
|
||||
import gradio as gr
|
||||
from pathlib import Path
|
||||
|
||||
log_file = (Path(__file__).parent / "output_log.txt").resolve()
|
||||
|
||||
|
||||
def test_fn(x):
|
||||
with open(log_file, "a") as f:
|
||||
f.write(f"incremented {x}\n")
|
||||
return x + 1, x + 1
|
||||
|
||||
def delete_fn(v):
|
||||
with log_file.open("a") as f:
|
||||
f.write(f"deleted {v}\n")
|
||||
|
||||
def unload_fn():
|
||||
with log_file.open("a") as f:
|
||||
f.write(f"unloading\n")
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
n1 = gr.Number(value=0, label="Number")
|
||||
state = gr.State(value=0, delete_callback=delete_fn)
|
||||
button = gr.Button("Increment")
|
||||
button.click(test_fn, [state], [n1, state], api_name="increment")
|
||||
demo.unload(unload_fn)
|
||||
demo.load(lambda: log_file.write_text(""))
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch()
|
@ -128,6 +128,10 @@ class Block:
|
||||
if render:
|
||||
self.render()
|
||||
|
||||
@property
|
||||
def stateful(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def skip_api(self):
|
||||
return False
|
||||
@ -506,7 +510,7 @@ def convert_component_dict_to_list(
|
||||
return predictions
|
||||
|
||||
|
||||
@document("launch", "queue", "integrate", "load")
|
||||
@document("launch", "queue", "integrate", "load", "unload")
|
||||
class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
"""
|
||||
Blocks is Gradio's low-level API that allows you to create more custom web
|
||||
@ -878,6 +882,43 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
for block in self.blocks.values()
|
||||
)
|
||||
|
||||
def unload(self, fn: Callable):
|
||||
"""This listener is triggered when the user closes or refreshes the tab, ending the user session.
|
||||
It is useful for cleaning up resources when the app is closed.
|
||||
Parameters:
|
||||
fn: Callable function to run to clear resources. The function should not take any arguments and the output is not used.
|
||||
Example:
|
||||
import gradio as gr
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("# When you close the tab, hello will be printed to the console")
|
||||
demo.unload(lambda: print("hello"))
|
||||
demo.launch()
|
||||
"""
|
||||
self.set_event_trigger(
|
||||
targets=[EventListenerMethod(None, "unload")],
|
||||
fn=fn,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
preprocess=False,
|
||||
postprocess=False,
|
||||
show_progress="hidden",
|
||||
api_name=None,
|
||||
js=None,
|
||||
no_target=True,
|
||||
queue=None,
|
||||
batch=False,
|
||||
max_batch_size=4,
|
||||
cancels=None,
|
||||
every=None,
|
||||
collects_event_data=None,
|
||||
trigger_after=None,
|
||||
trigger_only_on_success=False,
|
||||
trigger_mode="once",
|
||||
concurrency_limit="default",
|
||||
concurrency_id=None,
|
||||
show_api=False,
|
||||
)
|
||||
|
||||
def set_event_trigger(
|
||||
self,
|
||||
targets: Sequence[EventListenerMethod],
|
||||
@ -1498,7 +1539,7 @@ Received outputs:
|
||||
f"Output component with id {output_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
|
||||
) from e
|
||||
|
||||
if getattr(block, "stateful", False):
|
||||
if block.stateful:
|
||||
if not utils.is_update(predictions[i]):
|
||||
state[output_id] = predictions[i]
|
||||
output.append(None)
|
||||
@ -2392,9 +2433,10 @@ Received outputs:
|
||||
self._queue._cancel_asyncio_tasks()
|
||||
self.server_app._cancel_asyncio_tasks()
|
||||
self._queue.close()
|
||||
# set this before closing server to shut down heartbeats
|
||||
self.is_running = False
|
||||
if self.server:
|
||||
self.server.close()
|
||||
self.is_running = False
|
||||
# So that the startup events (starting the queue)
|
||||
# happen the next time the app is launched
|
||||
self.app.startup_events_triggered = False
|
||||
@ -2448,6 +2490,7 @@ Received outputs:
|
||||
self._queue.start()
|
||||
# So that processing can resume in case the queue was stopped
|
||||
self._queue.stopped = False
|
||||
self.is_running = True
|
||||
self.create_limiter()
|
||||
|
||||
def queue_enabled_for_fn(self, fn_index: int):
|
||||
|
@ -2,8 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from gradio_client.documentation import document
|
||||
|
||||
@ -16,8 +17,7 @@ class State(Component):
|
||||
"""
|
||||
Special hidden component that stores session state across runs of the demo by the
|
||||
same user. The value of the State variable is cleared when the user refreshes the page.
|
||||
|
||||
Demos: interface_state, blocks_simple_squares
|
||||
Demos: interface_state, blocks_simple_squares, state_cleanup
|
||||
Guides: real-time-speech-recognition
|
||||
"""
|
||||
|
||||
@ -27,13 +27,21 @@ class State(Component):
|
||||
self,
|
||||
value: Any = None,
|
||||
render: bool = True,
|
||||
*,
|
||||
time_to_live: int | float | None = None,
|
||||
delete_callback: Callable[[Any], None] | None = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
value: the initial value (of arbitrary type) of the state. The provided argument is deepcopied. If a callable is provided, the function will be called whenever the app loads to set the initial value of the state.
|
||||
render: has no effect, but is included for consistency with other components.
|
||||
time_to_live: The number of seconds the state should be stored for after it is created or updated. If None, the state will be stored indefinitely. Gradio automatically deletes state variables after a user closes the browser tab or refreshes the page, so this is useful for clearing state for potentially long running sessions.
|
||||
delete_callback: A function that is called when the state is deleted. The function should take the state value as an argument.
|
||||
"""
|
||||
self.stateful = True
|
||||
self.time_to_live = self.time_to_live = (
|
||||
math.inf if time_to_live is None else time_to_live
|
||||
)
|
||||
self.delete_callback = delete_callback or (lambda a: None) # noqa: ARG005
|
||||
try:
|
||||
self.value = deepcopy(value)
|
||||
except TypeError as err:
|
||||
@ -42,6 +50,10 @@ class State(Component):
|
||||
) from err
|
||||
super().__init__(value=self.value, render=render)
|
||||
|
||||
@property
|
||||
def stateful(self):
|
||||
return True
|
||||
|
||||
def preprocess(self, payload: Any) -> Any:
|
||||
"""
|
||||
Parameters:
|
||||
|
@ -65,7 +65,7 @@ class Server(uvicorn.Server):
|
||||
if self.reloader:
|
||||
self.reloader.stop()
|
||||
self.watch_thread.join()
|
||||
self.thread.join()
|
||||
self.thread.join(timeout=5)
|
||||
|
||||
|
||||
def start_server(
|
||||
|
@ -9,7 +9,7 @@ import os
|
||||
import re
|
||||
import shutil
|
||||
from collections import deque
|
||||
from contextlib import asynccontextmanager
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from dataclasses import dataclass as python_dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@ -734,7 +734,6 @@ class CustomCORSMiddleware:
|
||||
|
||||
def delete_files_created_by_app(blocks: Blocks, age: int | None) -> None:
|
||||
"""Delete files that are older than age. If age is None, delete all files."""
|
||||
|
||||
dont_delete = set()
|
||||
for component in blocks.blocks.values():
|
||||
dont_delete.update(getattr(component, "keep_in_cache", set()))
|
||||
@ -770,27 +769,40 @@ async def _lifespan_handler(
|
||||
app: App, frequency: int = 1, age: int = 1
|
||||
) -> AsyncGenerator:
|
||||
"""A context manager that triggers the startup and shutdown events of the app."""
|
||||
app.get_blocks().startup_events()
|
||||
app.startup_events_triggered = True
|
||||
asyncio.create_task(delete_files_on_schedule(app, frequency, age))
|
||||
yield
|
||||
delete_files_created_by_app(app.get_blocks(), age=None)
|
||||
|
||||
|
||||
async def _delete_state(app: App):
|
||||
"""Delete all expired state every second."""
|
||||
while True:
|
||||
app.state_holder.delete_all_expired_state()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _delete_state_handler(app: App):
|
||||
"""When the server launches, regularly delete expired state."""
|
||||
asyncio.create_task(_delete_state(app))
|
||||
yield
|
||||
|
||||
|
||||
def create_lifespan_handler(
|
||||
user_lifespan: Callable[[App], AsyncContextManager] | None,
|
||||
frequency: int = 1,
|
||||
age: int = 1,
|
||||
frequency: int | None = 1,
|
||||
age: int | None = 1,
|
||||
) -> Callable[[App], AsyncContextManager]:
|
||||
"""Return a context manager that applies _lifespan_handler and user_lifespan if it exists."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def _handler(app: App):
|
||||
async with _lifespan_handler(app, frequency, age):
|
||||
async with AsyncExitStack() as stack:
|
||||
await stack.enter_async_context(_delete_state_handler(app))
|
||||
if frequency and age:
|
||||
await stack.enter_async_context(_lifespan_handler(app, frequency, age))
|
||||
if user_lifespan is not None:
|
||||
async with user_lifespan(app):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
await stack.enter_async_context(user_lifespan(app))
|
||||
yield
|
||||
|
||||
return _handler
|
||||
|
@ -222,10 +222,10 @@ class App(FastAPI):
|
||||
) -> App:
|
||||
app_kwargs = app_kwargs or {}
|
||||
app_kwargs.setdefault("default_response_class", ORJSONResponse)
|
||||
if blocks.delete_cache is not None:
|
||||
app_kwargs["lifespan"] = create_lifespan_handler(
|
||||
app_kwargs.get("lifespan", None), *blocks.delete_cache
|
||||
)
|
||||
delete_cache = blocks.delete_cache or (None, None)
|
||||
app_kwargs["lifespan"] = create_lifespan_handler(
|
||||
app_kwargs.get("lifespan", None), *delete_cache
|
||||
)
|
||||
app = App(auth_dependency=auth_dependency, **app_kwargs)
|
||||
app.configure_app(blocks)
|
||||
|
||||
@ -589,6 +589,75 @@ class App(FastAPI):
|
||||
await app.get_blocks()._queue.clean_events(event_id=body.event_id)
|
||||
return {"success": True}
|
||||
|
||||
@app.get("/heartbeat/{session_hash}")
|
||||
def heartbeat(
|
||||
session_hash: str,
|
||||
request: fastapi.Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
username: str = Depends(get_current_user),
|
||||
):
|
||||
"""Clients make a persistent connection to this endpoint to keep the session alive.
|
||||
When the client disconnects, the session state is deleted.
|
||||
"""
|
||||
heartbeat_rate = 0.25 if os.getenv("GRADIO_IS_E2E_TEST", None) else 15
|
||||
|
||||
async def wait():
|
||||
await asyncio.sleep(heartbeat_rate)
|
||||
return "wait"
|
||||
|
||||
async def stop_stream():
|
||||
while app.get_blocks().is_running:
|
||||
await asyncio.sleep(0.25)
|
||||
return "stop"
|
||||
|
||||
async def iterator():
|
||||
while True:
|
||||
try:
|
||||
yield "data: ALIVE\n\n"
|
||||
# We need to close the heartbeat connections as soon as the server stops
|
||||
# otherwise the server can take forever to close
|
||||
wait_task = asyncio.create_task(wait())
|
||||
stop_stream_task = asyncio.create_task(stop_stream())
|
||||
done, _ = await asyncio.wait(
|
||||
[wait_task, stop_stream_task],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
done = [d.result() for d in done]
|
||||
if "stop" in done:
|
||||
raise asyncio.CancelledError()
|
||||
except asyncio.CancelledError:
|
||||
req = Request(request, username)
|
||||
root_path = route_utils.get_root_url(
|
||||
request=request,
|
||||
route_path=f"/hearbeat/{session_hash}",
|
||||
root_path=app.root_path,
|
||||
)
|
||||
body = PredictBody(
|
||||
session_hash=session_hash, data=[], request=request
|
||||
)
|
||||
unload_fn_indices = [
|
||||
i
|
||||
for i, dep in enumerate(app.get_blocks().dependencies)
|
||||
if any(t for t in dep["targets"] if t[1] == "unload")
|
||||
]
|
||||
for fn_index in unload_fn_indices:
|
||||
# The task runnning this loop has been cancelled
|
||||
# so we add tasks in the background
|
||||
background_tasks.add_task(
|
||||
route_utils.call_process_api,
|
||||
app=app,
|
||||
body=body,
|
||||
gr_request=req,
|
||||
fn_index_inferred=fn_index,
|
||||
root_path=root_path,
|
||||
)
|
||||
# This will mark the state to be deleted in an hour
|
||||
if session_hash in app.state_holder.session_data:
|
||||
app.state_holder.session_data[session_hash].is_closed = True
|
||||
return
|
||||
|
||||
return StreamingResponse(iterator(), media_type="text/event-stream")
|
||||
|
||||
# had to use '/run' endpoint for Colab compatibility, '/api' supported for backwards compatibility
|
||||
@app.post("/run/{api_name}", dependencies=[Depends(login_check)])
|
||||
@app.post("/run/{api_name}/", dependencies=[Depends(login_check)])
|
||||
@ -1098,8 +1167,9 @@ def mount_gradio_app(
|
||||
async with old_lifespan(
|
||||
app
|
||||
): # Instert the startup events inside the FastAPI context manager
|
||||
gradio_app.get_blocks().startup_events()
|
||||
yield
|
||||
async with gradio_app.router.lifespan_context(gradio_app):
|
||||
gradio_app.get_blocks().startup_events()
|
||||
yield
|
||||
|
||||
app.router.lifespan_context = new_lifespan
|
||||
|
||||
|
@ -1,18 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Iterator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Blocks
|
||||
from gradio.components import State
|
||||
|
||||
|
||||
class StateHolder:
|
||||
def __init__(self):
|
||||
self.capacity = 10000
|
||||
self.session_data = OrderedDict()
|
||||
self.session_data: OrderedDict[str, SessionState] = OrderedDict()
|
||||
self.time_last_used: dict[str, datetime.datetime] = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def set_blocks(self, blocks: Blocks):
|
||||
@ -29,6 +33,7 @@ class StateHolder:
|
||||
if session_id not in self.session_data:
|
||||
self.session_data[session_id] = SessionState(self.blocks)
|
||||
self.update(session_id)
|
||||
self.time_last_used[session_id] = datetime.datetime.now()
|
||||
return self.session_data[session_id]
|
||||
|
||||
def __contains__(self, session_id: str):
|
||||
@ -41,11 +46,36 @@ class StateHolder:
|
||||
if len(self.session_data) > self.capacity:
|
||||
self.session_data.popitem(last=False)
|
||||
|
||||
def delete_all_expired_state(
|
||||
self,
|
||||
):
|
||||
for session_id in self.session_data:
|
||||
self.delete_state(session_id, expired_only=True)
|
||||
|
||||
def delete_state(self, session_id: str, expired_only: bool = False):
|
||||
if session_id not in self.session_data:
|
||||
return
|
||||
to_delete = []
|
||||
session_state = self.session_data[session_id]
|
||||
for component, value, expired in session_state.state_components:
|
||||
if not expired_only or expired:
|
||||
component.delete_callback(value)
|
||||
to_delete.append(component._id)
|
||||
for component in to_delete:
|
||||
del session_state._data[component]
|
||||
|
||||
|
||||
class SessionState:
|
||||
def __init__(self, blocks: Blocks):
|
||||
self.blocks = blocks
|
||||
self._data = {}
|
||||
self._state_ttl = {}
|
||||
self.is_closed = False
|
||||
# When a session is closed, the state is stored for an hour to give the user time to reopen the session.
|
||||
# During testing we set to a lower value to be able to test
|
||||
self.STATE_TTL_WHEN_CLOSED = (
|
||||
1 if os.getenv("GRADIO_IS_E2E_TEST", None) else 3600
|
||||
)
|
||||
|
||||
def __getitem__(self, key: int) -> Any:
|
||||
if key not in self._data:
|
||||
@ -57,7 +87,32 @@ class SessionState:
|
||||
return self._data[key]
|
||||
|
||||
def __setitem__(self, key: int, value: Any):
|
||||
from gradio.components import State
|
||||
|
||||
block = self.blocks.blocks[key]
|
||||
if isinstance(block, State):
|
||||
self._state_ttl[key] = (
|
||||
block.time_to_live,
|
||||
datetime.datetime.now(),
|
||||
)
|
||||
self._data[key] = value
|
||||
|
||||
def __contains__(self, key: int):
|
||||
return key in self._data
|
||||
|
||||
@property
|
||||
def state_components(self) -> Iterator[tuple[State, Any, bool]]:
|
||||
from gradio.components import State
|
||||
|
||||
for id in self._data:
|
||||
block = self.blocks.blocks[id]
|
||||
if isinstance(block, State) and id in self._state_ttl:
|
||||
time_to_live, created_at = self._state_ttl[id]
|
||||
if self.is_closed:
|
||||
time_to_live = self.STATE_TTL_WHEN_CLOSED
|
||||
value = self._data[id]
|
||||
yield (
|
||||
block,
|
||||
value,
|
||||
(datetime.datetime.now() - created_at).seconds > time_to_live,
|
||||
)
|
||||
|
34
js/app/test/unload_event_test.spec.ts
Normal file
34
js/app/test/unload_event_test.spec.ts
Normal file
@ -0,0 +1,34 @@
|
||||
import { test, expect } from "@gradio/tootils";
|
||||
import { readFileSync } from "fs";
|
||||
|
||||
test("when a user closes the page, the unload event should be triggered", async ({
|
||||
page
|
||||
}) => {
|
||||
const increment = await page.locator("button", {
|
||||
hasText: /Increment/
|
||||
});
|
||||
|
||||
// if you click too fast, the page may close before the event is processed
|
||||
await increment.click();
|
||||
await page.waitForTimeout(100);
|
||||
await increment.click();
|
||||
await page.waitForTimeout(100);
|
||||
await increment.click();
|
||||
await page.waitForTimeout(100);
|
||||
await increment.click();
|
||||
await expect(page.getByLabel("Number")).toHaveValue("4");
|
||||
await page.close();
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 5000));
|
||||
|
||||
const data = readFileSync(
|
||||
"../../demo/unload_event_test/output_log.txt",
|
||||
"utf-8"
|
||||
);
|
||||
expect(data).toContain("incremented 0");
|
||||
expect(data).toContain("incremented 1");
|
||||
expect(data).toContain("incremented 2");
|
||||
expect(data).toContain("incremented 3");
|
||||
expect(data).toContain("deleted 4");
|
||||
expect(data).toContain("unloading");
|
||||
});
|
@ -49,16 +49,11 @@ def connect():
|
||||
def _connect(demo: gr.Blocks, serialize=True, **kwargs):
|
||||
_, local_url, _ = demo.launch(prevent_thread_lock=True, **kwargs)
|
||||
try:
|
||||
yield Client(local_url, serialize=serialize)
|
||||
client = Client(local_url, serialize=serialize)
|
||||
yield client
|
||||
finally:
|
||||
# A more verbose version of .close()
|
||||
# because we should set a timeout
|
||||
# the tests that call .cancel() can get stuck
|
||||
# waiting for the thread to join
|
||||
demo._queue.close()
|
||||
demo.is_running = False
|
||||
demo.server.should_exit = True
|
||||
demo.server.thread.join(timeout=1)
|
||||
client.close()
|
||||
demo.close()
|
||||
|
||||
return _connect
|
||||
|
||||
|
@ -1726,3 +1726,44 @@ def test_static_files_multiple_apps(gradio_temp_dir):
|
||||
|
||||
# Input/Output got saved to cache
|
||||
assert len(list(gradio_temp_dir.glob("**/*.*"))) == 0
|
||||
|
||||
|
||||
def test_time_to_live_and_delete_callback_for_state(capsys, monkeypatch):
|
||||
monkeypatch.setenv("GRADIO_IS_E2E_TEST", 1)
|
||||
|
||||
def test_fn(x):
|
||||
return x + 1, x + 1
|
||||
|
||||
def delete_fn(v):
|
||||
print(f"deleted {v}")
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
n1 = gr.Number(value=0)
|
||||
state = gr.State(
|
||||
value=0, time_to_live=1, delete_callback=lambda v: delete_fn(v)
|
||||
)
|
||||
button = gr.Button("Increment")
|
||||
button.click(test_fn, [state], [n1, state], api_name="increment")
|
||||
|
||||
app, url, _ = demo.launch(prevent_thread_lock=True)
|
||||
|
||||
try:
|
||||
client_1 = grc.Client(url)
|
||||
client_2 = grc.Client(url)
|
||||
|
||||
client_1.predict(api_name="/increment")
|
||||
client_1.predict(api_name="/increment")
|
||||
client_1.predict(api_name="/increment")
|
||||
|
||||
client_2.predict(api_name="/increment")
|
||||
client_2.predict(api_name="/increment")
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "deleted 2" in captured.out
|
||||
assert "deleted 3" in captured.out
|
||||
for client in [client_1, client_2]:
|
||||
assert len(app.state_holder.session_data[client.session_hash]._data) == 0
|
||||
finally:
|
||||
demo.close()
|
||||
|
Loading…
x
Reference in New Issue
Block a user