Add delete_cache parameter to gr.Blocks to delete files created by app on shutdown (#7447)

* Add code

* add changeset

* Add code

* trigger ci

* Add schedule

* Fix implementation

* Fix test

* Address comments

* add changeset

* handle examples

* Update guides/01_getting-started/03_sharing-your-app.md

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Fix code

* Fix code

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
Freddy Boulton 2024-03-04 17:51:02 -08:00 committed by GitHub
parent 3645da5f1e
commit a57e34ef87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 165 additions and 7 deletions

View File

@ -0,0 +1,5 @@
---
"gradio": minor
---
feat:Add delete_cache parameter to gr.Blocks to delete files created by app on shutdown

View File

@ -526,6 +526,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
js: str | None = None,
head: str | None = None,
fill_height: bool = False,
delete_cache: tuple[int, int] | None = None,
**kwargs,
):
"""
@ -538,6 +539,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
js: Custom js or path to js file to run when demo is first loaded. This javascript will be included in the demo webpage.
head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, scripts, stylesheets, etc. to the page.
fill_height: Whether to vertically expand top-level child components to the height of the window. If True, expansion occurs when the scale value of the child components >= 1.
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
"""
self.limiter = None
if theme is None:
@ -566,6 +568,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
self.show_error = True
self.head = head
self.fill_height = fill_height
self.delete_cache = delete_cache
if css is not None and os.path.exists(css):
with open(css) as css_file:
self.css = css_file.read()
@ -608,7 +611,8 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
self.auth = None
self.dev_mode = bool(os.getenv("GRADIO_WATCH_DIRS", ""))
self.app_id = random.getrandbits(64)
self.temp_file_sets = []
self.upload_file_set = set()
self.temp_file_sets = [self.upload_file_set]
self.title = title
self.show_api = not wasm_utils.IS_WASM

View File

@ -77,6 +77,7 @@ class ChatInterface(Blocks):
autofocus: bool = True,
concurrency_limit: int | None | Literal["default"] = "default",
fill_height: bool = True,
delete_cache: tuple[int, int] | None = None,
):
"""
Parameters:
@ -103,6 +104,7 @@ class ChatInterface(Blocks):
autofocus: If True, autofocuses to the textbox when the page loads.
concurrency_limit: If set, this is the maximum number of chatbot submissions that can be running simultaneously. Can be set to None to mean no limit (any number of chatbot submissions can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which is 1 by default).
fill_height: If True, the chat interface will expand to the height of window.
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
"""
super().__init__(
analytics_enabled=analytics_enabled,
@ -113,6 +115,7 @@ class ChatInterface(Blocks):
js=js,
head=head,
fill_height=fill_height,
delete_cache=delete_cache,
)
self.concurrency_limit = concurrency_limit
self.fn = fn

View File

@ -14,6 +14,8 @@ from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
from gradio_client.utils import is_file_obj
from gradio import utils
from gradio.blocks import Block, BlockContext
from gradio.component_meta import ComponentMeta
@ -189,6 +191,9 @@ class Component(ComponentBase, Block):
self.scale = scale
self.min_width = min_width
self.interactive = interactive
# Keep tracks of files that should not be deleted when the delete_cache parmaeter is set
# These files are the default value of the component and files that are used in examples
self.keep_in_cache = set()
# load_event is set in the Blocks.attach_load_events method
self.load_event: None | dict[str, Any] = None
@ -200,6 +205,8 @@ class Component(ComponentBase, Block):
self, # type: ignore
postprocess=True,
)
if is_file_obj(self.value):
self.keep_in_cache.add(self.value["path"])
if callable(load_fn):
self.attach_load_event(load_fn, every)

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from typing import Any, Literal
from gradio_client.documentation import document
from gradio_client.utils import is_file_obj
from gradio import processing_utils
from gradio.components.base import (
@ -98,6 +99,8 @@ class Dataset(Component):
example[i],
component,
)
if is_file_obj(example[i]):
self.keep_in_cache.add(example[i]["path"])
self.type = type
self.label = label
if headers is not None:

View File

