2022-11-27 16:52:53 +08:00
|
|
|
import os
|
|
|
|
import tempfile
|
|
|
|
from collections import namedtuple
|
2023-01-03 19:18:48 +08:00
|
|
|
from pathlib import Path
|
2022-11-27 16:52:53 +08:00
|
|
|
|
2023-05-28 00:06:49 +08:00
|
|
|
import gradio.components
|
2022-11-27 16:52:53 +08:00
|
|
|
|
|
|
|
from PIL import PngImagePlugin
|
|
|
|
|
|
|
|
from modules import shared
|
|
|
|
|
|
|
|
|
|
|
|
Savedfile = namedtuple("Savedfile", ["name"])
|
|
|
|
|
|
|
|
|
2023-01-03 19:18:48 +08:00
|
|
|
def register_tmp_file(gradio, filename):
|
|
|
|
if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
|
|
|
|
gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
|
|
|
|
|
|
|
|
if hasattr(gradio, 'temp_dirs'): # gradio 3.9
|
|
|
|
gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
|
|
|
|
|
|
|
|
|
|
|
|
def check_tmp_file(gradio, filename):
|
|
|
|
if hasattr(gradio, 'temp_file_sets'):
|
2023-05-10 16:05:02 +08:00
|
|
|
return any(filename in fileset for fileset in gradio.temp_file_sets)
|
2023-01-03 19:18:48 +08:00
|
|
|
|
|
|
|
if hasattr(gradio, 'temp_dirs'):
|
|
|
|
return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
2023-06-04 14:20:23 +08:00
|
|
|
def save_pil_to_file(self, pil_image, dir=None, format="png"):
|
2022-11-27 16:52:53 +08:00
|
|
|
already_saved_as = getattr(pil_image, 'already_saved_as', None)
|
2022-11-28 04:14:13 +08:00
|
|
|
if already_saved_as and os.path.isfile(already_saved_as):
|
2023-05-04 15:55:57 +08:00
|
|
|
register_tmp_file(shared.demo, already_saved_as)
|
2024-02-24 00:26:56 +08:00
|
|
|
filename_with_mtime = f'{already_saved_as}?{os.path.getmtime(already_saved_as)}'
|
|
|
|
register_tmp_file(shared.demo, filename_with_mtime)
|
|
|
|
return filename_with_mtime
|
2022-11-27 16:52:53 +08:00
|
|
|
|
|
|
|
if shared.opts.temp_dir != "":
|
|
|
|
dir = shared.opts.temp_dir
|
2023-08-22 05:36:17 +08:00
|
|
|
else:
|
|
|
|
os.makedirs(dir, exist_ok=True)
|
2022-11-27 16:52:53 +08:00
|
|
|
|
|
|
|
use_metadata = False
|
|
|
|
metadata = PngImagePlugin.PngInfo()
|
|
|
|
for key, value in pil_image.info.items():
|
|
|
|
if isinstance(key, str) and isinstance(value, str):
|
|
|
|
metadata.add_text(key, value)
|
|
|
|
use_metadata = True
|
|
|
|
|
|
|
|
file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
|
|
|
|
pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
|
2023-05-28 00:06:49 +08:00
|
|
|
return file_obj.name
|
2022-11-27 16:52:53 +08:00
|
|
|
|
|
|
|
|
2023-08-09 23:11:13 +08:00
|
|
|
def install_ui_tempdir_override():
|
|
|
|
"""override save to file function so that it also writes PNG info"""
|
|
|
|
gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
|
2022-11-27 16:52:53 +08:00
|
|
|
|
|
|
|
|
|
|
|
def on_tmpdir_changed():
|
|
|
|
if shared.opts.temp_dir == "" or shared.demo is None:
|
|
|
|
return
|
|
|
|
|
|
|
|
os.makedirs(shared.opts.temp_dir, exist_ok=True)
|
|
|
|
|
2023-01-03 19:18:48 +08:00
|
|
|
register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
|
2022-11-27 16:52:53 +08:00
|
|
|
|
|
|
|
|
|
|
|
def cleanup_tmpdr():
|
|
|
|
temp_dir = shared.opts.temp_dir
|
|
|
|
if temp_dir == "" or not os.path.isdir(temp_dir):
|
|
|
|
return
|
|
|
|
|
2023-05-10 16:37:18 +08:00
|
|
|
for root, _, files in os.walk(temp_dir, topdown=False):
|
2022-11-27 16:52:53 +08:00
|
|
|
for name in files:
|
|
|
|
_, extension = os.path.splitext(name)
|
|
|
|
if extension != ".png":
|
|
|
|
continue
|
|
|
|
|
|
|
|
filename = os.path.join(root, name)
|
|
|
|
os.remove(filename)
|
2024-02-17 23:38:05 +08:00
|
|
|
|
|
|
|
|
|
|
|
def is_gradio_temp_path(path):
|
|
|
|
"""
|
|
|
|
Check if the path is a temp dir used by gradio
|
|
|
|
"""
|
|
|
|
path = Path(path)
|
|
|
|
if shared.opts.temp_dir and path.is_relative_to(shared.opts.temp_dir):
|
|
|
|
return True
|
|
|
|
if gradio_temp_dir := os.environ.get("GRADIO_TEMP_DIR"):
|
|
|
|
if path.is_relative_to(gradio_temp_dir):
|
|
|
|
return True
|
|
|
|
if path.is_relative_to(Path(tempfile.gettempdir()) / "gradio"):
|
|
|
|
return True
|
|
|
|
return False
|