2022-09-03 17:08:45 +08:00
|
|
|
import json
|
|
|
|
import math
|
|
|
|
import os
|
|
|
|
import sys
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
from PIL import Image, ImageFilter, ImageOps
|
|
|
|
import random
|
2022-09-13 17:51:57 +08:00
|
|
|
import cv2
|
|
|
|
from skimage import exposure
|
2022-10-18 03:10:36 +08:00
|
|
|
from typing import Any, Dict, List, Optional
|
2022-09-03 17:08:45 +08:00
|
|
|
|
2022-09-05 08:25:37 +08:00
|
|
|
import modules.sd_hijack
|
2022-10-07 00:41:37 +08:00
|
|
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
|
2022-09-03 17:08:45 +08:00
|
|
|
from modules.sd_hijack import model_hijack
|
|
|
|
from modules.shared import opts, cmd_opts, state
|
|
|
|
import modules.shared as shared
|
2022-09-07 17:32:28 +08:00
|
|
|
import modules.face_restoration
|
2022-09-03 17:08:45 +08:00
|
|
|
import modules.images as images
|
2022-09-10 04:16:02 +08:00
|
|
|
import modules.styles
|
2022-09-23 08:57:42 +08:00
|
|
|
import logging
|
2022-09-03 17:08:45 +08:00
|
|
|
|
2022-09-13 17:51:57 +08:00
|
|
|
|
2022-09-03 17:08:45 +08:00
|
|
|
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
|
|
|
opt_C = 4
|
|
|
|
opt_f = 8
|
|
|
|
|
|
|
|
|
2022-09-13 17:51:57 +08:00
|
|
|
def setup_color_correction(image):
|
2022-09-23 08:57:42 +08:00
|
|
|
logging.info("Calibrating color correction.")
|
2022-09-13 17:51:57 +08:00
|
|
|
correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
|
|
|
|
return correction_target
|
|
|
|
|
|
|
|
|
|
|
|
def apply_color_correction(correction, image):
|
2022-09-23 08:57:42 +08:00
|
|
|
logging.info("Applying color correction.")
|
2022-09-13 17:51:57 +08:00
|
|
|
image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
|
|
|
|
cv2.cvtColor(
|
|
|
|
np.asarray(image),
|
|
|
|
cv2.COLOR_RGB2LAB
|
|
|
|
),
|
|
|
|
correction,
|
|
|
|
channel_axis=2
|
|
|
|
), cv2.COLOR_LAB2RGB).astype("uint8"))
|
|
|
|
|
|
|
|
return image
|
|
|
|
|
|
|
|
|
2022-10-09 08:13:13 +08:00
|
|
|
def get_correct_sampler(p):
|
|
|
|
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
|
|
|
|
return sd_samplers.samplers
|
|
|
|
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
|
|
|
|
return sd_samplers.samplers_for_img2img
|
2022-10-18 03:10:36 +08:00
|
|
|
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
|
|
|
|
return sd_samplers.samplers
|
2022-10-09 08:13:13 +08:00
|
|
|
|
2022-10-18 03:10:36 +08:00
|
|
|
class StableDiffusionProcessing():
|
|
|
|
"""
|
|
|
|
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
|
|
|
|
|
|
|
"""
|
|
|
|
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
|
2022-09-03 17:08:45 +08:00
|
|
|
self.sd_model = sd_model
|
|
|
|
self.outpath_samples: str = outpath_samples
|
|
|
|
self.outpath_grids: str = outpath_grids
|
|
|
|
self.prompt: str = prompt
|
2022-09-03 22:21:15 +08:00
|
|
|
self.prompt_for_display: str = None
|
2022-09-03 17:08:45 +08:00
|
|
|
self.negative_prompt: str = (negative_prompt or "")
|
2022-10-02 20:03:39 +08:00
|
|
|
self.styles: list = styles or []
|
2022-09-03 17:08:45 +08:00
|
|
|
self.seed: int = seed
|
2022-09-09 22:54:04 +08:00
|
|
|
self.subseed: int = subseed
|
|
|
|
self.subseed_strength: float = subseed_strength
|
|
|
|
self.seed_resize_from_h: int = seed_resize_from_h
|
|
|
|
self.seed_resize_from_w: int = seed_resize_from_w
|
2022-09-03 17:08:45 +08:00
|
|
|
self.sampler_index: int = sampler_index
|
|
|
|
self.batch_size: int = batch_size
|
|
|
|
self.n_iter: int = n_iter
|
|
|
|
self.steps: int = steps
|
|
|
|
self.cfg_scale: float = cfg_scale
|
|
|
|
self.width: int = width
|
|
|
|
self.height: int = height
|
2022-09-07 17:32:28 +08:00
|
|
|
self.restore_faces: bool = restore_faces
|
2022-09-05 08:25:37 +08:00
|
|
|
self.tiling: bool = tiling
|
2022-09-03 17:08:45 +08:00
|
|
|
self.do_not_save_samples: bool = do_not_save_samples
|
|
|
|
self.do_not_save_grid: bool = do_not_save_grid
|
2022-09-21 00:07:09 +08:00
|
|
|
self.extra_generation_params: dict = extra_generation_params or {}
|
2022-09-03 17:08:45 +08:00
|
|
|
self.overlay_images = overlay_images
|
2022-09-28 23:09:06 +08:00
|
|
|
self.eta = eta
|
2022-10-16 13:51:24 +08:00
|
|
|
self.do_not_reload_embeddings = do_not_reload_embeddings
|
2022-09-03 17:08:45 +08:00
|
|
|
self.paste_to = None
|
2022-09-13 17:51:57 +08:00
|
|
|
self.color_corrections = None
|
2022-09-19 21:42:56 +08:00
|
|
|
self.denoising_strength: float = 0
|
2022-09-30 08:44:38 +08:00
|
|
|
self.sampler_noise_scheduler_override = None
|
2022-09-26 22:40:47 +08:00
|
|
|
self.ddim_discretize = opts.ddim_discretize
|
2022-10-18 03:10:36 +08:00
|
|
|
self.s_churn = s_churn or opts.s_churn
|
|
|
|
self.s_tmin = s_tmin or opts.s_tmin
|
|
|
|
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
|
|
|
self.s_noise = s_noise or opts.s_noise
|
2022-10-04 23:49:51 +08:00
|
|
|
|
2022-09-21 18:34:10 +08:00
|
|
|
if not seed_enable_extras:
|
|
|
|
self.subseed = -1
|
|
|
|
self.subseed_strength = 0
|
|
|
|
self.seed_resize_from_h = 0
|
|
|
|
self.seed_resize_from_w = 0
|
|
|
|
|
2022-10-18 03:10:36 +08:00
|
|
|
|
2022-09-19 21:42:56 +08:00
|
|
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
2022-09-03 17:08:45 +08:00
|
|
|
pass
|
|
|
|
|
2022-09-19 21:42:56 +08:00
|
|
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
2022-09-03 17:08:45 +08:00
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
|
|
class Processed:
|
2022-09-28 22:05:23 +08:00
|
|
|
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
|
2022-09-03 17:08:45 +08:00
|
|
|
self.images = images_list
|
|
|
|
self.prompt = p.prompt
|
2022-09-13 00:57:31 +08:00
|
|
|
self.negative_prompt = p.negative_prompt
|
2022-09-03 17:08:45 +08:00
|
|
|
self.seed = seed
|
2022-09-17 03:20:56 +08:00
|
|
|
self.subseed = subseed
|
|
|
|
self.subseed_strength = p.subseed_strength
|
2022-09-03 17:08:45 +08:00
|
|
|
self.info = info
|
|
|
|
self.width = p.width
|
|
|
|
self.height = p.height
|
2022-09-19 14:02:10 +08:00
|
|
|
self.sampler_index = p.sampler_index
|
2022-10-06 17:08:48 +08:00
|
|
|
self.sampler = sd_samplers.samplers[p.sampler_index].name
|
2022-09-03 17:08:45 +08:00
|
|
|
self.cfg_scale = p.cfg_scale
|
|
|
|
self.steps = p.steps
|
2022-09-19 14:02:10 +08:00
|
|
|
self.batch_size = p.batch_size
|
|
|
|
self.restore_faces = p.restore_faces
|
|
|
|
self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None
|
|
|
|
self.sd_model_hash = shared.sd_model.sd_model_hash
|
|
|
|
self.seed_resize_from_w = p.seed_resize_from_w
|
|
|
|
self.seed_resize_from_h = p.seed_resize_from_h
|
|
|
|
self.denoising_strength = getattr(p, 'denoising_strength', None)
|
|
|
|
self.extra_generation_params = p.extra_generation_params
|
|
|
|
self.index_of_first_image = index_of_first_image
|
2022-10-05 01:13:09 +08:00
|
|
|
self.styles = p.styles
|
2022-10-05 01:17:15 +08:00
|
|
|
self.job_timestamp = state.job_timestamp
|
2022-10-09 05:28:42 +08:00
|
|
|
self.clip_skip = opts.CLIP_stop_at_last_layers
|
2022-09-19 14:02:10 +08:00
|
|
|
|
2022-09-28 10:11:03 +08:00
|
|
|
self.eta = p.eta
|
2022-09-26 22:40:47 +08:00
|
|
|
self.ddim_discretize = p.ddim_discretize
|
|
|
|
self.s_churn = p.s_churn
|
|
|
|
self.s_tmin = p.s_tmin
|
|
|
|
self.s_tmax = p.s_tmax
|
|
|
|
self.s_noise = p.s_noise
|
2022-09-30 08:44:38 +08:00
|
|
|
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
2022-09-19 14:02:10 +08:00
|
|
|
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
|
|
|
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
2022-10-14 11:05:07 +08:00
|
|
|
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
|
2022-09-19 14:02:10 +08:00
|
|
|
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
|
|
|
|
|
|
|
|
self.all_prompts = all_prompts or [self.prompt]
|
|
|
|
self.all_seeds = all_seeds or [self.seed]
|
|
|
|
self.all_subseeds = all_subseeds or [self.subseed]
|
2022-09-28 22:05:23 +08:00
|
|
|
self.infotexts = infotexts or [info]
|
2022-09-03 17:08:45 +08:00
|
|
|
|
|
|
|
def js(self):
|
|
|
|
obj = {
|
2022-09-19 14:02:10 +08:00
|
|
|
"prompt": self.prompt,
|
|
|
|
"all_prompts": self.all_prompts,
|
|
|
|
"negative_prompt": self.negative_prompt,
|
|
|
|
"seed": self.seed,
|
|
|
|
"all_seeds": self.all_seeds,
|
|
|
|
"subseed": self.subseed,
|
|
|
|
"all_subseeds": self.all_subseeds,
|
2022-09-17 03:20:56 +08:00
|
|
|
"subseed_strength": self.subseed_strength,
|
2022-09-03 17:08:45 +08:00
|
|
|
"width": self.width,
|
|
|
|
"height": self.height,
|
2022-09-19 14:02:10 +08:00
|
|
|
"sampler_index": self.sampler_index,
|
2022-09-03 17:08:45 +08:00
|
|
|
"sampler": self.sampler,
|
|
|
|
"cfg_scale": self.cfg_scale,
|
|
|
|
"steps": self.steps,
|
2022-09-19 14:02:10 +08:00
|
|
|
"batch_size": self.batch_size,
|
|
|
|
"restore_faces": self.restore_faces,
|
|
|
|
"face_restoration_model": self.face_restoration_model,
|
|
|
|
"sd_model_hash": self.sd_model_hash,
|
|
|
|
"seed_resize_from_w": self.seed_resize_from_w,
|
|
|
|
"seed_resize_from_h": self.seed_resize_from_h,
|
|
|
|
"denoising_strength": self.denoising_strength,
|
|
|
|
"extra_generation_params": self.extra_generation_params,
|
|
|
|
"index_of_first_image": self.index_of_first_image,
|
2022-09-28 22:05:23 +08:00
|
|
|
"infotexts": self.infotexts,
|
2022-10-05 01:13:09 +08:00
|
|
|
"styles": self.styles,
|
2022-10-05 01:17:15 +08:00
|
|
|
"job_timestamp": self.job_timestamp,
|
2022-10-09 03:21:15 +08:00
|
|
|
"clip_skip": self.clip_skip,
|
2022-09-03 17:08:45 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return json.dumps(obj)
|
|
|
|
|
2022-09-19 14:02:10 +08:00
|
|
|
def infotext(self, p: StableDiffusionProcessing, index):
|
|
|
|
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
|
|
|
|
|
|
|
|
|
2022-09-09 22:54:04 +08:00
|
|
|
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
|
|
|
def slerp(val, low, high):
|
|
|
|
low_norm = low/torch.norm(low, dim=1, keepdim=True)
|
|
|
|
high_norm = high/torch.norm(high, dim=1, keepdim=True)
|
2022-09-19 01:44:57 +08:00
|
|
|
dot = (low_norm*high_norm).sum(1)
|
|
|
|
|
|
|
|
if dot.mean() > 0.9995:
|
|
|
|
return low * val + high * (1 - val)
|
|
|
|
|
|
|
|
omega = torch.acos(dot)
|
2022-09-09 22:54:04 +08:00
|
|
|
so = torch.sin(omega)
|
|
|
|
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
|
|
|
|
return res
|
2022-09-03 17:08:45 +08:00
|
|
|
|
2022-09-09 22:54:04 +08:00
|
|
|
|
2022-09-14 02:49:58 +08:00
|
|
|
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
|
2022-09-03 17:08:45 +08:00
|
|
|
xs = []
|
2022-09-14 02:49:58 +08:00
|
|
|
|
2022-09-16 15:04:07 +08:00
|
|
|
# if we have multiple seeds, this means we are working with batch size>1; this then
|
|
|
|
# enables the generation of additional tensors with noise that the sampler will use during its processing.
|
2022-09-18 13:09:52 +08:00
|
|
|
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
|
2022-09-16 15:04:07 +08:00
|
|
|
# produce the same images as with two batches [100], [101].
|
2022-10-11 01:32:37 +08:00
|
|
|
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0):
|
2022-09-14 02:49:58 +08:00
|
|
|
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
|
|
|
|
else:
|
|
|
|
sampler_noises = None
|
|
|
|
|
2022-09-09 22:54:04 +08:00
|
|
|
for i, seed in enumerate(seeds):
|
|
|
|
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
|
|
|
|
|
|
|
|
subnoise = None
|
|
|
|
if subseeds is not None:
|
|
|
|
subseed = 0 if i >= len(subseeds) else subseeds[i]
|
2022-09-12 22:32:44 +08:00
|
|
|
|
2022-09-13 01:09:32 +08:00
|
|
|
subnoise = devices.randn(subseed, noise_shape)
|
2022-09-03 17:08:45 +08:00
|
|
|
|
|
|
|
# randn results depend on device; gpu and cpu get different results for same seed;
|
|
|
|
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
|
2022-09-09 22:54:04 +08:00
|
|
|
# but the original script had it like this, so I do not dare change it for now because
|
2022-09-03 17:08:45 +08:00
|
|
|
# it will break everyone's seeds.
|
2022-09-13 01:09:32 +08:00
|
|
|
noise = devices.randn(seed, noise_shape)
|
2022-09-09 22:54:04 +08:00
|
|
|
|
|
|
|
if subnoise is not None:
|
|
|
|
noise = slerp(subseed_strength, noise, subnoise)
|
|
|
|
|
|
|
|
if noise_shape != shape:
|
2022-09-13 01:09:32 +08:00
|
|
|
x = devices.randn(seed, shape)
|
|
|
|
dx = (shape[2] - noise_shape[2]) // 2
|
2022-09-09 22:54:04 +08:00
|
|
|
dy = (shape[1] - noise_shape[1]) // 2
|
|
|
|
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
|
|
|
|
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
|
|
|
|
tx = 0 if dx < 0 else dx
|
|
|
|
ty = 0 if dy < 0 else dy
|
|
|
|
dx = max(-dx, 0)
|
|
|
|
dy = max(-dy, 0)
|
|
|
|
|
|
|
|
x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
|
|
|
|
noise = x
|
|
|
|
|
2022-09-14 02:49:58 +08:00
|
|
|
if sampler_noises is not None:
|
|
|
|
cnt = p.sampler.number_of_needed_noises(p)
|
2022-09-09 22:54:04 +08:00
|
|
|
|
2022-10-11 01:32:37 +08:00
|
|
|
if opts.eta_noise_seed_delta > 0:
|
|
|
|
torch.manual_seed(seed + opts.eta_noise_seed_delta)
|
|
|
|
|
2022-09-14 02:49:58 +08:00
|
|
|
for j in range(cnt):
|
|
|
|
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
|
2022-09-09 22:54:04 +08:00
|
|
|
|
|
|
|
xs.append(noise)
|
2022-09-14 02:49:58 +08:00
|
|
|
|
|
|
|
if sampler_noises is not None:
|
|
|
|
p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
|
|
|
|
|
2022-09-09 22:54:04 +08:00
|
|
|
x = torch.stack(xs).to(shared.device)
|
2022-09-03 17:08:45 +08:00
|
|
|
return x
|
|
|
|
|
|
|
|
|
2022-10-10 21:11:14 +08:00
|
|
|
def decode_first_stage(model, x):
|
|
|
|
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
|
|
|
x = model.decode_first_stage(x)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
2022-10-04 22:36:39 +08:00
|
|
|
def get_fixed_seed(seed):
|
|
|
|
if seed is None or seed == '' or seed == -1:
|
|
|
|
return int(random.randrange(4294967294))
|
|
|
|
|
|
|
|
return seed
|
|
|
|
|
|
|
|
|
2022-09-09 22:54:04 +08:00
|
|
|
def fix_seed(p):
|
2022-10-04 22:36:39 +08:00
|
|
|
p.seed = get_fixed_seed(p.seed)
|
|
|
|
p.subseed = get_fixed_seed(p.subseed)
|
2022-09-07 06:44:44 +08:00
|
|
|
|
|
|
|
|
2022-09-19 14:02:10 +08:00
|
|
|
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
|
|
|
|
index = position_in_batch + iteration * p.batch_size
|
|
|
|
|
2022-10-09 05:28:42 +08:00
|
|
|
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
2022-10-09 03:21:15 +08:00
|
|
|
|
2022-09-19 14:02:10 +08:00
|
|
|
generation_params = {
|
|
|
|
"Steps": p.steps,
|
2022-10-09 08:13:13 +08:00
|
|
|
"Sampler": get_correct_sampler(p)[p.sampler_index].name,
|
2022-09-19 14:02:10 +08:00
|
|
|
"CFG scale": p.cfg_scale,
|
|
|
|
"Seed": all_seeds[index],
|
|
|
|
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
|
|
|
|
"Size": f"{p.width}x{p.height}",
|
|
|
|
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
|
2022-10-09 19:57:48 +08:00
|
|
|
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
|
2022-10-19 21:17:47 +08:00
|
|
|
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.filename.split('\\')[-1].split('.')[0]),
|
2022-09-19 14:02:10 +08:00
|
|
|
"Batch size": (None if p.batch_size < 2 else p.batch_size),
|
|
|
|
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
|
|
|
|
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
|
|
|
|
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
|
|
|
|
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
|
|
|
|
"Denoising strength": getattr(p, 'denoising_strength', None),
|
2022-10-02 20:03:39 +08:00
|
|
|
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
|
2022-10-10 03:30:59 +08:00
|
|
|
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
2022-10-11 01:32:37 +08:00
|
|
|
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
2022-09-19 14:02:10 +08:00
|
|
|
}
|
|
|
|
|
2022-09-21 00:07:09 +08:00
|
|
|
generation_params.update(p.extra_generation_params)
|
2022-09-19 14:02:10 +08:00
|
|
|
|
|
|
|
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
|
|
|
|
|
|
|
|
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
|
|
|
|
|
2022-09-28 23:20:30 +08:00
|
|
|
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
2022-09-19 14:02:10 +08:00
|
|
|
|
|
|
|
|
2022-09-03 17:08:45 +08:00
|
|
|
def process_images(p: StableDiffusionProcessing) -> Processed:
|
|
|
|
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
|
|
|
|
|
2022-09-17 16:34:33 +08:00
|
|
|
if type(p.prompt) == list:
|
|
|
|
assert(len(p.prompt) > 0)
|
|
|
|
else:
|
|
|
|
assert p.prompt is not None
|
2022-10-04 23:49:51 +08:00
|
|
|
|
2022-10-09 11:57:19 +08:00
|
|
|
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
|
|
|
|
processed = Processed(p, [], p.seed, "")
|
|
|
|
file.write(processed.infotext(p, 0))
|
|
|
|
|
2022-09-12 04:24:24 +08:00
|
|
|
devices.torch_gc()
|
2022-09-03 17:08:45 +08:00
|
|
|
|
2022-10-04 22:36:39 +08:00
|
|
|
seed = get_fixed_seed(p.seed)
|
|
|
|
subseed = get_fixed_seed(p.subseed)
|
2022-09-03 17:08:45 +08:00
|
|
|
|
2022-09-05 08:25:37 +08:00
|
|
|
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
|
2022-10-08 05:48:34 +08:00
|
|
|
modules.sd_hijack.model_hijack.clear_comments()
|
2022-09-05 08:25:37 +08:00
|
|
|
|
2022-09-15 13:57:03 +08:00
|
|
|
comments = {}
|
2022-09-03 17:08:45 +08:00
|
|
|
|
2022-09-14 22:56:21 +08:00
|
|
|
shared.prompt_styles.apply_styles(p)
|
2022-09-10 04:16:02 +08:00
|
|
|
|
|
|
|
if type(p.prompt) == list:
|
|
|
|
all_prompts = p.prompt
|
2022-09-03 17:08:45 +08:00
|
|
|
else:
|
2022-09-10 04:16:02 +08:00
|
|
|
all_prompts = p.batch_size * p.n_iter * [p.prompt]
|
2022-09-03 22:21:15 +08:00
|
|
|
|
2022-10-04 22:36:39 +08:00
|
|
|
if type(seed) == list:
|
|
|
|
all_seeds = seed
|
2022-09-03 22:21:15 +08:00
|
|
|
else:
|
2022-10-04 22:36:39 +08:00
|
|
|
all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
|
2022-09-09 22:54:04 +08:00
|
|
|
|
2022-10-04 22:36:39 +08:00
|
|
|
if type(subseed) == list:
|
|
|
|
all_subseeds = subseed
|
2022-09-09 22:54:04 +08:00
|
|
|
else:
|
2022-10-04 22:36:39 +08:00
|
|
|
all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
|
2022-09-03 17:08:45 +08:00
|
|
|
|
|
|
|
def infotext(iteration=0, position_in_batch=0):
|
2022-09-19 14:02:10 +08:00
|
|
|
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
|
2022-09-03 17:08:45 +08:00
|
|
|
|
2022-10-16 13:51:24 +08:00
|
|
|
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
|
2022-10-02 20:03:39 +08:00
|
|
|
model_hijack.embedding_db.load_textual_inversion_embeddings()
|
2022-09-03 17:08:45 +08:00
|
|
|
|
2022-09-28 22:05:23 +08:00
|
|
|
infotexts = []
|
2022-09-03 17:08:45 +08:00
|
|
|
output_images = []
|
2022-10-04 17:32:22 +08:00
|
|
|
|
2022-10-09 04:26:48 +08:00
|
|
|
with torch.no_grad(), p.sd_model.ema_scope():
|
2022-10-04 21:54:31 +08:00
|
|
|
with devices.autocast():
|
|
|
|
p.init(all_prompts, all_seeds, all_subseeds)
|
2022-09-03 17:08:45 +08:00
|
|
|
|
2022-09-06 15:11:25 +08:00
|
|
|
if state.job_count == -1:
|
|
|
|
state.job_count = p.n_iter
|
2022-09-06 07:09:01 +08:00
|
|
|
|
2022-10-05 09:28:50 +08:00
|
|
|
for n in range(p.n_iter):
|
2022-10-05 11:56:30 +08:00
|
|
|
if state.skipped:
|
|
|
|
state.skipped = False
|
|
|
|
|
2022-09-03 17:08:45 +08:00
|
|
|
if state.interrupted:
|
|
|
|
break
|
|
|
|
|
|
|
|
prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
|
|
|
|
seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
|
2022-09-10 21:16:18 +08:00
|
|
|
subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
|
2022-09-03 17:08:45 +08:00
|
|
|
|
2022-09-17 16:34:33 +08:00
|
|
|
if (len(prompts) == 0):
|
|
|
|
break
|
|
|
|
|
2022-09-15 18:10:16 +08:00
|
|
|
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
|
|
|
|
#c = p.sd_model.get_learned_conditioning(prompts)
|
2022-10-04 17:32:22 +08:00
|
|
|
with devices.autocast():
|
2022-10-04 23:49:51 +08:00
|
|
|
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
|
2022-10-06 04:16:27 +08:00
|
|
|
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
|
2022-09-03 17:08:45 +08:00
|
|
|
|
|
|
|
if len(model_hijack.comments) > 0:
|
2022-09-15 13:57:03 +08:00
|
|
|
for comment in model_hijack.comments:
|
|
|
|
comments[comment] = 1
|
2022-09-03 17:08:45 +08:00
|
|
|
|
|
|
|
if p.n_iter > 1:
|
2022-09-24 13:23:01 +08:00
|
|
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
2022-09-03 17:08:45 +08:00
|
|
|
|
2022-10-04 17:32:22 +08:00
|
|
|
with devices.autocast():
|
2022-10-04 19:23:22 +08:00
|
|
|
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
|
2022-10-04 17:32:22 +08:00
|
|
|
|
2022-10-10 22:03:45 +08:00
|
|
|
samples_ddim = samples_ddim.to(devices.dtype_vae)
|
2022-10-10 21:11:14 +08:00
|
|
|
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
|
2022-09-03 17:08:45 +08:00
|
|
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
|
2022-09-29 09:14:13 +08:00
|
|
|
del samples_ddim
|
|
|
|
|
|
|
|
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|
|
|
|
lowvram.send_everything_to_cpu()
|
|
|
|
|
|
|
|
devices.torch_gc()
|
|
|
|
|
2022-09-13 08:15:35 +08:00
|
|
|
if opts.filter_nsfw:
|
2022-09-13 13:34:41 +08:00
|
|
|
import modules.safety as safety
|
|
|
|
x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
|
2022-09-13 08:15:35 +08:00
|
|
|
|
2022-10-05 09:28:50 +08:00
|
|
|
for i, x_sample in enumerate(x_samples_ddim):
|
2022-09-03 17:08:45 +08:00
|
|
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
|
|
|
x_sample = x_sample.astype(np.uint8)
|
|
|
|
|
2022-10-05 09:28:50 +08:00
|
|
|
if p.restore_faces:
|
2022-09-12 22:47:36 +08:00
|
|
|
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
|
2022-09-22 18:54:50 +08:00
|
|
|
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
|
2022-09-12 22:47:36 +08:00
|
|
|
|
2022-10-04 17:32:22 +08:00
|
|
|
devices.torch_gc()
|
2022-09-03 17:08:45 +08:00
|
|
|
|
2022-10-05 09:28:50 +08:00
|
|
|
x_sample = modules.face_restoration.restore_faces(x_sample)
|
|
|
|
devices.torch_gc()
|
2022-09-29 09:14:13 +08:00
|
|
|
|
2022-09-03 17:08:45 +08:00
|
|
|
image = Image.fromarray(x_sample)
|
|
|
|
|
2022-09-13 17:51:57 +08:00
|
|
|
if p.color_corrections is not None and i < len(p.color_corrections):
|
2022-09-22 18:54:50 +08:00
|
|
|
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
|
2022-09-23 08:57:42 +08:00
|
|
|
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
|
2022-09-13 17:51:57 +08:00
|
|
|
image = apply_color_correction(p.color_corrections[i], image)
|
2022-09-12 22:47:36 +08:00
|
|
|
|
2022-09-03 17:08:45 +08:00
|
|
|
if p.overlay_images is not None and i < len(p.overlay_images):
|
|
|
|
overlay = p.overlay_images[i]
|
|
|
|
|
|
|
|
if p.paste_to is not None:
|
|
|
|
x, y, w, h = p.paste_to
|
|
|
|
base_image = Image.new('RGBA', (overlay.width, overlay.height))
|
|
|
|
image = images.resize_image(1, image, w, h)
|
|
|
|
base_image.paste(image, (x, y))
|
|
|
|
image = base_image
|
|
|
|
|
|
|
|
image = image.convert('RGBA')
|
|
|
|
image.alpha_composite(overlay)
|
|
|
|
image = image.convert('RGB')
|
|
|
|
|
|
|
|
if opts.samples_save and not p.do_not_save_samples:
|
2022-09-12 20:41:30 +08:00
|
|
|
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
|
2022-09-03 17:08:45 +08:00
|
|
|
|
2022-10-07 01:27:50 +08:00
|
|
|
text = infotext(n, i)
|
|
|
|
infotexts.append(text)
|
2022-10-09 18:10:15 +08:00
|
|
|
if opts.enable_pnginfo:
|
|
|
|
image.info["parameters"] = text
|
2022-09-03 17:08:45 +08:00
|
|
|
output_images.append(image)
|
|
|
|
|
2022-10-05 09:28:50 +08:00
|
|
|
del x_samples_ddim
|
2022-09-06 07:09:01 +08:00
|
|
|
|
2022-10-05 09:28:50 +08:00
|
|
|
devices.torch_gc()
|
2022-09-29 09:14:13 +08:00
|
|
|
|
2022-10-05 09:28:50 +08:00
|
|
|
state.nextjob()
|
2022-09-29 09:14:13 +08:00
|
|
|
|
2022-09-18 06:18:30 +08:00
|
|
|
p.color_corrections = None
|
|
|
|
|
2022-09-19 14:02:10 +08:00
|
|
|
index_of_first_image = 0
|
2022-09-03 17:08:45 +08:00
|
|
|
unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
|
2022-09-14 15:34:44 +08:00
|
|
|
if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
|
2022-09-03 22:21:15 +08:00
|
|
|
grid = images.image_grid(output_images, p.batch_size)
|
2022-09-03 17:08:45 +08:00
|
|
|
|
2022-09-14 15:34:44 +08:00
|
|
|
if opts.return_grid:
|
2022-10-07 01:27:50 +08:00
|
|
|
text = infotext()
|
|
|
|
infotexts.insert(0, text)
|
2022-10-09 18:10:15 +08:00
|
|
|
if opts.enable_pnginfo:
|
|
|
|
grid.info["parameters"] = text
|
2022-09-03 17:08:45 +08:00
|
|
|
output_images.insert(0, grid)
|
2022-09-19 14:02:10 +08:00
|
|
|
index_of_first_image = 1
|
2022-09-03 17:08:45 +08:00
|
|
|
|
|
|
|
if opts.grid_save:
|
2022-09-26 17:30:18 +08:00
|
|
|
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
|
2022-09-03 17:08:45 +08:00
|
|
|
|
2022-09-12 04:24:24 +08:00
|
|
|
devices.torch_gc()
|
2022-09-28 23:20:30 +08:00
|
|
|
return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
|
2022-09-03 17:08:45 +08:00
|
|
|
|
|
|
|
|
|
|
|
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|
|
|
sampler = None
|
2022-09-19 21:42:56 +08:00
|
|
|
|
2022-10-18 03:10:36 +08:00
|
|
|
def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
|
2022-09-19 21:42:56 +08:00
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.enable_hr = enable_hr
|
|
|
|
self.denoising_strength = denoising_strength
|
2022-10-14 03:04:22 +08:00
|
|
|
self.firstphase_width = firstphase_width
|
|
|
|
self.firstphase_height = firstphase_height
|
2022-10-15 04:19:05 +08:00
|
|
|
self.truncate_x = 0
|
|
|
|
self.truncate_y = 0
|
2022-09-19 21:42:56 +08:00
|
|
|
|
|
|
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
|
|
|
if self.enable_hr:
|
|
|
|
if state.job_count == -1:
|
|
|
|
state.job_count = self.n_iter * 2
|
|
|
|
else:
|
|
|
|
state.job_count = state.job_count * 2
|
|
|
|
|
2022-10-15 04:19:05 +08:00
|
|
|
if self.firstphase_width == 0 or self.firstphase_height == 0:
|
|
|
|
desired_pixel_count = 512 * 512
|
|
|
|
actual_pixel_count = self.width * self.height
|
|
|
|
scale = math.sqrt(desired_pixel_count / actual_pixel_count)
|
|
|
|
self.firstphase_width = math.ceil(scale * self.width / 64) * 64
|
|
|
|
self.firstphase_height = math.ceil(scale * self.height / 64) * 64
|
|
|
|
firstphase_width_truncated = int(scale * self.width)
|
|
|
|
firstphase_height_truncated = int(scale * self.height)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
width_ratio = self.width / self.firstphase_width
|
|
|
|
height_ratio = self.height / self.firstphase_height
|
|
|
|
|
|
|
|
if width_ratio > height_ratio:
|
|
|
|
firstphase_width_truncated = self.firstphase_width
|
|
|
|
firstphase_height_truncated = self.firstphase_width * self.height / self.width
|
|
|
|
else:
|
|
|
|
firstphase_width_truncated = self.firstphase_height * self.width / self.height
|
|
|
|
firstphase_height_truncated = self.firstphase_height
|
|
|
|
|
2022-10-15 20:47:02 +08:00
|
|
|
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
|
2022-10-15 04:19:05 +08:00
|
|
|
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
|
|
|
|
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
|
|
|
|
|
|
|
|
|
2022-09-19 21:42:56 +08:00
|
|
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
2022-10-06 19:12:52 +08:00
|
|
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
2022-09-19 21:42:56 +08:00
|
|
|
|
|
|
|
if not self.enable_hr:
|
|
|
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
2022-10-20 04:47:45 +08:00
|
|
|
|
|
|
|
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
|
|
|
image_conditioning = torch.zeros(x.shape[0], 3, self.height, self.width, device=x.device)
|
|
|
|
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
|
|
|
|
|
|
|
|
# Add the fake full 1s mask to the first dimension.
|
|
|
|
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
|
|
|
|
image_conditioning = image_conditioning.to(x.dtype)
|
|
|
|
|
|
|
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=image_conditioning)
|
2022-09-19 21:42:56 +08:00
|
|
|
return samples
|
|
|
|
|
|
|
|
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
|
|
|
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
|
|
|
|
|
2022-10-15 04:19:05 +08:00
|
|
|
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
|
2022-09-19 21:42:56 +08:00
|
|
|
|
2022-10-15 18:23:12 +08:00
|
|
|
if opts.use_scale_latent_for_hires_fix:
|
|
|
|
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
|
2022-09-19 21:42:56 +08:00
|
|
|
|
|
|
|
else:
|
2022-10-15 18:23:12 +08:00
|
|
|
decoded_samples = decode_first_stage(self.sd_model, samples)
|
2022-10-14 22:03:03 +08:00
|
|
|
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
2022-09-21 00:32:26 +08:00
|
|
|
|
2022-10-14 22:03:03 +08:00
|
|
|
batch_images = []
|
|
|
|
for i, x_sample in enumerate(lowres_samples):
|
|
|
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
|
|
|
x_sample = x_sample.astype(np.uint8)
|
|
|
|
image = Image.fromarray(x_sample)
|
|
|
|
image = images.resize_image(0, image, self.width, self.height)
|
|
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
|
|
image = np.moveaxis(image, 2, 0)
|
|
|
|
batch_images.append(image)
|
|
|
|
|
|
|
|
decoded_samples = torch.from_numpy(np.array(batch_images))
|
|
|
|
decoded_samples = decoded_samples.to(shared.device)
|
|
|
|
decoded_samples = 2. * decoded_samples - 1.
|
|
|
|
|
2022-10-15 18:23:12 +08:00
|
|
|
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
|
2022-09-19 21:42:56 +08:00
|
|
|
|
|
|
|
shared.state.nextjob()
|
2022-09-03 17:08:45 +08:00
|
|
|
|
2022-10-06 19:12:52 +08:00
|
|
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
|
|
|
|
|
2022-09-19 21:42:56 +08:00
|
|
|
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
2022-09-24 09:18:34 +08:00
|
|
|
|
|
|
|
# GC now before running the next img2img to prevent running out of memory
|
|
|
|
x = None
|
|
|
|
devices.torch_gc()
|
2022-10-04 23:49:51 +08:00
|
|
|
|
2022-09-19 21:42:56 +08:00
|
|
|
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
|
|
|
|
|
|
|
|
return samples
|
2022-09-03 17:08:45 +08:00
|
|
|
|
|
|
|
|
|
|
|
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|
|
|
sampler = None
|
|
|
|
|
2022-09-22 17:11:48 +08:00
|
|
|
def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs):
|
2022-09-03 17:08:45 +08:00
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
|
|
self.init_images = init_images
|
|
|
|
self.resize_mode: int = resize_mode
|
|
|
|
self.denoising_strength: float = denoising_strength
|
|
|
|
self.init_latent = None
|
|
|
|
self.image_mask = mask
|
2022-09-04 06:29:43 +08:00
|
|
|
#self.image_unblurred_mask = None
|
|
|
|
self.latent_mask = None
|
2022-09-03 17:08:45 +08:00
|
|
|
self.mask_for_overlay = None
|
|
|
|
self.mask_blur = mask_blur
|
|
|
|
self.inpainting_fill = inpainting_fill
|
|
|
|
self.inpaint_full_res = inpaint_full_res
|
2022-09-22 17:11:48 +08:00
|
|
|
self.inpaint_full_res_padding = inpaint_full_res_padding
|
2022-09-04 02:02:38 +08:00
|
|
|
self.inpainting_mask_invert = inpainting_mask_invert
|
2022-09-03 17:08:45 +08:00
|
|
|
self.mask = None
|
|
|
|
self.nmask = None
|
|
|
|
|
2022-09-19 21:42:56 +08:00
|
|
|
def init(self, all_prompts, all_seeds, all_subseeds):
|
2022-10-06 19:12:52 +08:00
|
|
|
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
|
2022-09-03 17:08:45 +08:00
|
|
|
crop_region = None
|
|
|
|
|
|
|
|
if self.image_mask is not None:
|
2022-09-04 02:02:38 +08:00
|
|
|
self.image_mask = self.image_mask.convert('L')
|
|
|
|
|
|
|
|
if self.inpainting_mask_invert:
|
|
|
|
self.image_mask = ImageOps.invert(self.image_mask)
|
|
|
|
|
2022-09-04 06:29:43 +08:00
|
|
|
#self.image_unblurred_mask = self.image_mask
|
|
|
|
|
2022-09-03 17:08:45 +08:00
|
|
|
if self.mask_blur > 0:
|
2022-09-04 02:02:38 +08:00
|
|
|
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
2022-09-03 17:08:45 +08:00
|
|
|
|
|
|
|
if self.inpaint_full_res:
|
|
|
|
self.mask_for_overlay = self.image_mask
|
|
|
|
mask = self.image_mask.convert('L')
|
2022-09-22 17:11:48 +08:00
|
|
|
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
|
2022-09-18 15:49:00 +08:00
|
|
|
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
|
2022-09-03 17:08:45 +08:00
|
|
|
x1, y1, x2, y2 = crop_region
|
|
|
|
|
|
|
|
mask = mask.crop(crop_region)
|
|
|
|
self.image_mask = images.resize_image(2, mask, self.width, self.height)
|
|
|
|
self.paste_to = (x1, y1, x2-x1, y2-y1)
|
|
|
|
else:
|
|
|
|
self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
|
2022-09-07 05:58:01 +08:00
|
|
|
np_mask = np.array(self.image_mask)
|
2022-09-13 22:14:40 +08:00
|
|
|
np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
|
2022-09-07 05:58:01 +08:00
|
|
|
self.mask_for_overlay = Image.fromarray(np_mask)
|
2022-09-03 17:08:45 +08:00
|
|
|
|
|
|
|
self.overlay_images = []
|
|
|
|
|
2022-09-07 22:00:51 +08:00
|
|
|
latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
|
|
|
|
|
2022-09-16 13:33:47 +08:00
|
|
|
add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
|
|
|
|
if add_color_corrections:
|
|
|
|
self.color_corrections = []
|
2022-09-03 17:08:45 +08:00
|
|
|
imgs = []
|
|
|
|
for img in self.init_images:
|
|
|
|
image = img.convert("RGB")
|
|
|
|
|
|
|
|
if crop_region is None:
|
|
|
|
image = images.resize_image(self.resize_mode, image, self.width, self.height)
|
|
|
|
|
|
|
|
if self.image_mask is not None:
|
|
|
|
image_masked = Image.new('RGBa', (image.width, image.height))
|
|
|
|
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
|
|
|
|
|
|
|
|
self.overlay_images.append(image_masked.convert('RGBA'))
|
|
|
|
|
|
|
|
if crop_region is not None:
|
|
|
|
image = image.crop(crop_region)
|
|
|
|
image = images.resize_image(2, image, self.width, self.height)
|
|
|
|
|
2022-09-08 15:03:21 +08:00
|
|
|
if self.image_mask is not None:
|
|
|
|
if self.inpainting_fill != 1:
|
2022-09-18 15:49:00 +08:00
|
|
|
image = masking.fill(image, latent_mask)
|
2022-09-08 15:03:21 +08:00
|
|
|
|
2022-09-16 13:33:47 +08:00
|
|
|
if add_color_corrections:
|
2022-09-13 17:51:57 +08:00
|
|
|
self.color_corrections.append(setup_color_correction(image))
|
|
|
|
|
2022-09-03 17:08:45 +08:00
|
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
|
|
image = np.moveaxis(image, 2, 0)
|
|
|
|
|
|
|
|
imgs.append(image)
|
|
|
|
|
|
|
|
if len(imgs) == 1:
|
|
|
|
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
|
|
|
|
if self.overlay_images is not None:
|
|
|
|
self.overlay_images = self.overlay_images * self.batch_size
|
|
|
|
elif len(imgs) <= self.batch_size:
|
|
|
|
self.batch_size = len(imgs)
|
|
|
|
batch_images = np.array(imgs)
|
|
|
|
else:
|
|
|
|
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
|
|
|
|
|
|
|
|
image = torch.from_numpy(batch_images)
|
|
|
|
image = 2. * image - 1.
|
|
|
|
image = image.to(shared.device)
|
|
|
|
|
|
|
|
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
|
|
|
|
|
|
|
if self.image_mask is not None:
|
2022-09-07 22:00:51 +08:00
|
|
|
init_mask = latent_mask
|
2022-09-04 06:29:43 +08:00
|
|
|
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
|
2022-09-13 01:09:32 +08:00
|
|
|
latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
|
2022-09-03 17:08:45 +08:00
|
|
|
latmask = latmask[0]
|
2022-09-07 22:00:51 +08:00
|
|
|
latmask = np.around(latmask)
|
2022-09-03 17:08:45 +08:00
|
|
|
latmask = np.tile(latmask[None], (4, 1, 1))
|
|
|
|
|
|
|
|
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
|
|
|
|
self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
|
|
|
|
|
2022-09-19 21:42:56 +08:00
|
|
|
# this needs to be fixed to be done in sample() using actual seeds for batches
|
2022-09-03 17:08:45 +08:00
|
|
|
if self.inpainting_fill == 2:
|
2022-09-19 21:42:56 +08:00
|
|
|
self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
|
2022-09-03 17:08:45 +08:00
|
|
|
elif self.inpainting_fill == 3:
|
|
|
|
self.init_latent = self.init_latent * self.mask
|
|
|
|
|
2022-10-20 04:47:45 +08:00
|
|
|
if self.image_mask is not None:
|
|
|
|
conditioning_mask = np.array(self.image_mask.convert("L"))
|
|
|
|
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
|
|
|
|
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
|
|
|
|
|
|
|
|
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
|
|
|
|
conditioning_mask = torch.round(conditioning_mask)
|
|
|
|
else:
|
|
|
|
conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
|
|
|
|
|
|
|
|
# Create another latent image, this time with a masked version of the original input.
|
|
|
|
conditioning_mask = conditioning_mask.to(image.device)
|
|
|
|
conditioning_image = image * (1.0 - conditioning_mask)
|
|
|
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
|
|
|
|
|
|
|
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
|
|
|
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
|
|
|
|
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
|
|
|
|
self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
|
|
|
|
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
|
|
|
|
|
2022-09-19 21:42:56 +08:00
|
|
|
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
|
|
|
|
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
|
|
|
|
2022-10-20 04:47:45 +08:00
|
|
|
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
|
2022-09-03 17:08:45 +08:00
|
|
|
|
|
|
|
if self.mask is not None:
|
|
|
|
samples = samples * self.nmask + self.init_latent * self.mask
|
|
|
|
|
2022-09-29 09:14:13 +08:00
|
|
|
del x
|
|
|
|
devices.torch_gc()
|
|
|
|
|
2022-10-17 14:58:42 +08:00
|
|
|
return samples
|