mirror of
https://github.com/gradio-app/gradio.git
synced 2024-11-21 01:01:05 +08:00
Add delete_cache parameter to gr.Blocks to delete files created by app on shutdown (#7447)
* Add code * add changeset * Add code * trigger ci * Add schedule * Fix implementation * Fix test * Address comments * add changeset * handle examples * Update guides/01_getting-started/03_sharing-your-app.md Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * Fix code * Fix code --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
This commit is contained in:
parent
3645da5f1e
commit
a57e34ef87
5
.changeset/public-hoops-drum.md
Normal file
5
.changeset/public-hoops-drum.md
Normal file
@ -0,0 +1,5 @@
|
||||
---
|
||||
"gradio": minor
|
||||
---
|
||||
|
||||
feat:Add delete_cache parameter to gr.Blocks to delete files created by app on shutdown
|
@ -526,6 +526,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
js: str | None = None,
|
||||
head: str | None = None,
|
||||
fill_height: bool = False,
|
||||
delete_cache: tuple[int, int] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -538,6 +539,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
js: Custom js or path to js file to run when demo is first loaded. This javascript will be included in the demo webpage.
|
||||
head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, scripts, stylesheets, etc. to the page.
|
||||
fill_height: Whether to vertically expand top-level child components to the height of the window. If True, expansion occurs when the scale value of the child components >= 1.
|
||||
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
|
||||
"""
|
||||
self.limiter = None
|
||||
if theme is None:
|
||||
@ -566,6 +568,7 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
self.show_error = True
|
||||
self.head = head
|
||||
self.fill_height = fill_height
|
||||
self.delete_cache = delete_cache
|
||||
if css is not None and os.path.exists(css):
|
||||
with open(css) as css_file:
|
||||
self.css = css_file.read()
|
||||
@ -608,7 +611,8 @@ class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
|
||||
self.auth = None
|
||||
self.dev_mode = bool(os.getenv("GRADIO_WATCH_DIRS", ""))
|
||||
self.app_id = random.getrandbits(64)
|
||||
self.temp_file_sets = []
|
||||
self.upload_file_set = set()
|
||||
self.temp_file_sets = [self.upload_file_set]
|
||||
self.title = title
|
||||
self.show_api = not wasm_utils.IS_WASM
|
||||
|
||||
|
@ -77,6 +77,7 @@ class ChatInterface(Blocks):
|
||||
autofocus: bool = True,
|
||||
concurrency_limit: int | None | Literal["default"] = "default",
|
||||
fill_height: bool = True,
|
||||
delete_cache: tuple[int, int] | None = None,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
@ -103,6 +104,7 @@ class ChatInterface(Blocks):
|
||||
autofocus: If True, autofocuses to the textbox when the page loads.
|
||||
concurrency_limit: If set, this is the maximum number of chatbot submissions that can be running simultaneously. Can be set to None to mean no limit (any number of chatbot submissions can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which is 1 by default).
|
||||
fill_height: If True, the chat interface will expand to the height of window.
|
||||
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
|
||||
"""
|
||||
super().__init__(
|
||||
analytics_enabled=analytics_enabled,
|
||||
@ -113,6 +115,7 @@ class ChatInterface(Blocks):
|
||||
js=js,
|
||||
head=head,
|
||||
fill_height=fill_height,
|
||||
delete_cache=delete_cache,
|
||||
)
|
||||
self.concurrency_limit = concurrency_limit
|
||||
self.fn = fn
|
||||
|
@ -14,6 +14,8 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from gradio_client.utils import is_file_obj
|
||||
|
||||
from gradio import utils
|
||||
from gradio.blocks import Block, BlockContext
|
||||
from gradio.component_meta import ComponentMeta
|
||||
@ -189,6 +191,9 @@ class Component(ComponentBase, Block):
|
||||
self.scale = scale
|
||||
self.min_width = min_width
|
||||
self.interactive = interactive
|
||||
# Keep tracks of files that should not be deleted when the delete_cache parmaeter is set
|
||||
# These files are the default value of the component and files that are used in examples
|
||||
self.keep_in_cache = set()
|
||||
|
||||
# load_event is set in the Blocks.attach_load_events method
|
||||
self.load_event: None | dict[str, Any] = None
|
||||
@ -200,6 +205,8 @@ class Component(ComponentBase, Block):
|
||||
self, # type: ignore
|
||||
postprocess=True,
|
||||
)
|
||||
if is_file_obj(self.value):
|
||||
self.keep_in_cache.add(self.value["path"])
|
||||
|
||||
if callable(load_fn):
|
||||
self.attach_load_event(load_fn, every)
|
||||
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
from typing import Any, Literal
|
||||
|
||||
from gradio_client.documentation import document
|
||||
from gradio_client.utils import is_file_obj
|
||||
|
||||
from gradio import processing_utils
|
||||
from gradio.components.base import (
|
||||
@ -98,6 +99,8 @@ class Dataset(Component):
|
||||
example[i],
|
||||
component,
|
||||
)
|
||||
if is_file_obj(example[i]):
|
||||
self.keep_in_cache.add(example[i]["path"])
|
||||
self.type = type
|
||||
self.label = label
|
||||
if headers is not None:
|
||||
|
@ -121,6 +121,7 @@ class Interface(Blocks):
|
||||
submit_btn: str | Button = "Submit",
|
||||
stop_btn: str | Button = "Stop",
|
||||
clear_btn: str | Button = "Clear",
|
||||
delete_cache: tuple[int, int] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -155,6 +156,7 @@ class Interface(Blocks):
|
||||
submit_btn: The button to use for submitting inputs. Defaults to a `gr.Button("Submit", variant="primary")`. This parameter does not apply if the Interface is output-only, in which case the submit button always displays "Generate". Can be set to a string (which becomes the button label) or a `gr.Button` object (which allows for more customization).
|
||||
stop_btn: The button to use for stopping the interface. Defaults to a `gr.Button("Stop", variant="stop", visible=False)`. Can be set to a string (which becomes the button label) or a `gr.Button` object (which allows for more customization).
|
||||
clear_btn: The button to use for clearing the inputs. Defaults to a `gr.Button("Clear", variant="secondary")`. Can be set to a string (which becomes the button label) or a `gr.Button` object (which allows for more customization).
|
||||
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
|
||||
"""
|
||||
super().__init__(
|
||||
analytics_enabled=analytics_enabled,
|
||||
@ -164,6 +166,7 @@ class Interface(Blocks):
|
||||
theme=theme,
|
||||
js=js,
|
||||
head=head,
|
||||
delete_cache=delete_cache,
|
||||
**kwargs,
|
||||
)
|
||||
self.api_name: str | Literal[False] | None = api_name
|
||||
|
@ -1,16 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from collections import deque
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass as python_dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, BinaryIO, List, Optional, Tuple, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AsyncContextManager,
|
||||
AsyncGenerator,
|
||||
BinaryIO,
|
||||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import anyio
|
||||
import fastapi
|
||||
import httpx
|
||||
import multipart
|
||||
@ -640,3 +656,67 @@ class CustomCORSMiddleware(BaseHTTPMiddleware):
|
||||
"Access-Control-Allow-Headers"
|
||||
] = "Origin, Content-Type, Accept"
|
||||
return response
|
||||
|
||||
|
||||
def delete_files_created_by_app(blocks: Blocks, age: int | None) -> None:
|
||||
"""Delete files that are older than age. If age is None, delete all files."""
|
||||
|
||||
dont_delete = set()
|
||||
for component in blocks.blocks.values():
|
||||
dont_delete.update(getattr(component, "keep_in_cache", set()))
|
||||
for temp_set in blocks.temp_file_sets:
|
||||
# We use a copy of the set to avoid modifying the set while iterating over it
|
||||
# otherwise we would get an exception: Set changed size during iteration
|
||||
to_remove = set()
|
||||
for file in temp_set:
|
||||
if file in dont_delete:
|
||||
continue
|
||||
try:
|
||||
file_path = Path(file)
|
||||
modified_time = datetime.fromtimestamp(file_path.lstat().st_ctime)
|
||||
if age is None or (datetime.now() - modified_time).seconds > age:
|
||||
os.remove(file)
|
||||
to_remove.add(file)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
temp_set -= to_remove
|
||||
|
||||
|
||||
async def delete_files_on_schedule(app: App, frequency: int, age: int) -> None:
|
||||
"""Startup task to delete files created by the app based on time since last modification."""
|
||||
while True:
|
||||
await asyncio.sleep(frequency)
|
||||
await anyio.to_thread.run_sync(
|
||||
delete_files_created_by_app, app.get_blocks(), age
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _lifespan_handler(
|
||||
app: App, frequency: int = 1, age: int = 1
|
||||
) -> AsyncGenerator:
|
||||
"""A context manager that triggers the startup and shutdown events of the app."""
|
||||
app.get_blocks().startup_events()
|
||||
app.startup_events_triggered = True
|
||||
asyncio.create_task(delete_files_on_schedule(app, frequency, age))
|
||||
yield
|
||||
delete_files_created_by_app(app.get_blocks(), age=None)
|
||||
|
||||
|
||||
def create_lifespan_handler(
|
||||
user_lifespan: Callable[[App], AsyncContextManager] | None,
|
||||
frequency: int = 1,
|
||||
age: int = 1,
|
||||
) -> Callable[[App], AsyncContextManager]:
|
||||
"""Return a context manager that applies _lifespan_handler and user_lifespan if it exists."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def _handler(app: App):
|
||||
async with _lifespan_handler(app, frequency, age):
|
||||
if user_lifespan is not None:
|
||||
async with user_lifespan(app):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
return _handler
|
||||
|
@ -63,6 +63,7 @@ from gradio.route_utils import ( # noqa: F401
|
||||
MultiPartException,
|
||||
Request,
|
||||
compare_passwords_securely,
|
||||
create_lifespan_handler,
|
||||
move_uploaded_files_to_cache,
|
||||
)
|
||||
from gradio.state_holder import StateHolder
|
||||
@ -192,6 +193,10 @@ class App(FastAPI):
|
||||
) -> App:
|
||||
app_kwargs = app_kwargs or {}
|
||||
app_kwargs.setdefault("default_response_class", ORJSONResponse)
|
||||
if blocks.delete_cache is not None:
|
||||
app_kwargs["lifespan"] = create_lifespan_handler(
|
||||
app_kwargs.get("lifespan", None), *blocks.delete_cache
|
||||
)
|
||||
app = App(**app_kwargs)
|
||||
app.configure_app(blocks)
|
||||
|
||||
@ -827,6 +832,7 @@ class App(FastAPI):
|
||||
files_to_copy.append(temp_file.file.name)
|
||||
locations.append(str(dest))
|
||||
output_files.append(dest)
|
||||
blocks.upload_file_set.add(str(dest))
|
||||
if files_to_copy:
|
||||
bg_tasks.add_task(
|
||||
move_uploaded_files_to_cache, files_to_copy, locations
|
||||
|
@ -315,7 +315,7 @@ Sharing your Gradio app with others (by hosting it on Spaces, on your own server
|
||||
|
||||
In particular, Gradio apps ALLOW users to access to three kinds of files:
|
||||
|
||||
- **Temporary files created by Gradio.** These are files that are created by Gradio as part of running your prediction function. For example, if your prediction function returns a video file, then Gradio will save that video to a temporary cache on your device and then send the path to the file to the front end. You can customize the location of temporary cache files created by Gradio by setting the environment variable `GRADIO_TEMP_DIR` to an absolute path, such as `/home/usr/scripts/project/temp/`.
|
||||
- **Temporary files created by Gradio.** These are files that are created by Gradio as part of running your prediction function. For example, if your prediction function returns a video file, then Gradio will save that video to a temporary cache on your device and then send the path to the file to the front end. You can customize the location of temporary cache files created by Gradio by setting the environment variable `GRADIO_TEMP_DIR` to an absolute path, such as `/home/usr/scripts/project/temp/`. You can delete the files created by your app when it shuts down with the `delete_cache` parameter of `gradio.Blocks`, `gradio.Interface`, and `gradio.ChatInterface`. This parameter is a tuple of integers of the form `[frequency, age]` where `frequency` is how often to delete files and `age` is the time in seconds since the file was last modified.
|
||||
|
||||
|
||||
- **Cached examples created by Gradio.** These are files that are created by Gradio as part of caching examples for faster runtimes, if you set `cache_examples=True` in `gr.Interface()` or in `gr.Examples()`. By default, these files are saved in the `gradio_cached_examples/` subdirectory within your app's working directory. You can customize the location of cached example files created by Gradio by setting the environment variable `GRADIO_EXAMPLES_CACHE` to an absolute path or a path relative to your working directory.
|
||||
|
@ -41,8 +41,8 @@ def io_components():
|
||||
@pytest.fixture
|
||||
def connect():
|
||||
@contextmanager
|
||||
def _connect(demo: gr.Blocks, serialize=True):
|
||||
_, local_url, _ = demo.launch(prevent_thread_lock=True)
|
||||
def _connect(demo: gr.Blocks, serialize=True, **kwargs):
|
||||
_, local_url, _ = demo.launch(prevent_thread_lock=True, **kwargs)
|
||||
try:
|
||||
yield Client(local_url, serialize=serialize)
|
||||
finally:
|
||||
|
@ -1584,8 +1584,12 @@ def test_temp_file_sets_get_extended():
|
||||
with gr.Blocks() as demo3:
|
||||
demo1.render()
|
||||
demo2.render()
|
||||
|
||||
assert demo3.temp_file_sets == demo1.temp_file_sets + demo2.temp_file_sets
|
||||
# The upload_set is empty so we remove it from the check
|
||||
demo_3_no_empty = [s for s in demo3.temp_file_sets if len(s)]
|
||||
demo_1_and_2_no_empty = [
|
||||
s for s in demo1.temp_file_sets + demo2.temp_file_sets if len(s)
|
||||
]
|
||||
assert demo_3_no_empty == demo_1_and_2_no_empty
|
||||
|
||||
|
||||
def test_recover_kwargs():
|
||||
|
@ -480,6 +480,49 @@ class TestRoutes:
|
||||
assert file_response.headers["access-control-allow-origin"] == "127.0.0.1"
|
||||
io.close()
|
||||
|
||||
def test_delete_cache(self, connect, gradio_temp_dir, capsys):
|
||||
def check_num_files_exist(blocks: Blocks):
|
||||
num_files = 0
|
||||
for temp_file_set in blocks.temp_file_sets:
|
||||
for temp_file in temp_file_set:
|
||||
if os.path.exists(temp_file):
|
||||
num_files += 1
|
||||
return num_files
|
||||
|
||||
demo = gr.Interface(lambda s: s, gr.Textbox(), gr.File(), delete_cache=None)
|
||||
with connect(demo) as client:
|
||||
client.predict("test/test_files/cheetah1.jpg")
|
||||
assert check_num_files_exist(demo) == 1
|
||||
|
||||
demo_delete = gr.Interface(
|
||||
lambda s: s, gr.Textbox(), gr.File(), delete_cache=(60, 30)
|
||||
)
|
||||
with connect(demo_delete) as client:
|
||||
client.predict("test/test_files/alphabet.txt")
|
||||
client.predict("test/test_files/bus.png")
|
||||
assert check_num_files_exist(demo_delete) == 2
|
||||
assert check_num_files_exist(demo_delete) == 0
|
||||
assert check_num_files_exist(demo) == 1
|
||||
|
||||
@asynccontextmanager
|
||||
async def mylifespan(app: FastAPI):
|
||||
print("IN CUSTOM LIFESPAN")
|
||||
yield
|
||||
print("AFTER CUSTOM LIFESPAN")
|
||||
|
||||
demo_custom_lifespan = gr.Interface(
|
||||
lambda s: s, gr.Textbox(), gr.File(), delete_cache=(5, 1)
|
||||
)
|
||||
|
||||
with connect(
|
||||
demo_custom_lifespan, app_kwargs={"lifespan": mylifespan}
|
||||
) as client:
|
||||
client.predict("test/test_files/alphabet.txt")
|
||||
assert check_num_files_exist(demo_custom_lifespan) == 0
|
||||
captured = capsys.readouterr()
|
||||
assert "IN CUSTOM LIFESPAN" in captured.out
|
||||
assert "AFTER CUSTOM LIFESPAN" in captured.out
|
||||
|
||||
|
||||
class TestApp:
|
||||
def test_create_app(self):
|
||||
|
Loading…
Reference in New Issue
Block a user