2023-01-30 14:51:06 +08:00
from collections import deque
2022-09-03 17:08:45 +08:00
import torch
2022-09-28 15:49:07 +08:00
import inspect
2022-09-03 17:08:45 +08:00
import k_diffusion . sampling
2023-01-30 15:47:09 +08:00
from modules import prompt_parser , devices , sd_samplers_common
2022-09-03 17:08:45 +08:00
2023-01-30 14:51:06 +08:00
from modules . shared import opts , state
2022-09-03 17:08:45 +08:00
import modules . shared as shared
2022-11-02 08:38:17 +08:00
from modules . script_callbacks import CFGDenoiserParams , cfg_denoiser_callback
2023-02-11 10:18:38 +08:00
from modules . script_callbacks import CFGDenoisedParams , cfg_denoised_callback
2023-05-14 09:49:41 +08:00
from modules . script_callbacks import AfterCFGCallbackParams , cfg_after_cfg_callback
2022-09-03 17:08:45 +08:00
2022-09-03 22:21:15 +08:00
samplers_k_diffusion = [
2023-05-16 16:54:02 +08:00
( ' Euler a ' , ' sample_euler_ancestral ' , [ ' k_euler_a ' , ' k_euler_ancestral ' ] , { " uses_ensd " : True } ) ,
2022-10-06 19:12:52 +08:00
( ' Euler ' , ' sample_euler ' , [ ' k_euler ' ] , { } ) ,
( ' LMS ' , ' sample_lms ' , [ ' k_lms ' ] , { } ) ,
2023-05-16 17:36:15 +08:00
( ' Heun ' , ' sample_heun ' , [ ' k_heun ' ] , { " second_order " : True } ) ,
2022-12-24 14:03:45 +08:00
( ' DPM2 ' , ' sample_dpm_2 ' , [ ' k_dpm_2 ' ] , { ' discard_next_to_last_sigma ' : True } ) ,
2023-05-16 16:54:02 +08:00
( ' DPM2 a ' , ' sample_dpm_2_ancestral ' , [ ' k_dpm_2_a ' ] , { ' discard_next_to_last_sigma ' : True , " uses_ensd " : True } ) ,
2023-05-16 17:36:15 +08:00
( ' DPM++ 2S a ' , ' sample_dpmpp_2s_ancestral ' , [ ' k_dpmpp_2s_a ' ] , { " uses_ensd " : True , " second_order " : True } ) ,
2022-11-05 23:32:22 +08:00
( ' DPM++ 2M ' , ' sample_dpmpp_2m ' , [ ' k_dpmpp_2m ' ] , { } ) ,
2023-05-21 12:31:39 +08:00
( ' DPM++ SDE ' , ' sample_dpmpp_sde ' , [ ' k_dpmpp_sde ' ] , { " second_order " : True , " brownian_noise " : True } ) ,
2023-05-23 01:06:57 +08:00
( ' DPM++ 2M SDE ' , ' sample_dpmpp_2m_sde ' , [ ' k_dpmpp_2m_sde_ka ' ] , { " brownian_noise " : True } ) ,
2023-05-16 16:54:02 +08:00
( ' DPM fast ' , ' sample_dpm_fast ' , [ ' k_dpm_fast ' ] , { " uses_ensd " : True } ) ,
( ' DPM adaptive ' , ' sample_dpm_adaptive ' , [ ' k_dpm_ad ' ] , { " uses_ensd " : True } ) ,
2022-10-06 19:12:52 +08:00
( ' LMS Karras ' , ' sample_lms ' , [ ' k_lms_ka ' ] , { ' scheduler ' : ' karras ' } ) ,
2023-05-16 17:36:15 +08:00
( ' DPM2 Karras ' , ' sample_dpm_2 ' , [ ' k_dpm_2_ka ' ] , { ' scheduler ' : ' karras ' , ' discard_next_to_last_sigma ' : True , " uses_ensd " : True , " second_order " : True } ) ,
( ' DPM2 a Karras ' , ' sample_dpm_2_ancestral ' , [ ' k_dpm_2_a_ka ' ] , { ' scheduler ' : ' karras ' , ' discard_next_to_last_sigma ' : True , " uses_ensd " : True , " second_order " : True } ) ,
( ' DPM++ 2S a Karras ' , ' sample_dpmpp_2s_ancestral ' , [ ' k_dpmpp_2s_a_ka ' ] , { ' scheduler ' : ' karras ' , " uses_ensd " : True , " second_order " : True } ) ,
2022-11-05 23:32:22 +08:00
( ' DPM++ 2M Karras ' , ' sample_dpmpp_2m ' , [ ' k_dpmpp_2m_ka ' ] , { ' scheduler ' : ' karras ' } ) ,
2023-05-21 12:31:39 +08:00
( ' DPM++ SDE Karras ' , ' sample_dpmpp_sde ' , [ ' k_dpmpp_sde_ka ' ] , { ' scheduler ' : ' karras ' , " second_order " : True , " brownian_noise " : True } ) ,
2023-05-23 01:06:57 +08:00
( ' DPM++ 2M SDE Karras ' , ' sample_dpmpp_2m_sde ' , [ ' k_dpmpp_2m_sde_ka ' ] , { ' scheduler ' : ' karras ' , " brownian_noise " : True } ) ,
2023-07-18 12:32:01 +08:00
( ' Restart (new) ' , ' restart_sampler ' , [ ' restart ' ] , { ' scheduler ' : ' karras ' , " second_order " : True } ) ,
2022-09-03 22:21:15 +08:00
]
2023-07-18 12:32:01 +08:00
@torch.no_grad ( )
2023-07-21 08:34:41 +08:00
def restart_sampler ( model , x , sigmas , extra_args = None , callback = None , disable = None , s_noise = 1. , restart_list = None ) :
2023-07-18 12:32:01 +08:00
""" Implements restart sampling in Restart Sampling for Improving Generative Processes (2023) """
''' Restart_list format: { min_sigma: [ restart_steps, restart_times, max_sigma]} '''
2023-07-21 08:34:41 +08:00
''' If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list '''
2023-07-18 13:02:04 +08:00
from tqdm . auto import trange
2023-07-18 12:32:01 +08:00
extra_args = { } if extra_args is None else extra_args
s_in = x . new_ones ( [ x . shape [ 0 ] ] )
step_id = 0
2023-07-21 08:34:41 +08:00
from k_diffusion . sampling import to_d , append_zero , get_sigmas_karras
2023-07-20 14:24:18 +08:00
def heun_step ( x , old_sigma , new_sigma , second_order = True ) :
2023-07-18 12:32:01 +08:00
nonlocal step_id
denoised = model ( x , old_sigma * s_in , * * extra_args )
d = to_d ( x , old_sigma , denoised )
if callback is not None :
callback ( { ' x ' : x , ' i ' : step_id , ' sigma ' : new_sigma , ' sigma_hat ' : old_sigma , ' denoised ' : denoised } )
dt = new_sigma - old_sigma
2023-07-20 14:24:18 +08:00
if new_sigma == 0 or not second_order :
2023-07-18 12:32:01 +08:00
# Euler method
x = x + d * dt
else :
# Heun's method
x_2 = x + d * dt
denoised_2 = model ( x_2 , new_sigma * s_in , * * extra_args )
d_2 = to_d ( x_2 , new_sigma , denoised_2 )
d_prime = ( d + d_2 ) / 2
x = x + d_prime * dt
step_id + = 1
return x
2023-07-20 14:24:18 +08:00
steps = sigmas . shape [ 0 ] - 1
2023-07-21 08:34:41 +08:00
if restart_list is None :
if steps > = 20 :
restart_steps = 9
restart_times = 2 if steps > = 36 else 1
sigmas = get_sigmas_karras ( steps - restart_steps * restart_times , sigmas [ - 2 ] . item ( ) , sigmas [ 0 ] . item ( ) , device = sigmas . device )
restart_list = { 0.1 : [ restart_steps + 1 , restart_times , 2 ] }
else :
restart_list = dict ( )
2023-07-20 14:24:18 +08:00
temp_list = dict ( )
for key , value in restart_list . items ( ) :
temp_list [ int ( torch . argmin ( abs ( sigmas - key ) , dim = 0 ) ) ] = value
restart_list = temp_list
2023-07-18 12:32:01 +08:00
for i in trange ( len ( sigmas ) - 1 , disable = disable ) :
x = heun_step ( x , sigmas [ i ] , sigmas [ i + 1 ] )
if i + 1 in restart_list :
restart_steps , restart_times , restart_max = restart_list [ i + 1 ]
min_idx = i + 1
max_idx = int ( torch . argmin ( abs ( sigmas - restart_max ) , dim = 0 ) )
2023-07-18 12:55:02 +08:00
if max_idx < min_idx :
2023-07-21 08:34:41 +08:00
sigma_restart = get_sigmas_karras ( restart_steps , sigmas [ min_idx ] . item ( ) , sigmas [ max_idx ] . item ( ) , device = sigmas . device ) [ : - 1 ] # remove the zero at the end
2023-07-18 13:02:04 +08:00
while restart_times > 0 :
restart_times - = 1
2023-07-18 12:55:02 +08:00
x = x + torch . randn_like ( x ) * s_noise * ( sigmas [ max_idx ] * * 2 - sigmas [ min_idx ] * * 2 ) * * 0.5
for ( old_sigma , new_sigma ) in zip ( sigma_restart [ : - 1 ] , sigma_restart [ 1 : ] ) :
x = heun_step ( x , old_sigma , new_sigma )
2023-07-18 12:32:01 +08:00
return x
2022-09-03 22:21:15 +08:00
samplers_data_k_diffusion = [
2023-01-30 14:51:06 +08:00
sd_samplers_common . SamplerData ( label , lambda model , funcname = funcname : KDiffusionSampler ( funcname , model ) , aliases , options )
2022-10-06 19:12:52 +08:00
for label , funcname , aliases , options in samplers_k_diffusion
2023-07-18 12:32:01 +08:00
if ( hasattr ( k_diffusion . sampling , funcname ) or funcname == ' restart_sampler ' )
2022-09-03 22:21:15 +08:00
]
2022-09-26 16:56:47 +08:00
sampler_extra_params = {
2022-09-28 15:49:07 +08:00
' sample_euler ' : [ ' s_churn ' , ' s_tmin ' , ' s_tmax ' , ' s_noise ' ] ,
' sample_heun ' : [ ' s_churn ' , ' s_tmin ' , ' s_tmax ' , ' s_noise ' ] ,
' sample_dpm_2 ' : [ ' s_churn ' , ' s_tmin ' , ' s_tmax ' , ' s_noise ' ] ,
2022-09-26 16:56:47 +08:00
}
2022-09-03 17:08:45 +08:00
2023-05-22 23:26:28 +08:00
k_diffusion_samplers_map = { x . name : x for x in samplers_data_k_diffusion }
2023-05-22 23:02:05 +08:00
k_diffusion_scheduler = {
2023-05-24 00:18:09 +08:00
' Automatic ' : None ,
2023-05-22 23:02:05 +08:00
' karras ' : k_diffusion . sampling . get_sigmas_karras ,
' exponential ' : k_diffusion . sampling . get_sigmas_exponential ,
' polyexponential ' : k_diffusion . sampling . get_sigmas_polyexponential
}
2022-10-23 01:48:13 +08:00
2022-09-03 17:08:45 +08:00
class CFGDenoiser ( torch . nn . Module ) :
2023-01-30 15:11:30 +08:00
"""
Classifier free guidance denoiser . A wrapper for stable diffusion model ( specifically for unet )
that can take a noisy picture and produce a noise - free picture using two guidances ( prompts )
instead of one . Originally , the second prompt is just an empty string , but we use non - empty
negative prompt .
"""
2022-09-03 17:08:45 +08:00
def __init__ ( self , model ) :
super ( ) . __init__ ( )
self . inner_model = model
self . mask = None
self . nmask = None
self . init_latent = None
2022-09-15 18:10:16 +08:00
self . step = 0
2023-02-04 16:06:17 +08:00
self . image_cfg_scale = None
2023-06-27 11:18:43 +08:00
self . padded_cond_uncond = False
2022-09-03 17:08:45 +08:00
2022-12-24 23:38:16 +08:00
def combine_denoised ( self , x_out , conds_list , uncond , cond_scale ) :
denoised_uncond = x_out [ - uncond . shape [ 0 ] : ]
denoised = torch . clone ( denoised_uncond )
for i , conds in enumerate ( conds_list ) :
for cond_index , weight in conds :
denoised [ i ] + = ( x_out [ cond_index ] - denoised_uncond [ i ] ) * ( weight * cond_scale )
return denoised
2023-02-04 16:06:17 +08:00
def combine_denoised_for_edit_model ( self , x_out , cond_scale ) :
out_cond , out_img_cond , out_uncond = x_out . chunk ( 3 )
denoised = out_uncond + cond_scale * ( out_cond - out_img_cond ) + self . image_cfg_scale * ( out_img_cond - out_uncond )
return denoised
2023-03-29 06:18:28 +08:00
def forward ( self , x , sigma , uncond , cond , cond_scale , s_min_uncond , image_cond ) :
2022-10-18 22:23:38 +08:00
if state . interrupted or state . skipped :
2023-01-30 14:51:06 +08:00
raise sd_samplers_common . InterruptedException
2022-10-18 22:23:38 +08:00
2023-02-04 16:06:17 +08:00
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
# so is_edit_model is set to False to support AND composition.
is_edit_model = shared . sd_model . cond_stage_key == " edit " and self . image_cfg_scale is not None and self . image_cfg_scale != 1.0
2022-10-06 04:16:27 +08:00
conds_list , tensor = prompt_parser . reconstruct_multicond_batch ( cond , self . step )
2022-09-15 18:10:16 +08:00
uncond = prompt_parser . reconstruct_cond_batch ( uncond , self . step )
2023-05-10 16:05:02 +08:00
assert not is_edit_model or all ( len ( conds ) == 1 for conds in conds_list ) , " AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0) "
2023-02-04 16:06:17 +08:00
2022-10-06 04:16:27 +08:00
batch_size = len ( conds_list )
repeats = [ len ( conds_list [ i ] ) for i in range ( batch_size ) ]
2023-03-25 10:48:16 +08:00
if shared . sd_model . model . conditioning_key == " crossattn-adm " :
image_uncond = torch . zeros_like ( image_cond )
2023-05-11 23:28:15 +08:00
make_condition_dict = lambda c_crossattn , c_adm : { " c_crossattn " : c_crossattn , " c_adm " : c_adm }
2023-03-25 10:48:16 +08:00
else :
image_uncond = image_cond
2023-05-11 23:28:15 +08:00
make_condition_dict = lambda c_crossattn , c_concat : { " c_crossattn " : c_crossattn , " c_concat " : [ c_concat ] }
2023-03-25 10:48:16 +08:00
2023-02-04 16:06:17 +08:00
if not is_edit_model :
x_in = torch . cat ( [ torch . stack ( [ x [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ x ] )
sigma_in = torch . cat ( [ torch . stack ( [ sigma [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ sigma ] )
2023-03-25 10:48:16 +08:00
image_cond_in = torch . cat ( [ torch . stack ( [ image_cond [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ image_uncond ] )
2023-02-04 16:06:17 +08:00
else :
x_in = torch . cat ( [ torch . stack ( [ x [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ x ] + [ x ] )
sigma_in = torch . cat ( [ torch . stack ( [ sigma [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ sigma ] + [ sigma ] )
2023-03-25 10:48:16 +08:00
image_cond_in = torch . cat ( [ torch . stack ( [ image_cond [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ image_uncond ] + [ torch . zeros_like ( self . init_latent ) ] )
2022-10-06 04:16:27 +08:00
2023-02-24 13:04:23 +08:00
denoiser_params = CFGDenoiserParams ( x_in , image_cond_in , sigma_in , state . sampling_step , state . sampling_steps , tensor , uncond )
2022-11-02 08:38:17 +08:00
cfg_denoiser_callback ( denoiser_params )
x_in = denoiser_params . x
image_cond_in = denoiser_params . image_cond
sigma_in = denoiser_params . sigma
2023-03-11 19:52:29 +08:00
tensor = denoiser_params . text_cond
uncond = denoiser_params . text_uncond
2023-04-29 20:57:09 +08:00
skip_uncond = False
2022-10-31 07:48:33 +08:00
2023-04-29 20:57:09 +08:00
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
if self . step % 2 and s_min_uncond > 0 and sigma [ 0 ] < s_min_uncond and not is_edit_model :
skip_uncond = True
x_in = x_in [ : - batch_size ]
sigma_in = sigma_in [ : - batch_size ]
2023-03-29 06:18:28 +08:00
2023-06-27 11:18:43 +08:00
self . padded_cond_uncond = False
2023-05-22 05:13:53 +08:00
if shared . opts . pad_cond_uncond and tensor . shape [ 1 ] != uncond . shape [ 1 ] :
empty = shared . sd_model . cond_stage_model_empty_prompt
num_repeats = ( tensor . shape [ 1 ] - uncond . shape [ 1 ] ) / / empty . shape [ 1 ]
if num_repeats < 0 :
tensor = torch . cat ( [ tensor , empty . repeat ( ( tensor . shape [ 0 ] , - num_repeats , 1 ) ) ] , axis = 1 )
2023-06-27 11:18:43 +08:00
self . padded_cond_uncond = True
2023-05-22 05:13:53 +08:00
elif num_repeats > 0 :
uncond = torch . cat ( [ uncond , empty . repeat ( ( uncond . shape [ 0 ] , num_repeats , 1 ) ) ] , axis = 1 )
2023-06-27 11:18:43 +08:00
self . padded_cond_uncond = True
2023-05-22 05:13:53 +08:00
2023-04-29 20:57:09 +08:00
if tensor . shape [ 1 ] == uncond . shape [ 1 ] or skip_uncond :
if is_edit_model :
2023-02-04 16:06:17 +08:00
cond_in = torch . cat ( [ tensor , uncond , uncond ] )
2023-04-29 20:57:09 +08:00
elif skip_uncond :
cond_in = tensor
else :
cond_in = torch . cat ( [ tensor , uncond ] )
2022-10-08 20:25:59 +08:00
if shared . batch_cond_uncond :
2023-03-25 10:48:16 +08:00
x_out = self . inner_model ( x_in , sigma_in , cond = make_condition_dict ( [ cond_in ] , image_cond_in ) )
2022-10-08 20:25:59 +08:00
else :
x_out = torch . zeros_like ( x_in )
for batch_offset in range ( 0 , x_out . shape [ 0 ] , batch_size ) :
a = batch_offset
b = a + batch_size
2023-03-25 10:48:16 +08:00
x_out [ a : b ] = self . inner_model ( x_in [ a : b ] , sigma_in [ a : b ] , cond = make_condition_dict ( [ cond_in [ a : b ] ] , image_cond_in [ a : b ] ) )
2022-09-03 17:08:45 +08:00
else :
2022-10-06 04:16:27 +08:00
x_out = torch . zeros_like ( x_in )
2022-10-08 20:25:59 +08:00
batch_size = batch_size * 2 if shared . batch_cond_uncond else batch_size
for batch_offset in range ( 0 , tensor . shape [ 0 ] , batch_size ) :
2022-10-06 04:16:27 +08:00
a = batch_offset
2022-10-08 20:25:59 +08:00
b = min ( a + batch_size , tensor . shape [ 0 ] )
2023-02-04 16:06:17 +08:00
if not is_edit_model :
c_crossattn = [ tensor [ a : b ] ]
else :
c_crossattn = torch . cat ( [ tensor [ a : b ] ] , uncond )
2023-03-25 10:48:16 +08:00
x_out [ a : b ] = self . inner_model ( x_in [ a : b ] , sigma_in [ a : b ] , cond = make_condition_dict ( c_crossattn , image_cond_in [ a : b ] ) )
2022-10-08 20:25:59 +08:00
2023-04-29 20:57:09 +08:00
if not skip_uncond :
2023-03-29 06:18:28 +08:00
x_out [ - uncond . shape [ 0 ] : ] = self . inner_model ( x_in [ - uncond . shape [ 0 ] : ] , sigma_in [ - uncond . shape [ 0 ] : ] , cond = make_condition_dict ( [ uncond ] , image_cond_in [ - uncond . shape [ 0 ] : ] ) )
2022-10-06 04:16:27 +08:00
2023-04-29 21:05:20 +08:00
denoised_image_indexes = [ x [ 0 ] [ 0 ] for x in conds_list ]
2023-04-29 20:57:09 +08:00
if skip_uncond :
fake_uncond = torch . cat ( [ x_out [ i : i + 1 ] for i in denoised_image_indexes ] )
2023-04-29 21:05:20 +08:00
x_out = torch . cat ( [ x_out , fake_uncond ] ) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
2023-04-29 20:57:09 +08:00
2023-05-14 09:49:41 +08:00
denoised_params = CFGDenoisedParams ( x_out , state . sampling_step , state . sampling_steps , self . inner_model )
2023-02-11 10:18:38 +08:00
cfg_denoised_callback ( denoised_params )
2023-01-17 03:59:46 +08:00
devices . test_for_nans ( x_out , " unet " )
2023-01-14 21:29:23 +08:00
if opts . live_preview_content == " Prompt " :
2023-04-29 21:06:20 +08:00
sd_samplers_common . store_latent ( torch . cat ( [ x_out [ i : i + 1 ] for i in denoised_image_indexes ] ) )
2023-01-14 21:29:23 +08:00
elif opts . live_preview_content == " Negative prompt " :
2023-01-30 14:51:06 +08:00
sd_samplers_common . store_latent ( x_out [ - uncond . shape [ 0 ] : ] )
2023-01-14 21:29:23 +08:00
2023-04-29 20:57:09 +08:00
if is_edit_model :
2023-02-04 16:06:17 +08:00
denoised = self . combine_denoised_for_edit_model ( x_out , cond_scale )
2023-04-29 20:57:09 +08:00
elif skip_uncond :
denoised = self . combine_denoised ( x_out , conds_list , uncond , 1.0 )
else :
denoised = self . combine_denoised ( x_out , conds_list , uncond , cond_scale )
2022-09-03 17:08:45 +08:00
if self . mask is not None :
denoised = self . init_latent * self . mask + self . nmask * denoised
2023-05-14 09:49:41 +08:00
after_cfg_callback_params = AfterCFGCallbackParams ( denoised , state . sampling_step , state . sampling_steps )
cfg_after_cfg_callback ( after_cfg_callback_params )
2023-05-14 13:15:22 +08:00
denoised = after_cfg_callback_params . x
2023-05-14 09:49:41 +08:00
2022-09-15 18:10:16 +08:00
self . step + = 1
2022-09-03 17:08:45 +08:00
return denoised
2022-09-16 14:47:03 +08:00
class TorchHijack :
2022-11-26 10:12:23 +08:00
def __init__ ( self , sampler_noises ) :
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
# implementation.
self . sampler_noises = deque ( sampler_noises )
2022-09-16 14:47:03 +08:00
def __getattr__ ( self , item ) :
if item == ' randn_like ' :
2022-11-26 10:12:23 +08:00
return self . randn_like
2022-09-16 14:47:03 +08:00
if hasattr ( torch , item ) :
return getattr ( torch , item )
2023-05-10 03:17:58 +08:00
raise AttributeError ( f " ' { type ( self ) . __name__ } ' object has no attribute ' { item } ' " )
2022-09-16 14:47:03 +08:00
2022-11-26 10:12:23 +08:00
def randn_like ( self , x ) :
if self . sampler_noises :
noise = self . sampler_noises . popleft ( )
if noise . shape == x . shape :
return noise
2023-04-29 16:29:37 +08:00
if opts . randn_source == " CPU " or x . device . type == ' mps ' :
2022-11-30 21:02:39 +08:00
return torch . randn_like ( x , device = devices . cpu ) . to ( x . device )
else :
return torch . randn_like ( x )
2022-11-26 10:12:23 +08:00
2022-09-14 02:49:58 +08:00
2022-09-03 17:08:45 +08:00
class KDiffusionSampler :
def __init__ ( self , funcname , sd_model ) :
2022-11-26 21:10:46 +08:00
denoiser = k_diffusion . external . CompVisVDenoiser if sd_model . parameterization == " v " else k_diffusion . external . CompVisDenoiser
self . model_wrap = denoiser ( sd_model , quantize = shared . opts . enable_quantization )
2022-09-03 17:08:45 +08:00
self . funcname = funcname
2023-07-18 12:32:01 +08:00
self . func = getattr ( k_diffusion . sampling , self . funcname ) if funcname != " restart_sampler " else restart_sampler
2022-09-28 15:49:07 +08:00
self . extra_params = sampler_extra_params . get ( funcname , [ ] )
2023-02-04 08:46:13 +08:00
self . model_wrap_cfg = CFGDenoiser ( self . model_wrap )
2022-09-14 02:49:58 +08:00
self . sampler_noises = None
2022-09-19 21:42:56 +08:00
self . stop_at = None
2022-09-28 23:09:06 +08:00
self . eta = None
2023-05-21 12:31:39 +08:00
self . config = None # set by the function calling the constructor
2022-10-18 22:23:38 +08:00
self . last_latent = None
2023-04-29 20:57:09 +08:00
self . s_min_uncond = None
2022-09-03 17:08:45 +08:00
2022-10-20 06:09:43 +08:00
self . conditioning_key = sd_model . model . conditioning_key
2022-09-07 00:33:51 +08:00
def callback_state ( self , d ) :
2022-10-18 22:23:38 +08:00
step = d [ ' i ' ]
latent = d [ " denoised " ]
2023-01-14 21:29:23 +08:00
if opts . live_preview_content == " Combined " :
2023-01-30 14:51:06 +08:00
sd_samplers_common . store_latent ( latent )
2022-10-18 22:23:38 +08:00
self . last_latent = latent
if self . stop_at is not None and step > self . stop_at :
2023-01-30 14:51:06 +08:00
raise sd_samplers_common . InterruptedException
2022-10-18 22:23:38 +08:00
state . sampling_step = step
shared . total_tqdm . update ( )
def launch_sampling ( self , steps , func ) :
state . sampling_steps = steps
state . sampling_step = 0
try :
return func ( )
2023-05-23 00:09:49 +08:00
except RecursionError :
print (
2023-05-23 09:38:30 +08:00
' Encountered RecursionError during sampling, returning last latent. '
' rho >5 with a polyexponential scheduler may cause this error. '
' You should try to use a smaller rho value instead. '
2023-05-23 00:09:49 +08:00
)
return self . last_latent
2023-01-30 14:51:06 +08:00
except sd_samplers_common . InterruptedException :
2022-10-18 22:23:38 +08:00
return self . last_latent
2022-09-07 00:33:51 +08:00
2022-09-14 02:49:58 +08:00
def number_of_needed_noises ( self , p ) :
return p . steps
2022-09-28 23:09:06 +08:00
def initialize ( self , p ) :
2022-09-19 21:42:56 +08:00
self . model_wrap_cfg . mask = p . mask if hasattr ( p , ' mask ' ) else None
self . model_wrap_cfg . nmask = p . nmask if hasattr ( p , ' nmask ' ) else None
2023-01-26 04:25:40 +08:00
self . model_wrap_cfg . step = 0
2023-02-04 16:06:17 +08:00
self . model_wrap_cfg . image_cfg_scale = getattr ( p , ' image_cfg_scale ' , None )
2023-01-30 15:47:09 +08:00
self . eta = p . eta if p . eta is not None else opts . eta_ancestral
2023-03-29 06:18:28 +08:00
self . s_min_uncond = getattr ( p , ' s_min_uncond ' , 0.0 )
2022-09-03 17:08:45 +08:00
2022-11-30 21:02:39 +08:00
k_diffusion . sampling . torch = TorchHijack ( self . sampler_noises if self . sampler_noises is not None else [ ] )
2022-09-16 14:47:03 +08:00
2022-09-26 16:56:47 +08:00
extra_params_kwargs = { }
2022-09-28 15:49:07 +08:00
for param_name in self . extra_params :
if hasattr ( p , param_name ) and param_name in inspect . signature ( self . func ) . parameters :
extra_params_kwargs [ param_name ] = getattr ( p , param_name )
2022-09-26 16:56:47 +08:00
2022-09-28 23:09:06 +08:00
if ' eta ' in inspect . signature ( self . func ) . parameters :
2023-01-30 15:47:09 +08:00
if self . eta != 1.0 :
p . extra_generation_params [ " Eta " ] = self . eta
2022-09-28 23:09:06 +08:00
extra_params_kwargs [ ' eta ' ] = self . eta
return extra_params_kwargs
2022-12-24 14:03:45 +08:00
def get_sigmas ( self , p , steps ) :
2023-01-05 15:43:21 +08:00
discard_next_to_last_sigma = self . config is not None and self . config . options . get ( ' discard_next_to_last_sigma ' , False )
if opts . always_discard_next_to_last_sigma and not discard_next_to_last_sigma :
discard_next_to_last_sigma = True
p . extra_generation_params [ " Discard penultimate sigma " ] = True
steps + = 1 if discard_next_to_last_sigma else 0
2022-12-27 04:49:13 +08:00
2022-09-30 08:46:06 +08:00
if p . sampler_noise_scheduler_override :
2022-10-07 04:27:01 +08:00
sigmas = p . sampler_noise_scheduler_override ( steps )
2023-05-24 00:18:09 +08:00
elif opts . k_sched_type != " Automatic " :
2023-05-24 20:35:58 +08:00
m_sigma_min , m_sigma_max = ( self . model_wrap . sigmas [ 0 ] . item ( ) , self . model_wrap . sigmas [ - 1 ] . item ( ) )
2023-05-28 00:53:09 +08:00
sigma_min , sigma_max = ( 0.1 , 10 ) if opts . use_old_karras_scheduler_sigmas else ( m_sigma_min , m_sigma_max )
2023-05-22 23:02:05 +08:00
sigmas_kwargs = {
2023-05-28 00:53:09 +08:00
' sigma_min ' : sigma_min ,
' sigma_max ' : sigma_max ,
2023-05-22 23:02:05 +08:00
}
2023-05-24 20:35:58 +08:00
sigmas_func = k_diffusion_scheduler [ opts . k_sched_type ]
2023-05-28 00:53:09 +08:00
p . extra_generation_params [ " Schedule type " ] = opts . k_sched_type
if opts . sigma_min != m_sigma_min and opts . sigma_min != 0 :
sigmas_kwargs [ ' sigma_min ' ] = opts . sigma_min
p . extra_generation_params [ " Schedule min sigma " ] = opts . sigma_min
if opts . sigma_max != m_sigma_max and opts . sigma_max != 0 :
sigmas_kwargs [ ' sigma_max ' ] = opts . sigma_max
p . extra_generation_params [ " Schedule max sigma " ] = opts . sigma_max
default_rho = 1. if opts . k_sched_type == " polyexponential " else 7.
if opts . k_sched_type != ' exponential ' and opts . rho != 0 and opts . rho != default_rho :
2023-05-23 11:34:51 +08:00
sigmas_kwargs [ ' rho ' ] = opts . rho
2023-05-28 00:53:09 +08:00
p . extra_generation_params [ " Schedule rho " ] = opts . rho
2023-05-24 20:35:58 +08:00
2023-05-22 23:02:05 +08:00
sigmas = sigmas_func ( n = steps , * * sigmas_kwargs , device = shared . device )
2022-10-07 04:27:01 +08:00
elif self . config is not None and self . config . options . get ( ' scheduler ' , None ) == ' karras ' :
2023-01-01 14:51:37 +08:00
sigma_min , sigma_max = ( 0.1 , 10 ) if opts . use_old_karras_scheduler_sigmas else ( self . model_wrap . sigmas [ 0 ] . item ( ) , self . model_wrap . sigmas [ - 1 ] . item ( ) )
sigmas = k_diffusion . sampling . get_sigmas_karras ( n = steps , sigma_min = sigma_min , sigma_max = sigma_max , device = shared . device )
2022-09-30 08:46:06 +08:00
else :
2022-10-07 04:27:01 +08:00
sigmas = self . model_wrap . get_sigmas ( steps )
2022-09-28 23:09:06 +08:00
2023-01-05 15:43:21 +08:00
if discard_next_to_last_sigma :
2022-12-19 11:16:42 +08:00
sigmas = torch . cat ( [ sigmas [ : - 2 ] , sigmas [ - 1 : ] ] )
2022-12-24 14:03:45 +08:00
return sigmas
2023-02-15 16:57:18 +08:00
def create_noise_sampler ( self , x , sigmas , p ) :
2023-02-11 10:12:16 +08:00
""" For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes """
if shared . opts . no_dpmpp_sde_batch_determinism :
return None
from k_diffusion . sampling import BrownianTreeNoiseSampler
sigma_min , sigma_max = sigmas [ sigmas > 0 ] . min ( ) , sigmas . max ( )
2023-02-15 16:57:18 +08:00
current_iter_seeds = p . all_seeds [ p . iteration * p . batch_size : ( p . iteration + 1 ) * p . batch_size ]
return BrownianTreeNoiseSampler ( x , sigma_min , sigma_max , seed = current_iter_seeds )
2023-02-11 10:12:16 +08:00
2022-12-24 14:03:45 +08:00
def sample_img2img ( self , p , x , noise , conditioning , unconditional_conditioning , steps = None , image_conditioning = None ) :
2023-01-30 14:51:06 +08:00
steps , t_enc = sd_samplers_common . setup_img2img_steps ( p , steps )
2022-12-24 14:03:45 +08:00
sigmas = self . get_sigmas ( p , steps )
2022-09-28 23:09:06 +08:00
sigma_sched = sigmas [ steps - t_enc - 1 : ]
2022-10-11 07:02:44 +08:00
xi = x + noise * sigma_sched [ 0 ]
2023-05-11 23:28:15 +08:00
2022-10-11 07:02:44 +08:00
extra_params_kwargs = self . initialize ( p )
2023-02-11 10:12:16 +08:00
parameters = inspect . signature ( self . func ) . parameters
if ' sigma_min ' in parameters :
2022-10-11 07:36:00 +08:00
## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
2022-10-11 07:02:44 +08:00
extra_params_kwargs [ ' sigma_min ' ] = sigma_sched [ - 2 ]
2023-02-11 10:12:16 +08:00
if ' sigma_max ' in parameters :
2022-10-11 07:02:44 +08:00
extra_params_kwargs [ ' sigma_max ' ] = sigma_sched [ 0 ]
2023-02-11 10:12:16 +08:00
if ' n ' in parameters :
2022-10-11 07:02:44 +08:00
extra_params_kwargs [ ' n ' ] = len ( sigma_sched ) - 1
2023-02-11 10:12:16 +08:00
if ' sigma_sched ' in parameters :
2022-10-11 07:02:44 +08:00
extra_params_kwargs [ ' sigma_sched ' ] = sigma_sched
2023-02-11 10:12:16 +08:00
if ' sigmas ' in parameters :
2022-10-11 07:02:44 +08:00
extra_params_kwargs [ ' sigmas ' ] = sigma_sched
2022-09-28 23:09:06 +08:00
2023-05-21 12:31:39 +08:00
if self . config . options . get ( ' brownian_noise ' , False ) :
2023-02-15 16:57:18 +08:00
noise_sampler = self . create_noise_sampler ( x , sigmas , p )
2023-02-11 10:12:16 +08:00
extra_params_kwargs [ ' noise_sampler ' ] = noise_sampler
2022-09-28 23:09:06 +08:00
self . model_wrap_cfg . init_latent = x
2022-10-21 04:49:14 +08:00
self . last_latent = x
2023-05-21 12:31:39 +08:00
extra_args = {
2023-05-11 23:28:15 +08:00
' cond ' : conditioning ,
' image_cond ' : image_conditioning ,
' uncond ' : unconditional_conditioning ,
2023-02-04 07:19:56 +08:00
' cond_scale ' : p . cfg_scale ,
2023-03-29 06:18:28 +08:00
' s_min_uncond ' : self . s_min_uncond
2023-02-04 07:19:56 +08:00
}
samples = self . launch_sampling ( t_enc + 1 , lambda : self . func ( self . model_wrap_cfg , xi , extra_args = extra_args , disable = False , callback = self . callback_state , * * extra_params_kwargs ) )
2022-10-11 07:02:44 +08:00
2023-06-27 11:18:43 +08:00
if self . model_wrap_cfg . padded_cond_uncond :
p . extra_generation_params [ " Pad conds " ] = True
2022-10-18 22:23:38 +08:00
return samples
2022-09-03 17:08:45 +08:00
2023-02-11 10:12:16 +08:00
def sample ( self , p , x , conditioning , unconditional_conditioning , steps = None , image_conditioning = None ) :
2022-09-19 21:42:56 +08:00
steps = steps or p . steps
2022-12-24 14:03:45 +08:00
sigmas = self . get_sigmas ( p , steps )
2022-10-06 19:12:52 +08:00
2022-09-03 17:08:45 +08:00
x = x * sigmas [ 0 ]
2022-09-28 23:09:06 +08:00
extra_params_kwargs = self . initialize ( p )
2023-02-11 10:12:16 +08:00
parameters = inspect . signature ( self . func ) . parameters
if ' sigma_min ' in parameters :
2022-09-29 18:30:33 +08:00
extra_params_kwargs [ ' sigma_min ' ] = self . model_wrap . sigmas [ 0 ] . item ( )
extra_params_kwargs [ ' sigma_max ' ] = self . model_wrap . sigmas [ - 1 ] . item ( )
2023-02-11 10:12:16 +08:00
if ' n ' in parameters :
2022-09-29 18:30:33 +08:00
extra_params_kwargs [ ' n ' ] = steps
else :
extra_params_kwargs [ ' sigmas ' ] = sigmas
2022-10-18 22:23:38 +08:00
2023-05-21 12:31:39 +08:00
if self . config . options . get ( ' brownian_noise ' , False ) :
2023-02-15 16:57:18 +08:00
noise_sampler = self . create_noise_sampler ( x , sigmas , p )
2023-02-11 10:12:16 +08:00
extra_params_kwargs [ ' noise_sampler ' ] = noise_sampler
2022-10-21 04:49:14 +08:00
self . last_latent = x
2022-10-20 04:47:45 +08:00
samples = self . launch_sampling ( steps , lambda : self . func ( self . model_wrap_cfg , x , extra_args = {
2023-05-11 23:28:15 +08:00
' cond ' : conditioning ,
' image_cond ' : image_conditioning ,
' uncond ' : unconditional_conditioning ,
2023-03-29 06:18:28 +08:00
' cond_scale ' : p . cfg_scale ,
' s_min_uncond ' : self . s_min_uncond
2022-10-20 04:47:45 +08:00
} , disable = False , callback = self . callback_state , * * extra_params_kwargs ) )
2022-10-18 22:23:38 +08:00
2023-06-27 11:18:43 +08:00
if self . model_wrap_cfg . padded_cond_uncond :
p . extra_generation_params [ " Pad conds " ] = True
2022-09-19 21:42:56 +08:00
return samples
2022-09-03 17:08:45 +08:00