stable-diffusion-webui/modules/postprocessing.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

109 lines
4.2 KiB
Python
Raw Normal View History

import os
from PIL import Image
2023-01-23 14:24:43 +08:00
from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
from modules.shared import opts
2023-01-23 14:24:43 +08:00
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
devices.torch_gc()
2023-06-30 18:11:31 +08:00
shared.state.begin(job="extras")
2023-01-03 23:34:51 +08:00
2022-10-16 12:50:55 +08:00
outputs = []
def get_images(extras_mode, image, image_folder, input_dir):
if extras_mode == 1:
for img in image_folder:
if isinstance(img, Image.Image):
image = img
fn = ''
else:
image = Image.open(os.path.abspath(img.name))
fn = os.path.splitext(img.orig_name)[0]
yield image, fn
elif extras_mode == 2:
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
assert input_dir, 'input directory not selected'
image_list = shared.listfiles(input_dir)
for filename in image_list:
try:
image = Image.open(filename)
except Exception:
continue
yield image, filename
else:
assert image, 'image not selected'
yield image, None
2022-10-16 12:50:55 +08:00
if extras_mode == 2 and output_dir != '':
outpath = output_dir
else:
outpath = opts.outdir_samples or opts.outdir_extras_samples
2023-01-23 14:24:43 +08:00
infotext = ''
for image_data, name in get_images(extras_mode, image, image_folder, input_dir):
image_data: Image.Image
2023-01-23 14:24:43 +08:00
shared.state.textinfo = name
parameters, existing_pnginfo = images.read_info_from_image(image_data)
if parameters:
existing_pnginfo["parameters"] = parameters
2022-10-10 09:26:52 +08:00
pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
2023-01-23 14:24:43 +08:00
scripts.scripts_postproc.run(pp, args)
if opts.use_original_name_batch and name is not None:
basename = os.path.splitext(os.path.basename(name))[0]
else:
basename = ''
2023-01-23 14:24:43 +08:00
infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
2023-01-03 23:34:51 +08:00
2023-01-23 14:24:43 +08:00
if opts.enable_pnginfo:
pp.image.info = existing_pnginfo
pp.image.info["postprocessing"] = infotext
2022-12-17 20:31:03 +08:00
2023-01-23 14:24:43 +08:00
if save_output:
images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
2023-01-23 14:24:43 +08:00
if extras_mode != 2 or show_extras_results:
outputs.append(pp.image)
image_data.close()
devices.torch_gc()
shared.state.end()
2023-01-23 14:24:43 +08:00
return outputs, ui_common.plaintext_to_html(infotext), ''
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
"""old handler for API"""
args = scripts.scripts_postproc.create_args_for_run({
"Upscale": {
"upscale_mode": resize_mode,
"upscale_by": upscaling_resize,
"upscale_to_width": upscaling_resize_w,
"upscale_to_height": upscaling_resize_h,
"upscale_crop": upscaling_crop,
"upscaler_1_name": extras_upscaler_1,
"upscaler_2_name": extras_upscaler_2,
"upscaler_2_visibility": extras_upscaler_2_visibility,
},
"GFPGAN": {
"gfpgan_visibility": gfpgan_visibility,
},
"CodeFormer": {
"codeformer_visibility": codeformer_visibility,
"codeformer_weight": codeformer_weight,
},
})
return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)