@ -121,6 +121,7 @@ class Interface(Blocks):
submit_btn: str | Button = "Submit",
stop_btn: str | Button = "Stop",
clear_btn: str | Button = "Clear",
delete_cache: tuple[int, int] | None = None,
**kwargs,
):
"""
@ -155,6 +156,7 @@ class Interface(Blocks):
submit_btn: The button to use for submitting inputs. Defaults to a `gr.Button("Submit", variant="primary")`. This parameter does not apply if the Interface is output-only, in which case the submit button always displays "Generate". Can be set to a string (which becomes the button label) or a `gr.Button` object (which allows for more customization).
stop_btn: The button to use for stopping the interface. Defaults to a `gr.Button("Stop", variant="stop", visible=False)`. Can be set to a string (which becomes the button label) or a `gr.Button` object (which allows for more customization).
clear_btn: The button to use for clearing the inputs. Defaults to a `gr.Button("Clear", variant="secondary")`. Can be set to a string (which becomes the button label) or a `gr.Button` object (which allows for more customization).
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
"""
super().__init__(
analytics_enabled=analytics_enabled,
@ -164,6 +166,7 @@ class Interface(Blocks):
theme=theme,
js=js,
head=head,
delete_cache=delete_cache,
**kwargs,
)
self.api_name: str | Literal[False] | None = api_name

View File

@ -1,16 +1,32 @@
from __future__ import annotations
import asyncio
import hashlib
import hmac
import json
import os
import re
import shutil
from collections import deque
from contextlib import asynccontextmanager
from dataclasses import dataclass as python_dataclass
from datetime import datetime
from pathlib import Path
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
from typing import TYPE_CHECKING, AsyncGenerator, BinaryIO, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
AsyncContextManager,
AsyncGenerator,
BinaryIO,
Callable,
List,
Optional,
Tuple,
Union,
)
from urllib.parse import urlparse
import anyio
import fastapi
import httpx
import multipart
@ -640,3 +656,67 @@ class CustomCORSMiddleware(BaseHTTPMiddleware):
"Access-Control-Allow-Headers"
] = "Origin, Content-Type, Accept"
return response
def delete_files_created_by_app(blocks: Blocks, age: int | None) -> None:
"""Delete files that are older than age. If age is None, delete all files."""
dont_delete = set()
for component in blocks.blocks.values():
dont_delete.update(getattr(component, "keep_in_cache", set()))
for temp_set in blocks.temp_file_sets:
# We use a copy of the set to avoid modifying the set while iterating over it
# otherwise we would get an exception: Set changed size during iteration
to_remove = set()
for file in temp_set:
if file in dont_delete:
continue
try:
file_path = Path(file)
modified_time = datetime.fromtimestamp(file_path.lstat().st_ctime)
if age is None or (datetime.now() - modified_time).seconds > age:
os.remove(file)
to_remove.add(file)
except FileNotFoundError:
continue
temp_set -= to_remove
async def delete_files_on_schedule(app: App, frequency: int, age: int) -> None:
"""Startup task to delete files created by the app based on time since last modification."""
while True:
await asyncio.sleep(frequency)
await anyio.to_thread.run_sync(
delete_files_created_by_app, app.get_blocks(), age
)
@asynccontextmanager
async def _lifespan_handler(
app: App, frequency: int = 1, age: int = 1
) -> AsyncGenerator:
"""A context manager that triggers the startup and shutdown events of the app."""
app.get_blocks().startup_events()
app.startup_events_triggered = True
asyncio.create_task(delete_files_on_schedule(app, frequency, age))
yield
delete_files_created_by_app(app.get_blocks(), age=None)
def create_lifespan_handler(
user_lifespan: Callable[[App], AsyncContextManager] | None,
frequency: int = 1,
age: int = 1,
) -> Callable[[App], AsyncContextManager]:
"""Return a context manager that applies _lifespan_handler and user_lifespan if it exists."""
@asynccontextmanager
async def _handler(app: App):
async with _lifespan_handler(app, frequency, age):
if user_lifespan is not None:
async with user_lifespan(app):
yield
else:
yield
return _handler

View File

