2023-08-09 00:20:11 +08:00
|
|
|
import inspect
|
2023-08-09 13:43:31 +08:00
|
|
|
from collections import namedtuple
|
2022-09-07 04:10:12 +08:00
|
|
|
import numpy as np
|
2022-09-03 17:08:45 +08:00
|
|
|
import torch
|
2022-09-07 04:10:12 +08:00
|
|
|
from PIL import Image
|
2023-08-06 22:01:07 +08:00
|
|
|
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
2023-01-30 14:51:06 +08:00
|
|
|
from modules.shared import opts, state
|
2023-08-09 00:20:11 +08:00
|
|
|
import k_diffusion.sampling
|
2022-09-03 22:21:15 +08:00
|
|
|
|
2023-08-12 17:39:59 +08:00
|
|
|
|
|
|
|
SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
|
|
|
|
|
|
|
|
|
|
|
class SamplerData(SamplerDataTuple):
|
|
|
|
def total_steps(self, steps):
|
|
|
|
if self.options.get("second_order", False):
|
|
|
|
steps = steps * 2
|
|
|
|
|
|
|
|
return steps
|
2022-09-03 22:21:15 +08:00
|
|
|
|
2022-10-23 01:48:13 +08:00
|
|
|
|
2022-09-19 21:42:56 +08:00
|
|
|
def setup_img2img_steps(p, steps=None):
|
|
|
|
if opts.img2img_fix_steps or steps is not None:
|
2023-01-05 04:56:43 +08:00
|
|
|
requested_steps = (steps or p.steps)
|
|
|
|
steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0
|
|
|
|
t_enc = requested_steps - 1
|
2022-09-16 18:38:02 +08:00
|
|
|
else:
|
|
|
|
steps = p.steps
|
|
|
|
t_enc = int(min(p.denoising_strength, 0.999) * steps)
|
|
|
|
|
|
|
|
return steps, t_enc
|
|
|
|
|
|
|
|
|
2023-05-17 14:24:01 +08:00
|
|
|
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
|
2022-12-25 03:39:00 +08:00
|
|
|
|
|
|
|
|
2023-08-04 13:38:52 +08:00
|
|
|
def samples_to_images_tensor(sample, approximation=None, model=None):
|
2023-08-21 20:54:30 +08:00
|
|
|
"""Transforms 4-channel latent space images into 3-channel RGB image tensors, with values in range [-1, 1]."""
|
|
|
|
|
2023-08-18 08:03:26 +08:00
|
|
|
if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
|
2023-05-17 14:24:01 +08:00
|
|
|
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
|
|
|
|
2023-08-21 20:54:30 +08:00
|
|
|
from modules import lowvram
|
|
|
|
if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:
|
|
|
|
approximation = 1
|
|
|
|
|
2023-05-17 14:24:01 +08:00
|
|
|
if approximation == 2:
|
2023-08-04 13:38:52 +08:00
|
|
|
x_sample = sd_vae_approx.cheap_approximation(sample)
|
2023-05-17 14:24:01 +08:00
|
|
|
elif approximation == 1:
|
2023-08-04 13:38:52 +08:00
|
|
|
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()
|
2023-05-17 14:24:01 +08:00
|
|
|
elif approximation == 3:
|
2023-08-16 11:21:12 +08:00
|
|
|
x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
|
2023-08-04 13:38:52 +08:00
|
|
|
x_sample = x_sample * 2 - 1
|
2022-12-24 19:00:17 +08:00
|
|
|
else:
|
2023-08-04 13:38:52 +08:00
|
|
|
if model is None:
|
|
|
|
model = shared.sd_model
|
2023-08-18 06:10:55 +08:00
|
|
|
with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
|
|
|
|
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
|
2023-08-04 13:40:20 +08:00
|
|
|
|
2023-08-04 13:38:52 +08:00
|
|
|
return x_sample
|
|
|
|
|
|
|
|
|
|
|
|
def single_sample_to_image(sample, approximation=None):
|
|
|
|
x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5
|
2022-12-25 03:39:00 +08:00
|
|
|
|
2023-05-17 17:39:07 +08:00
|
|
|
x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
|
2022-09-07 04:10:12 +08:00
|
|
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
|
|
|
x_sample = x_sample.astype(np.uint8)
|
2023-05-17 14:24:01 +08:00
|
|
|
|
2022-09-07 04:10:12 +08:00
|
|
|
return Image.fromarray(x_sample)
|
|
|
|
|
2022-10-23 01:48:13 +08:00
|
|
|
|
2023-08-04 14:09:09 +08:00
|
|
|
def decode_first_stage(model, x):
|
2023-08-04 14:38:52 +08:00
|
|
|
x = x.to(devices.dtype_vae)
|
|
|
|
approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
|
|
|
|
return samples_to_images_tensor(x, approx_index, model)
|
2023-08-04 14:09:09 +08:00
|
|
|
|
|
|
|
|
2022-12-25 03:39:00 +08:00
|
|
|
def sample_to_image(samples, index=0, approximation=None):
|
2022-12-24 19:00:17 +08:00
|
|
|
return single_sample_to_image(samples[index], approximation)
|
2022-10-23 01:48:13 +08:00
|
|
|
|
2022-11-02 17:45:03 +08:00
|
|
|
|
2022-12-25 03:39:00 +08:00
|
|
|
def samples_to_image_grid(samples, approximation=None):
|
2022-12-24 19:00:17 +08:00
|
|
|
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
2022-10-23 01:48:13 +08:00
|
|
|
|
2022-09-07 04:10:12 +08:00
|
|
|
|
2023-08-04 13:38:52 +08:00
|
|
|
def images_tensor_to_samples(image, approximation=None, model=None):
|
|
|
|
'''image[0, 1] -> latent'''
|
|
|
|
if approximation is None:
|
|
|
|
approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0)
|
|
|
|
|
|
|
|
if approximation == 3:
|
|
|
|
image = image.to(devices.device, devices.dtype)
|
2023-08-04 17:55:52 +08:00
|
|
|
x_latent = sd_vae_taesd.encoder_model()(image)
|
2023-08-04 13:38:52 +08:00
|
|
|
else:
|
|
|
|
if model is None:
|
|
|
|
model = shared.sd_model
|
2023-08-31 02:13:24 +08:00
|
|
|
model.first_stage_model.to(devices.dtype_vae)
|
|
|
|
|
2023-08-04 13:38:52 +08:00
|
|
|
image = image.to(shared.device, dtype=devices.dtype_vae)
|
|
|
|
image = image * 2 - 1
|
2023-08-13 16:16:48 +08:00
|
|
|
if len(image) > 1:
|
|
|
|
x_latent = torch.stack([
|
|
|
|
model.get_first_stage_encoding(
|
|
|
|
model.encode_first_stage(torch.unsqueeze(img, 0))
|
|
|
|
)[0]
|
|
|
|
for img in image
|
|
|
|
])
|
|
|
|
else:
|
|
|
|
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
|
2023-08-04 13:38:52 +08:00
|
|
|
|
|
|
|
return x_latent
|
|
|
|
|
|
|
|
|
2022-09-07 04:10:12 +08:00
|
|
|
def store_latent(decoded):
|
2023-08-22 15:41:10 +08:00
|
|
|
state.current_latent = decoded
|
2022-09-07 04:10:12 +08:00
|
|
|
|
2023-01-14 21:29:23 +08:00
|
|
|
if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
|
2022-09-07 04:10:12 +08:00
|
|
|
if not shared.parallel_processing_allowed:
|
2023-01-15 23:50:56 +08:00
|
|
|
shared.state.assign_current_image(sample_to_image(decoded))
|
2022-09-07 04:10:12 +08:00
|
|
|
|
|
|
|
|
2023-05-16 16:54:02 +08:00
|
|
|
def is_sampler_using_eta_noise_seed_delta(p):
|
|
|
|
"""returns whether sampler from config will use eta noise seed delta for image creation"""
|
|
|
|
|
|
|
|
sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
|
|
|
|
|
|
|
|
eta = p.eta
|
|
|
|
|
|
|
|
if eta is None and p.sampler is not None:
|
|
|
|
eta = p.sampler.eta
|
|
|
|
|
|
|
|
if eta is None and sampler_config is not None:
|
|
|
|
eta = 0 if sampler_config.options.get("default_eta_is_0", False) else 1.0
|
|
|
|
|
|
|
|
if eta == 0:
|
|
|
|
return False
|
|
|
|
|
|
|
|
return sampler_config.options.get("uses_ensd", False)
|
|
|
|
|
|
|
|
|
2022-10-18 22:23:38 +08:00
|
|
|
class InterruptedException(BaseException):
|
|
|
|
pass
|
2023-04-19 11:18:58 +08:00
|
|
|
|
2023-04-29 16:29:37 +08:00
|
|
|
|
2023-08-03 12:18:55 +08:00
|
|
|
def replace_torchsde_browinan():
|
2023-04-19 11:18:58 +08:00
|
|
|
import torchsde._brownian.brownian_interval
|
|
|
|
|
|
|
|
def torchsde_randn(size, dtype, device, seed):
|
2023-08-03 12:18:55 +08:00
|
|
|
return devices.randn_local(seed, size).to(device=device, dtype=dtype)
|
2023-04-19 11:18:58 +08:00
|
|
|
|
|
|
|
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
2023-08-03 12:18:55 +08:00
|
|
|
|
|
|
|
|
|
|
|
replace_torchsde_browinan()
|
2023-08-06 22:01:07 +08:00
|
|
|
|
|
|
|
|
2023-08-12 17:39:59 +08:00
|
|
|
def apply_refiner(cfg_denoiser):
|
|
|
|
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
|
|
|
|
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
|
|
|
|
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
|
2023-08-06 22:01:07 +08:00
|
|
|
|
2023-08-13 11:07:30 +08:00
|
|
|
if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
|
2023-08-06 22:53:33 +08:00
|
|
|
return False
|
|
|
|
|
2023-08-12 17:39:59 +08:00
|
|
|
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
|
2023-08-09 03:17:25 +08:00
|
|
|
return False
|
|
|
|
|
2023-08-30 23:22:50 +08:00
|
|
|
if getattr(cfg_denoiser.p, "enable_hr", False):
|
|
|
|
is_second_pass = cfg_denoiser.p.is_hr_pass
|
|
|
|
|
|
|
|
if opts.hires_fix_refiner_pass == "first pass" and is_second_pass:
|
|
|
|
return False
|
|
|
|
|
|
|
|
if opts.hires_fix_refiner_pass == "second pass" and not is_second_pass:
|
|
|
|
return False
|
|
|
|
|
|
|
|
if opts.hires_fix_refiner_pass != "second pass":
|
|
|
|
cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass
|
2023-08-12 17:54:32 +08:00
|
|
|
|
2023-08-12 17:39:59 +08:00
|
|
|
cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
|
|
|
|
cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at
|
2023-08-06 22:53:33 +08:00
|
|
|
|
|
|
|
with sd_models.SkipWritingToConfig():
|
|
|
|
sd_models.reload_model_weights(info=refiner_checkpoint_info)
|
|
|
|
|
|
|
|
devices.torch_gc()
|
2023-08-12 17:39:59 +08:00
|
|
|
cfg_denoiser.p.setup_conds()
|
|
|
|
cfg_denoiser.update_inner_model()
|
2023-08-06 22:01:07 +08:00
|
|
|
|
2023-08-06 22:53:33 +08:00
|
|
|
return True
|
2023-08-06 22:01:07 +08:00
|
|
|
|
|
|
|
|
2023-08-09 00:20:11 +08:00
|
|
|
class TorchHijack:
|
2023-08-09 13:43:31 +08:00
|
|
|
"""This is here to replace torch.randn_like of k-diffusion.
|
|
|
|
|
|
|
|
k-diffusion has random_sampler argument for most samplers, but not for all, so
|
|
|
|
this is needed to properly replace every use of torch.randn_like.
|
|
|
|
|
|
|
|
We need to replace to make images generated in batches to be same as images generated individually."""
|
|
|
|
|
|
|
|
def __init__(self, p):
|
|
|
|
self.rng = p.rng
|
2023-08-09 00:20:11 +08:00
|
|
|
|
|
|
|
def __getattr__(self, item):
|
|
|
|
if item == 'randn_like':
|
|
|
|
return self.randn_like
|
|
|
|
|
|
|
|
if hasattr(torch, item):
|
|
|
|
return getattr(torch, item)
|
|
|
|
|
|
|
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
|
|
|
|
|
|
|
def randn_like(self, x):
|
2023-08-09 13:43:31 +08:00
|
|
|
return self.rng.next()
|
2023-08-09 00:20:11 +08:00
|
|
|
|
|
|
|
|
|
|
|
class Sampler:
|
|
|
|
def __init__(self, funcname):
|
|
|
|
self.funcname = funcname
|
|
|
|
self.func = funcname
|
|
|
|
self.extra_params = []
|
|
|
|
self.sampler_noises = None
|
|
|
|
self.stop_at = None
|
|
|
|
self.eta = None
|
2023-08-12 17:39:59 +08:00
|
|
|
self.config: SamplerData = None # set by the function calling the constructor
|
2023-08-09 00:20:11 +08:00
|
|
|
self.last_latent = None
|
|
|
|
self.s_min_uncond = None
|
|
|
|
self.s_churn = 0.0
|
|
|
|
self.s_tmin = 0.0
|
|
|
|
self.s_tmax = float('inf')
|
|
|
|
self.s_noise = 1.0
|
|
|
|
|
|
|
|
self.eta_option_field = 'eta_ancestral'
|
|
|
|
self.eta_infotext_field = 'Eta'
|
2023-08-15 02:48:05 +08:00
|
|
|
self.eta_default = 1.0
|
2023-08-09 00:20:11 +08:00
|
|
|
|
|
|
|
self.conditioning_key = shared.sd_model.model.conditioning_key
|
|
|
|
|
2023-08-09 03:09:40 +08:00
|
|
|
self.p = None
|
2023-08-09 00:20:11 +08:00
|
|
|
self.model_wrap_cfg = None
|
2023-08-09 03:09:40 +08:00
|
|
|
self.sampler_extra_args = None
|
2023-08-12 17:39:59 +08:00
|
|
|
self.options = {}
|
2023-08-09 00:20:11 +08:00
|
|
|
|
|
|
|
def callback_state(self, d):
|
|
|
|
step = d['i']
|
|
|
|
|
|
|
|
if self.stop_at is not None and step > self.stop_at:
|
|
|
|
raise InterruptedException
|
|
|
|
|
|
|
|
state.sampling_step = step
|
|
|
|
shared.total_tqdm.update()
|
|
|
|
|
|
|
|
def launch_sampling(self, steps, func):
|
2023-08-09 03:09:40 +08:00
|
|
|
self.model_wrap_cfg.steps = steps
|
2023-08-12 17:39:59 +08:00
|
|
|
self.model_wrap_cfg.total_steps = self.config.total_steps(steps)
|
2023-08-09 00:20:11 +08:00
|
|
|
state.sampling_steps = steps
|
|
|
|
state.sampling_step = 0
|
|
|
|
|
|
|
|
try:
|
|
|
|
return func()
|
|
|
|
except RecursionError:
|
|
|
|
print(
|
|
|
|
'Encountered RecursionError during sampling, returning last latent. '
|
|
|
|
'rho >5 with a polyexponential scheduler may cause this error. '
|
|
|
|
'You should try to use a smaller rho value instead.'
|
|
|
|
)
|
|
|
|
return self.last_latent
|
|
|
|
except InterruptedException:
|
|
|
|
return self.last_latent
|
|
|
|
|
|
|
|
def number_of_needed_noises(self, p):
|
|
|
|
return p.steps
|
|
|
|
|
|
|
|
def initialize(self, p) -> dict:
|
2023-08-09 03:09:40 +08:00
|
|
|
self.p = p
|
|
|
|
self.model_wrap_cfg.p = p
|
2023-08-09 00:20:11 +08:00
|
|
|
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
|
|
|
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
|
|
|
self.model_wrap_cfg.step = 0
|
|
|
|
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
|
|
|
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
|
|
|
|
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
|
|
|
|
2023-08-09 13:43:31 +08:00
|
|
|
k_diffusion.sampling.torch = TorchHijack(p)
|
2023-08-09 00:20:11 +08:00
|
|
|
|
|
|
|
extra_params_kwargs = {}
|
|
|
|
for param_name in self.extra_params:
|
|
|
|
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
|
|
|
extra_params_kwargs[param_name] = getattr(p, param_name)
|
|
|
|
|
|
|
|
if 'eta' in inspect.signature(self.func).parameters:
|
2023-08-15 02:48:05 +08:00
|
|
|
if self.eta != self.eta_default:
|
2023-08-09 00:20:11 +08:00
|
|
|
p.extra_generation_params[self.eta_infotext_field] = self.eta
|
|
|
|
|
|
|
|
extra_params_kwargs['eta'] = self.eta
|
|
|
|
|
|
|
|
if len(self.extra_params) > 0:
|
|
|
|
s_churn = getattr(opts, 's_churn', p.s_churn)
|
|
|
|
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
|
|
|
|
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
|
|
|
s_noise = getattr(opts, 's_noise', p.s_noise)
|
|
|
|
|
2023-08-13 20:22:24 +08:00
|
|
|
if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
|
2023-08-09 00:20:11 +08:00
|
|
|
extra_params_kwargs['s_churn'] = s_churn
|
|
|
|
p.s_churn = s_churn
|
|
|
|
p.extra_generation_params['Sigma churn'] = s_churn
|
2023-08-13 20:22:24 +08:00
|
|
|
if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
|
2023-08-09 00:20:11 +08:00
|
|
|
extra_params_kwargs['s_tmin'] = s_tmin
|
|
|
|
p.s_tmin = s_tmin
|
|
|
|
p.extra_generation_params['Sigma tmin'] = s_tmin
|
2023-08-13 20:22:24 +08:00
|
|
|
if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
|
2023-08-09 00:20:11 +08:00
|
|
|
extra_params_kwargs['s_tmax'] = s_tmax
|
|
|
|
p.s_tmax = s_tmax
|
|
|
|
p.extra_generation_params['Sigma tmax'] = s_tmax
|
2023-08-13 20:22:24 +08:00
|
|
|
if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
|
2023-08-09 00:20:11 +08:00
|
|
|
extra_params_kwargs['s_noise'] = s_noise
|
|
|
|
p.s_noise = s_noise
|
|
|
|
p.extra_generation_params['Sigma noise'] = s_noise
|
|
|
|
|
|
|
|
return extra_params_kwargs
|
|
|
|
|
|
|
|
def create_noise_sampler(self, x, sigmas, p):
|
|
|
|
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
|
|
|
if shared.opts.no_dpmpp_sde_batch_determinism:
|
|
|
|
return None
|
|
|
|
|
|
|
|
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
|
|
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
|
|
|
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
|
|
|
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
|
|
|
|
2023-08-13 13:24:16 +08:00
|
|
|
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
|
|
|
raise NotImplementedError()
|
2023-08-09 00:20:11 +08:00
|
|
|
|
2023-08-13 13:24:16 +08:00
|
|
|
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
|
|
|
raise NotImplementedError()
|