mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-12-21 07:30:02 +08:00
6e2ce4e735
Only comuted if the mask weight is not 1.0 to avoid extra memory. Also includes some code cleanup.
833 lines
38 KiB
Python
833 lines
38 KiB
Python
import json
|
|
import math
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
import numpy as np
|
|
from PIL import Image, ImageFilter, ImageOps
|
|
import random
|
|
import cv2
|
|
from skimage import exposure
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import modules.sd_hijack
|
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
|
|
from modules.sd_hijack import model_hijack
|
|
from modules.shared import opts, cmd_opts, state
|
|
import modules.shared as shared
|
|
import modules.face_restoration
|
|
import modules.images as images
|
|
import modules.styles
|
|
import logging
|
|
|
|
|
|
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
|
opt_C = 4
|
|
opt_f = 8
|
|
|
|
|
|
def setup_color_correction(image):
|
|
logging.info("Calibrating color correction.")
|
|
correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
|
|
return correction_target
|
|
|
|
|
|
def apply_color_correction(correction, image):
|
|
logging.info("Applying color correction.")
|
|
image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
|
|
cv2.cvtColor(
|
|
np.asarray(image),
|
|
cv2.COLOR_RGB2LAB
|
|
),
|
|
correction,
|
|
channel_axis=2
|
|
), cv2.COLOR_LAB2RGB).astype("uint8"))
|
|
|
|
return image
|
|
|
|
|
|
def apply_overlay(image, paste_loc, index, overlays):
|
|
if overlays is None or index >= len(overlays):
|
|
return image
|
|
|
|
overlay = overlays[index]
|
|
|
|
if paste_loc is not None:
|
|
x, y, w, h = paste_loc
|
|
base_image = Image.new('RGBA', (overlay.width, overlay.height))
|
|
image = images.resize_image(1, image, w, h)
|
|
base_image.paste(image, (x, y))
|
|
image = base_image
|
|
|
|
image = image.convert('RGBA')
|
|
image.alpha_composite(overlay)
|
|
image = image.convert('RGB')
|
|
|
|
return image
|
|
|
|
def get_correct_sampler(p):
|
|
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
|
|
return sd_samplers.samplers
|
|
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
|
|
return sd_samplers.samplers_for_img2img
|
|
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
|
|
return sd_samplers.samplers
|
|
|
|
class StableDiffusionProcessing():
|
|
"""
|
|
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
|
"""
|
|
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None):
|
|
self.sd_model = sd_model
|
|
self.outpath_samples: str = outpath_samples
|
|
self.outpath_grids: str = outpath_grids
|
|
self.prompt: str = prompt
|
|
self.prompt_for_display: str = None
|
|
self.negative_prompt: str = (negative_prompt or "")
|
|
self.styles: list = styles or []
|
|
self.seed: int = seed
|
|
self.subseed: int = subseed
|
|
self.subseed_strength: float = subseed_strength
|
|
self.seed_resize_from_h: int = seed_resize_from_h
|
|
self.seed_resize_from_w: int = seed_resize_from_w
|
|
self.sampler_index: int = sampler_index
|
|
self.batch_size: int = batch_size
|
|
self.n_iter: int = n_iter
|
|
self.steps: int = steps
|
|
self.cfg_scale: float = cfg_scale
|
|
self.width: int = width
|
|
self.height: int = height
|
|
self.restore_faces: bool = restore_faces
|
|
self.tiling: bool = tiling
|
|
self.do_not_save_samples: bool = do_not_save_samples
|
|
self.do_not_save_grid: bool = do_not_save_grid
|
|
self.extra_generation_params: dict = extra_generation_params or {}
|
|
self.overlay_images = overlay_images
|
|
self.eta = eta
|
|
self.do_not_reload_embeddings = do_not_reload_embeddings
|
|
self.paste_to = None
|
|
self.color_corrections = None
|
|
self.denoising_strength: float = denoising_strength
|
|
self.sampler_noise_scheduler_override = None
|
|
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
|
|
self.s_churn = s_churn or opts.s_churn
|
|
self.s_tmin = s_tmin or opts.s_tmin
|
|
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
|
self.s_noise = s_noise or opts.s_noise
|
|
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
|
|
|
|
if not seed_enable_extras:
|
|
self.subseed = -1
|
|
self.subseed_strength = 0
|
|
self.seed_resize_from_h = 0
|
|
self.seed_resize_from_w = 0
|
|
|
|
self.scripts = None
|
|
self.script_args = None
|
|
self.all_prompts = None
|
|
self.all_seeds = None
|
|
self.all_subseeds = None
|
|
|
|
def txt2img_image_conditioning(self, x, width=None, height=None):
|
|
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
|
|
# Dummy zero conditioning if we're not using inpainting model.
|
|
# Still takes up a bit of memory, but no encoder call.
|
|
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
|
|
return x.new_zeros(x.shape[0], 5, 1, 1)
|
|
|
|
height = height or self.height
|
|
width = width or self.width
|
|
|
|
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
|
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
|
|
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
|
|
|
# Add the fake full 1s mask to the first dimension.
|
|
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
|
image_conditioning = image_conditioning.to(x.dtype)
|
|
|
|
return image_conditioning
|
|
|
|
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
|
|
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
|
|
# Dummy zero conditioning if we're not using inpainting model.
|
|
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
|
|
|
# Handle the different mask inputs
|
|
if image_mask is not None:
|
|
if torch.is_tensor(image_mask):
|
|
conditioning_mask = image_mask
|
|
else:
|
|
conditioning_mask = np.array(image_mask.convert("L"))
|
|
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
|
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
|
|
|
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
|
conditioning_mask = torch.round(conditioning_mask)
|
|
else:
|
|
conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
|
|
|
|
# Create another latent image, this time with a masked version of the original input.
|
|
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
|
|
conditioning_image = torch.lerp(
|
|
source_image,
|
|
source_image * (1.0 - conditioning_mask),
|
|
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
|
|
)
|
|
|
|
# Encode the new masked image using first stage of network.
|
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
|
|
|
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
|
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
|
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
|
image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
|
image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
|
|
|
return image_conditioning
|
|
|
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
|
pass
|
|
|
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class Processed:
|
|
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
|
|
self.images = images_list
|
|
self.prompt = p.prompt
|
|
self.negative_prompt = p.negative_prompt
|
|
self.seed = seed
|
|
self.subseed = subseed
|
|
self.subseed_strength = p.subseed_strength
|
|
self.info = info
|
|
self.width = p.width
|
|
self.height = p.height
|
|
self.sampler_index = p.sampler_index
|
|
self.sampler = sd_samplers.samplers[p.sampler_index].name
|
|
self.cfg_scale = p.cfg_scale
|
|
self.steps = p.steps
|
|
self.batch_size = p.batch_size
|
|
self.restore_faces = p.restore_faces
|
|
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
|
|
self.sd_model_hash = shared.sd_model.sd_model_hash
|
|
self.seed_resize_from_w = p.seed_resize_from_w
|
|
self.seed_resize_from_h = p.seed_resize_from_h
|
|
self.denoising_strength = getattr(p, 'denoising_strength', None)
|
|
self.extra_generation_params = p.extra_generation_params
|
|
self.index_of_first_image = index_of_first_image
|
|
self.styles = p.styles
|
|
self.job_timestamp = state.job_timestamp
|
|
self.clip_skip = opts.CLIP_stop_at_last_layers
|
|
|
|
self.eta = p.eta
|
|
self.ddim_discretize = p.ddim_discretize
|
|
self.s_churn = p.s_churn
|
|
self.s_tmin = p.s_tmin
|
|
self.s_tmax = p.s_tmax
|
|
self.s_noise = p.s_noise
|
|
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
|
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
|
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
|
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
|
|
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
|
|
|
self.all_prompts = all_prompts or [self.prompt]
|
|
self.all_seeds = all_seeds or [self.seed]
|
|
self.all_subseeds = all_subseeds or [self.subseed]
|
|
self.infotexts = infotexts or [info]
|
|
|
|
def js(self):
|
|
obj = {
|
|
"prompt": self.prompt,
|
|
"all_prompts": self.all_prompts,
|
|
"negative_prompt": self.negative_prompt,
|
|
"seed": self.seed,
|
|
"all_seeds": self.all_seeds,
|
|
"subseed": self.subseed,
|
|
"all_subseeds": self.all_subseeds,
|
|
"subseed_strength": self.subseed_strength,
|
|
"width": self.width,
|
|
"height": self.height,
|
|
"sampler_index": self.sampler_index,
|
|
"sampler": self.sampler,
|
|
"cfg_scale": self.cfg_scale,
|
|
"steps": self.steps,
|
|
"batch_size": self.batch_size,
|
|
"restore_faces": self.restore_faces,
|
|
"face_restoration_model": self.face_restoration_model,
|
|
"sd_model_hash": self.sd_model_hash,
|
|
"seed_resize_from_w": self.seed_resize_from_w,
|
|
"seed_resize_from_h": self.seed_resize_from_h,
|
|
"denoising_strength": self.denoising_strength,
|
|
"extra_generation_params": self.extra_generation_params,
|
|
"index_of_first_image": self.index_of_first_image,
|
|
"infotexts": self.infotexts,
|
|
"styles": self.styles,
|
|
"job_timestamp": self.job_timestamp,
|
|
"clip_skip": self.clip_skip,
|
|
}
|
|
|
|
return json.dumps(obj)
|
|
|
|
def infotext(self, p: StableDiffusionProcessing, index):
|
|
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
|
|
|
|
|
|
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
|
def slerp(val, low, high):
|
|
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
|
high_norm = high/torch.norm(high, dim=1, keepdim=True)
|
|
dot = (low_norm*high_norm).sum(1)
|
|
|
|
if dot.mean() > 0.9995:
|
|
return low * val + high * (1 - val)
|
|
|
|
omega = torch.acos(dot)
|
|
so = torch.sin(omega)
|
|
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
|
|
return res
|
|
|
|
|
|
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
|
|
xs = []
|
|
|
|
# if we have multiple seeds, this means we are working with batch size>1; this then
|
|
# enables the generation of additional tensors with noise that the sampler will use during its processing.
|
|
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
|
|
# produce the same images as with two batches [100], [101].
|
|
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
|
|
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
|
|
else:
|
|
sampler_noises = None
|
|
|
|
for i, seed in enumerate(seeds):
|
|
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
|
|
|
subnoise = None
|
|
if subseeds is not None:
|
|
subseed = 0 if i >= len(subseeds) else subseeds[i]
|
|
|
|
subnoise = devices.randn(subseed, noise_shape)
|
|
|
|
# randn results depend on device; gpu and cpu get different results for same seed;
|
|
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
|
# but the original script had it like this, so I do not dare change it for now because
|
|
# it will break everyone's seeds.
|
|
noise = devices.randn(seed, noise_shape)
|
|
|
|
if subnoise is not None:
|
|
noise = slerp(subseed_strength, noise, subnoise)
|
|
|
|
if noise_shape != shape:
|
|
x = devices.randn(seed, shape)
|
|
dx = (shape[2] - noise_shape[2]) // 2
|
|
dy = (shape[1] - noise_shape[1]) // 2
|
|
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
|
|
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
|
|
tx = 0 if dx < 0 else dx
|
|
ty = 0 if dy < 0 else dy
|
|
dx = max(-dx, 0)
|
|
dy = max(-dy, 0)
|
|
|
|
x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
|
|
noise = x
|
|
|
|
if sampler_noises is not None:
|
|
cnt = p.sampler.number_of_needed_noises(p)
|
|
|
|
if opts.eta_noise_seed_delta > 0:
|
|
torch.manual_seed(seed + opts.eta_noise_seed_delta)
|
|
|
|
for j in range(cnt):
|
|
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
|
|
|
xs.append(noise)
|
|
|
|
if sampler_noises is not None:
|
|
p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
|
|
|
|
x = torch.stack(xs).to(shared.device)
|
|
return x
|
|
|
|
|
|
def decode_first_stage(model, x):
|
|
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
|
x = model.decode_first_stage(x)
|
|
|
|
return x
|
|
|
|
|
|
def get_fixed_seed(seed):
|
|
if seed is None or seed == '' or seed == -1:
|
|
return int(random.randrange(4294967294))
|
|
|
|
return seed
|
|
|
|
|
|
def fix_seed(p):
|
|
p.seed = get_fixed_seed(p.seed)
|
|
p.subseed = get_fixed_seed(p.subseed)
|
|
|
|
|
|
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
|
|
index = position_in_batch + iteration * p.batch_size
|
|
|
|
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
|
|
|
generation_params = {
|
|
"Steps": p.steps,
|
|
"Sampler": get_correct_sampler(p)[p.sampler_index].name,
|
|
"CFG scale": p.cfg_scale,
|
|
"Seed": all_seeds[index],
|
|
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
|
"Size": f"{p.width}x{p.height}",
|
|
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
|
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
|
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
|
|
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
|
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
|
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
|
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
|
"Denoising strength": getattr(p, 'denoising_strength', None),
|
|
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
|
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
|
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
|
}
|
|
|
|
generation_params.update(p.extra_generation_params)
|
|
|
|
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
|
|
|
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
|
|
|
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
|
|
|
|
|
def process_images(p: StableDiffusionProcessing) -> Processed:
|
|
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
|
|
|
try:
|
|
for k, v in p.override_settings.items():
|
|
opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible
|
|
|
|
res = process_images_inner(p)
|
|
|
|
finally:
|
|
for k, v in stored_opts.items():
|
|
opts.data[k] = v
|
|
|
|
return res
|
|
|
|
|
|
def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
|
|
|
if type(p.prompt) == list:
|
|
assert(len(p.prompt) > 0)
|
|
else:
|
|
assert p.prompt is not None
|
|
|
|
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
|
processed = Processed(p, [], p.seed, "")
|
|
file.write(processed.infotext(p, 0))
|
|
|
|
devices.torch_gc()
|
|
|
|
seed = get_fixed_seed(p.seed)
|
|
subseed = get_fixed_seed(p.subseed)
|
|
|
|
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
|
modules.sd_hijack.model_hijack.clear_comments()
|
|
|
|
comments = {}
|
|
|
|
shared.prompt_styles.apply_styles(p)
|
|
|
|
if type(p.prompt) == list:
|
|
p.all_prompts = p.prompt
|
|
else:
|
|
p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
|
|
|
|
if type(seed) == list:
|
|
p.all_seeds = seed
|
|
else:
|
|
p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
|
|
|
|
if type(subseed) == list:
|
|
p.all_subseeds = subseed
|
|
else:
|
|
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
|
|
|
|
def infotext(iteration=0, position_in_batch=0):
|
|
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
|
|
|
|
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
|
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
|
|
|
if p.scripts is not None:
|
|
p.scripts.run_alwayson_scripts(p)
|
|
|
|
infotexts = []
|
|
output_images = []
|
|
|
|
with torch.no_grad(), p.sd_model.ema_scope():
|
|
with devices.autocast():
|
|
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
|
|
|
if state.job_count == -1:
|
|
state.job_count = p.n_iter
|
|
|
|
for n in range(p.n_iter):
|
|
if state.skipped:
|
|
state.skipped = False
|
|
|
|
if state.interrupted:
|
|
break
|
|
|
|
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
|
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
|
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
|
|
|
if (len(prompts) == 0):
|
|
break
|
|
|
|
with devices.autocast():
|
|
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
|
|
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
|
|
|
|
if len(model_hijack.comments) > 0:
|
|
for comment in model_hijack.comments:
|
|
comments[comment] = 1
|
|
|
|
if p.n_iter > 1:
|
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
|
|
|
with devices.autocast():
|
|
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
|
|
|
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
|
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
del samples_ddim
|
|
|
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
lowvram.send_everything_to_cpu()
|
|
|
|
devices.torch_gc()
|
|
|
|
if opts.filter_nsfw:
|
|
import modules.safety as safety
|
|
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
|
|
|
|
for i, x_sample in enumerate(x_samples_ddim):
|
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
|
x_sample = x_sample.astype(np.uint8)
|
|
|
|
if p.restore_faces:
|
|
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
|
|
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
|
|
|
|
devices.torch_gc()
|
|
|
|
x_sample = modules.face_restoration.restore_faces(x_sample)
|
|
devices.torch_gc()
|
|
|
|
image = Image.fromarray(x_sample)
|
|
|
|
if p.color_corrections is not None and i < len(p.color_corrections):
|
|
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
|
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
|
images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
|
|
image = apply_color_correction(p.color_corrections[i], image)
|
|
|
|
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
|
|
|
|
if opts.samples_save and not p.do_not_save_samples:
|
|
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
|
|
|
|
text = infotext(n, i)
|
|
infotexts.append(text)
|
|
if opts.enable_pnginfo:
|
|
image.info["parameters"] = text
|
|
output_images.append(image)
|
|
|
|
del x_samples_ddim
|
|
|
|
devices.torch_gc()
|
|
|
|
state.nextjob()
|
|
|
|
p.color_corrections = None
|
|
|
|
index_of_first_image = 0
|
|
unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
|
|
if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
|
|
grid = images.image_grid(output_images, p.batch_size)
|
|
|
|
if opts.return_grid:
|
|
text = infotext()
|
|
infotexts.insert(0, text)
|
|
if opts.enable_pnginfo:
|
|
grid.info["parameters"] = text
|
|
output_images.insert(0, grid)
|
|
index_of_first_image = 1
|
|
|
|
if opts.grid_save:
|
|
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
|
|
|
devices.torch_gc()
|
|
return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
|
|
|
|
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|
sampler = None
|
|
|
|
def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.enable_hr = enable_hr
|
|
self.denoising_strength = denoising_strength
|
|
self.firstphase_width = firstphase_width
|
|
self.firstphase_height = firstphase_height
|
|
self.truncate_x = 0
|
|
self.truncate_y = 0
|
|
|
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
|
if self.enable_hr:
|
|
if state.job_count == -1:
|
|
state.job_count = self.n_iter * 2
|
|
else:
|
|
state.job_count = state.job_count * 2
|
|
|
|
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
|
|
|
|
if self.firstphase_width == 0 or self.firstphase_height == 0:
|
|
desired_pixel_count = 512 * 512
|
|
actual_pixel_count = self.width * self.height
|
|
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
|
|
self.firstphase_width = math.ceil(scale * self.width / 64) * 64
|
|
self.firstphase_height = math.ceil(scale * self.height / 64) * 64
|
|
firstphase_width_truncated = int(scale * self.width)
|
|
firstphase_height_truncated = int(scale * self.height)
|
|
|
|
else:
|
|
|
|
width_ratio = self.width / self.firstphase_width
|
|
height_ratio = self.height / self.firstphase_height
|
|
|
|
if width_ratio > height_ratio:
|
|
firstphase_width_truncated = self.firstphase_width
|
|
firstphase_height_truncated = self.firstphase_width * self.height / self.width
|
|
else:
|
|
firstphase_width_truncated = self.firstphase_height * self.width / self.height
|
|
firstphase_height_truncated = self.firstphase_height
|
|
|
|
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
|
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
|
|
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
|
|
|
if not self.enable_hr:
|
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
|
return samples
|
|
|
|
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
|
|
|
|
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
|
|
|
if opts.use_scale_latent_for_hires_fix:
|
|
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
|
|
|
# Avoid making the inpainting conditioning unless necessary as
|
|
# this does need some extra compute to decode / encode the image again.
|
|
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
|
|
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
|
|
else:
|
|
image_conditioning = self.txt2img_image_conditioning(samples)
|
|
|
|
else:
|
|
decoded_samples = decode_first_stage(self.sd_model, samples)
|
|
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
batch_images = []
|
|
for i, x_sample in enumerate(lowres_samples):
|
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
|
x_sample = x_sample.astype(np.uint8)
|
|
image = Image.fromarray(x_sample)
|
|
image = images.resize_image(0, image, self.width, self.height)
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
image = np.moveaxis(image, 2, 0)
|
|
batch_images.append(image)
|
|
|
|
decoded_samples = torch.from_numpy(np.array(batch_images))
|
|
decoded_samples = decoded_samples.to(shared.device)
|
|
decoded_samples = 2. * decoded_samples - 1.
|
|
|
|
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
|
|
|
|
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
|
|
|
|
shared.state.nextjob()
|
|
|
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
|
|
|
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
|
|
|
# GC now before running the next img2img to prevent running out of memory
|
|
x = None
|
|
devices.torch_gc()
|
|
|
|
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
|
|
|
|
return samples
|
|
|
|
|
|
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|
sampler = None
|
|
|
|
def __init__(self, init_images: list=None, resize_mode: int=0, denoising_strength: float=0.75, mask: Any=None, mask_blur: int=4, inpainting_fill: int=0, inpaint_full_res: bool=True, inpaint_full_res_padding: int=0, inpainting_mask_invert: int=0, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
self.init_images = init_images
|
|
self.resize_mode: int = resize_mode
|
|
self.denoising_strength: float = denoising_strength
|
|
self.init_latent = None
|
|
self.image_mask = mask
|
|
#self.image_unblurred_mask = None
|
|
self.latent_mask = None
|
|
self.mask_for_overlay = None
|
|
self.mask_blur = mask_blur
|
|
self.inpainting_fill = inpainting_fill
|
|
self.inpaint_full_res = inpaint_full_res
|
|
self.inpaint_full_res_padding = inpaint_full_res_padding
|
|
self.inpainting_mask_invert = inpainting_mask_invert
|
|
self.mask = None
|
|
self.nmask = None
|
|
self.image_conditioning = None
|
|
|
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
|
|
crop_region = None
|
|
|
|
if self.image_mask is not None:
|
|
self.image_mask = self.image_mask.convert('L')
|
|
|
|
if self.inpainting_mask_invert:
|
|
self.image_mask = ImageOps.invert(self.image_mask)
|
|
|
|
#self.image_unblurred_mask = self.image_mask
|
|
|
|
if self.mask_blur > 0:
|
|
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
|
|
|
if self.inpaint_full_res:
|
|
self.mask_for_overlay = self.image_mask
|
|
mask = self.image_mask.convert('L')
|
|
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
|
|
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
|
|
x1, y1, x2, y2 = crop_region
|
|
|
|
mask = mask.crop(crop_region)
|
|
self.image_mask = images.resize_image(2, mask, self.width, self.height)
|
|
self.paste_to = (x1, y1, x2-x1, y2-y1)
|
|
else:
|
|
self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
|
|
np_mask = np.array(self.image_mask)
|
|
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
|
|
self.mask_for_overlay = Image.fromarray(np_mask)
|
|
|
|
self.overlay_images = []
|
|
|
|
latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
|
|
|
|
add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
|
|
if add_color_corrections:
|
|
self.color_corrections = []
|
|
imgs = []
|
|
for img in self.init_images:
|
|
image = img.convert("RGB")
|
|
|
|
if crop_region is None:
|
|
image = images.resize_image(self.resize_mode, image, self.width, self.height)
|
|
|
|
if self.image_mask is not None:
|
|
image_masked = Image.new('RGBa', (image.width, image.height))
|
|
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
|
|
|
|
self.overlay_images.append(image_masked.convert('RGBA'))
|
|
|
|
if crop_region is not None:
|
|
image = image.crop(crop_region)
|
|
image = images.resize_image(2, image, self.width, self.height)
|
|
|
|
if self.image_mask is not None:
|
|
if self.inpainting_fill != 1:
|
|
image = masking.fill(image, latent_mask)
|
|
|
|
if add_color_corrections:
|
|
self.color_corrections.append(setup_color_correction(image))
|
|
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
image = np.moveaxis(image, 2, 0)
|
|
|
|
imgs.append(image)
|
|
|
|
if len(imgs) == 1:
|
|
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
|
|
if self.overlay_images is not None:
|
|
self.overlay_images = self.overlay_images * self.batch_size
|
|
|
|
if self.color_corrections is not None and len(self.color_corrections) == 1:
|
|
self.color_corrections = self.color_corrections * self.batch_size
|
|
|
|
elif len(imgs) <= self.batch_size:
|
|
self.batch_size = len(imgs)
|
|
batch_images = np.array(imgs)
|
|
else:
|
|
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
|
|
|
|
image = torch.from_numpy(batch_images)
|
|
image = 2. * image - 1.
|
|
image = image.to(shared.device)
|
|
|
|
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
|
|
|
if self.image_mask is not None:
|
|
init_mask = latent_mask
|
|
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
|
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
|
latmask = latmask[0]
|
|
latmask = np.around(latmask)
|
|
latmask = np.tile(latmask[None], (4, 1, 1))
|
|
|
|
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
|
|
self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
|
|
|
|
# this needs to be fixed to be done in sample() using actual seeds for batches
|
|
if self.inpainting_fill == 2:
|
|
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
|
|
elif self.inpainting_fill == 3:
|
|
self.init_latent = self.init_latent * self.mask
|
|
|
|
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
|
|
|
|
|
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
|
|
|
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
|
|
|
if self.mask is not None:
|
|
samples = samples * self.nmask + self.init_latent * self.mask
|
|
|
|
del x
|
|
devices.torch_gc()
|
|
|
|
return samples |