@ -63,6 +63,7 @@ from gradio.route_utils import ( # noqa: F401
MultiPartException,
Request,
compare_passwords_securely,
create_lifespan_handler,
move_uploaded_files_to_cache,
)
from gradio.state_holder import StateHolder
@ -192,6 +193,10 @@ class App(FastAPI):
) -> App:
app_kwargs = app_kwargs or {}
app_kwargs.setdefault("default_response_class", ORJSONResponse)
if blocks.delete_cache is not None:
app_kwargs["lifespan"] = create_lifespan_handler(
app_kwargs.get("lifespan", None), *blocks.delete_cache
)
app = App(**app_kwargs)
app.configure_app(blocks)
@ -827,6 +832,7 @@ class App(FastAPI):
files_to_copy.append(temp_file.file.name)
locations.append(str(dest))
output_files.append(dest)
blocks.upload_file_set.add(str(dest))
if files_to_copy:
bg_tasks.add_task(
move_uploaded_files_to_cache, files_to_copy, locations

View File

@ -315,7 +315,7 @@ Sharing your Gradio app with others (by hosting it on Spaces, on your own server
In particular, Gradio apps ALLOW users to access to three kinds of files:
- **Temporary files created by Gradio.** These are files that are created by Gradio as part of running your prediction function. For example, if your prediction function returns a video file, then Gradio will save that video to a temporary cache on your device and then send the path to the file to the front end. You can customize the location of temporary cache files created by Gradio by setting the environment variable `GRADIO_TEMP_DIR` to an absolute path, such as `/home/usr/scripts/project/temp/`.
- **Temporary files created by Gradio.** These are files that are created by Gradio as part of running your prediction function. For example, if your prediction function returns a video file, then Gradio will save that video to a temporary cache on your device and then send the path to the file to the front end. You can customize the location of temporary cache files created by Gradio by setting the environment variable `GRADIO_TEMP_DIR` to an absolute path, such as `/home/usr/scripts/project/temp/`. You can delete the files created by your app when it shuts down with the `delete_cache` parameter of `gradio.Blocks`, `gradio.Interface`, and `gradio.ChatInterface`. This parameter is a tuple of integers of the form `[frequency, age]` where `frequency` is how often to delete files and `age` is the time in seconds since the file was last modified.
- **Cached examples created by Gradio.** These are files that are created by Gradio as part of caching examples for faster runtimes, if you set `cache_examples=True` in `gr.Interface()` or in `gr.Examples()`. By default, these files are saved in the `gradio_cached_examples/` subdirectory within your app's working directory. You can customize the location of cached example files created by Gradio by setting the environment variable `GRADIO_EXAMPLES_CACHE` to an absolute path or a path relative to your working directory.

View File

@ -41,8 +41,8 @@ def io_components():
@pytest.fixture
def connect():
@contextmanager
def _connect(demo: gr.Blocks, serialize=True):
_, local_url, _ = demo.launch(prevent_thread_lock=True)
def _connect(demo: gr.Blocks, serialize=True, **kwargs):
_, local_url, _ = demo.launch(prevent_thread_lock=True, **kwargs)
try:
yield Client(local_url, serialize=serialize)
finally:

View File

@ -1584,8 +1584,12 @@ def test_temp_file_sets_get_extended():
with gr.Blocks() as demo3:
demo1.render()
demo2.render()
assert demo3.temp_file_sets == demo1.temp_file_sets + demo2.temp_file_sets
# The upload_set is empty so we remove it from the check
demo_3_no_empty = [s for s in demo3.temp_file_sets if len(s)]
demo_1_and_2_no_empty = [
s for s in demo1.temp_file_sets + demo2.temp_file_sets if len(s)
]
assert demo_3_no_empty == demo_1_and_2_no_empty
def test_recover_kwargs():

View File

@ -480,6 +480,49 @@ class TestRoutes:
assert file_response.headers["access-control-allow-origin"] == "127.0.0.1"
io.close()
def test_delete_cache(self, connect, gradio_temp_dir, capsys):
def check_num_files_exist(blocks: Blocks):
num_files = 0
for temp_file_set in blocks.temp_file_sets:
for temp_file in temp_file_set:
if os.path.exists(temp_file):
num_files += 1
return num_files
demo = gr.Interface(lambda s: s, gr.Textbox(), gr.File(), delete_cache=None)
with connect(demo) as client:
client.predict("test/test_files/cheetah1.jpg")
assert check_num_files_exist(demo) == 1
demo_delete = gr.Interface(
lambda s: s, gr.Textbox(), gr.File(), delete_cache=(60, 30)
)
with connect(demo_delete) as client:
client.predict("test/test_files/alphabet.txt")
client.predict("test/test_files/bus.png")
assert check_num_files_exist(demo_delete) == 2
assert check_num_files_exist(demo_delete) == 0
assert check_num_files_exist(demo) == 1
@asynccontextmanager
async def mylifespan(app: FastAPI):
print("IN CUSTOM LIFESPAN")
yield
print("AFTER CUSTOM LIFESPAN")
demo_custom_lifespan = gr.Interface(
lambda s: s, gr.Textbox(), gr.File(), delete_cache=(5, 1)
)
with connect(
demo_custom_lifespan, app_kwargs={"lifespan": mylifespan}
) as client:
client.predict("test/test_files/alphabet.txt")
assert check_num_files_exist(demo_custom_lifespan) == 0
captured = capsys.readouterr()
assert "IN CUSTOM LIFESPAN" in captured.out
assert "AFTER CUSTOM LIFESPAN" in captured.out
class TestApp:
def test_create_app(self):