2022-09-03 17:08:45 +08:00
import torch
2023-08-08 23:35:31 +08:00
from modules import prompt_parser , devices , sd_samplers_common
2022-09-03 17:08:45 +08:00
2023-01-30 14:51:06 +08:00
from modules . shared import opts , state
2022-09-03 17:08:45 +08:00
import modules . shared as shared
2022-11-02 08:38:17 +08:00
from modules . script_callbacks import CFGDenoiserParams , cfg_denoiser_callback
2023-02-11 10:18:38 +08:00
from modules . script_callbacks import CFGDenoisedParams , cfg_denoised_callback
2023-05-14 09:49:41 +08:00
from modules . script_callbacks import AfterCFGCallbackParams , cfg_after_cfg_callback
2022-09-03 17:08:45 +08:00
2022-10-23 01:48:13 +08:00
2023-07-12 02:16:43 +08:00
def catenate_conds ( conds ) :
if not isinstance ( conds [ 0 ] , dict ) :
return torch . cat ( conds )
return { key : torch . cat ( [ x [ key ] for x in conds ] ) for key in conds [ 0 ] . keys ( ) }
def subscript_cond ( cond , a , b ) :
if not isinstance ( cond , dict ) :
return cond [ a : b ]
return { key : vec [ a : b ] for key , vec in cond . items ( ) }
def pad_cond ( tensor , repeats , empty ) :
if not isinstance ( tensor , dict ) :
return torch . cat ( [ tensor , empty . repeat ( ( tensor . shape [ 0 ] , repeats , 1 ) ) ] , axis = 1 )
tensor [ ' crossattn ' ] = pad_cond ( tensor [ ' crossattn ' ] , repeats , empty )
return tensor
2022-09-03 17:08:45 +08:00
class CFGDenoiser ( torch . nn . Module ) :
2023-01-30 15:11:30 +08:00
"""
Classifier free guidance denoiser . A wrapper for stable diffusion model ( specifically for unet )
that can take a noisy picture and produce a noise - free picture using two guidances ( prompts )
instead of one . Originally , the second prompt is just an empty string , but we use non - empty
negative prompt .
"""
2023-08-09 03:09:40 +08:00
def __init__ ( self , sampler ) :
2022-09-03 17:08:45 +08:00
super ( ) . __init__ ( )
2023-08-09 03:09:40 +08:00
self . model_wrap = None
2022-09-03 17:08:45 +08:00
self . mask = None
self . nmask = None
self . init_latent = None
2023-08-09 03:09:40 +08:00
self . steps = None
2023-08-12 17:39:59 +08:00
""" number of steps as specified by user in UI """
self . total_steps = None
""" expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler """
2022-09-15 18:10:16 +08:00
self . step = 0
2023-02-04 16:06:17 +08:00
self . image_cfg_scale = None
2023-06-27 11:18:43 +08:00
self . padded_cond_uncond = False
2023-08-09 00:20:11 +08:00
self . sampler = sampler
2023-08-09 03:09:40 +08:00
self . model_wrap = None
self . p = None
2023-08-14 13:59:15 +08:00
self . mask_before_denoising = False
2023-08-09 03:09:40 +08:00
@property
def inner_model ( self ) :
raise NotImplementedError ( )
2022-12-24 23:38:16 +08:00
def combine_denoised ( self , x_out , conds_list , uncond , cond_scale ) :
denoised_uncond = x_out [ - uncond . shape [ 0 ] : ]
denoised = torch . clone ( denoised_uncond )
for i , conds in enumerate ( conds_list ) :
for cond_index , weight in conds :
denoised [ i ] + = ( x_out [ cond_index ] - denoised_uncond [ i ] ) * ( weight * cond_scale )
return denoised
2023-02-04 16:06:17 +08:00
def combine_denoised_for_edit_model ( self , x_out , cond_scale ) :
out_cond , out_img_cond , out_uncond = x_out . chunk ( 3 )
denoised = out_uncond + cond_scale * ( out_cond - out_img_cond ) + self . image_cfg_scale * ( out_img_cond - out_uncond )
return denoised
2023-08-09 00:20:11 +08:00
def get_pred_x0 ( self , x_in , x_out , sigma ) :
return x_out
2023-08-09 03:09:40 +08:00
def update_inner_model ( self ) :
self . model_wrap = None
c , uc = self . p . get_conds ( )
self . sampler . sampler_extra_args [ ' cond ' ] = c
self . sampler . sampler_extra_args [ ' uncond ' ] = uc
2023-03-29 06:18:28 +08:00
def forward ( self , x , sigma , uncond , cond , cond_scale , s_min_uncond , image_cond ) :
2022-10-18 22:23:38 +08:00
if state . interrupted or state . skipped :
2023-01-30 14:51:06 +08:00
raise sd_samplers_common . InterruptedException
2022-10-18 22:23:38 +08:00
2023-08-09 03:09:40 +08:00
if sd_samplers_common . apply_refiner ( self ) :
cond = self . sampler . sampler_extra_args [ ' cond ' ]
uncond = self . sampler . sampler_extra_args [ ' uncond ' ]
2023-02-04 16:06:17 +08:00
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
# so is_edit_model is set to False to support AND composition.
is_edit_model = shared . sd_model . cond_stage_key == " edit " and self . image_cfg_scale is not None and self . image_cfg_scale != 1.0
2022-10-06 04:16:27 +08:00
conds_list , tensor = prompt_parser . reconstruct_multicond_batch ( cond , self . step )
2022-09-15 18:10:16 +08:00
uncond = prompt_parser . reconstruct_cond_batch ( uncond , self . step )
2023-05-10 16:05:02 +08:00
assert not is_edit_model or all ( len ( conds ) == 1 for conds in conds_list ) , " AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0) "
2023-02-04 16:06:17 +08:00
2023-08-14 13:59:15 +08:00
if self . mask_before_denoising and self . mask is not None :
2023-08-09 00:20:11 +08:00
x = self . init_latent * self . mask + self . nmask * x
2022-10-06 04:16:27 +08:00
batch_size = len ( conds_list )
repeats = [ len ( conds_list [ i ] ) for i in range ( batch_size ) ]
2023-03-25 10:48:16 +08:00
if shared . sd_model . model . conditioning_key == " crossattn-adm " :
image_uncond = torch . zeros_like ( image_cond )
2023-07-12 02:16:43 +08:00
make_condition_dict = lambda c_crossattn , c_adm : { " c_crossattn " : [ c_crossattn ] , " c_adm " : c_adm }
2023-03-25 10:48:16 +08:00
else :
image_uncond = image_cond
2023-07-12 02:16:43 +08:00
if isinstance ( uncond , dict ) :
make_condition_dict = lambda c_crossattn , c_concat : { * * c_crossattn , " c_concat " : [ c_concat ] }
else :
make_condition_dict = lambda c_crossattn , c_concat : { " c_crossattn " : [ c_crossattn ] , " c_concat " : [ c_concat ] }
2023-03-25 10:48:16 +08:00
2023-02-04 16:06:17 +08:00
if not is_edit_model :
x_in = torch . cat ( [ torch . stack ( [ x [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ x ] )
sigma_in = torch . cat ( [ torch . stack ( [ sigma [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ sigma ] )
2023-03-25 10:48:16 +08:00
image_cond_in = torch . cat ( [ torch . stack ( [ image_cond [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ image_uncond ] )
2023-02-04 16:06:17 +08:00
else :
x_in = torch . cat ( [ torch . stack ( [ x [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ x ] + [ x ] )
sigma_in = torch . cat ( [ torch . stack ( [ sigma [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ sigma ] + [ sigma ] )
2023-03-25 10:48:16 +08:00
image_cond_in = torch . cat ( [ torch . stack ( [ image_cond [ i ] for _ in range ( n ) ] ) for i , n in enumerate ( repeats ) ] + [ image_uncond ] + [ torch . zeros_like ( self . init_latent ) ] )
2022-10-06 04:16:27 +08:00
2023-02-24 13:04:23 +08:00
denoiser_params = CFGDenoiserParams ( x_in , image_cond_in , sigma_in , state . sampling_step , state . sampling_steps , tensor , uncond )
2022-11-02 08:38:17 +08:00
cfg_denoiser_callback ( denoiser_params )
x_in = denoiser_params . x
image_cond_in = denoiser_params . image_cond
sigma_in = denoiser_params . sigma
2023-03-11 19:52:29 +08:00
tensor = denoiser_params . text_cond
uncond = denoiser_params . text_uncond
2023-04-29 20:57:09 +08:00
skip_uncond = False
2022-10-31 07:48:33 +08:00
2023-04-29 20:57:09 +08:00
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
if self . step % 2 and s_min_uncond > 0 and sigma [ 0 ] < s_min_uncond and not is_edit_model :
skip_uncond = True
x_in = x_in [ : - batch_size ]
sigma_in = sigma_in [ : - batch_size ]
2023-03-29 06:18:28 +08:00
2023-06-27 11:18:43 +08:00
self . padded_cond_uncond = False
2023-05-22 05:13:53 +08:00
if shared . opts . pad_cond_uncond and tensor . shape [ 1 ] != uncond . shape [ 1 ] :
empty = shared . sd_model . cond_stage_model_empty_prompt
num_repeats = ( tensor . shape [ 1 ] - uncond . shape [ 1 ] ) / / empty . shape [ 1 ]
if num_repeats < 0 :
2023-07-12 02:16:43 +08:00
tensor = pad_cond ( tensor , - num_repeats , empty )
2023-06-27 11:18:43 +08:00
self . padded_cond_uncond = True
2023-05-22 05:13:53 +08:00
elif num_repeats > 0 :
2023-07-12 02:16:43 +08:00
uncond = pad_cond ( uncond , num_repeats , empty )
2023-06-27 11:18:43 +08:00
self . padded_cond_uncond = True
2023-05-22 05:13:53 +08:00
2023-04-29 20:57:09 +08:00
if tensor . shape [ 1 ] == uncond . shape [ 1 ] or skip_uncond :
if is_edit_model :
2023-07-12 02:16:43 +08:00
cond_in = catenate_conds ( [ tensor , uncond , uncond ] )
2023-04-29 20:57:09 +08:00
elif skip_uncond :
cond_in = tensor
else :
2023-07-12 02:16:43 +08:00
cond_in = catenate_conds ( [ tensor , uncond ] )
2022-10-08 20:25:59 +08:00
if shared . batch_cond_uncond :
2023-07-12 02:16:43 +08:00
x_out = self . inner_model ( x_in , sigma_in , cond = make_condition_dict ( cond_in , image_cond_in ) )
2022-10-08 20:25:59 +08:00
else :
x_out = torch . zeros_like ( x_in )
for batch_offset in range ( 0 , x_out . shape [ 0 ] , batch_size ) :
a = batch_offset
b = a + batch_size
2023-07-13 04:52:43 +08:00
x_out [ a : b ] = self . inner_model ( x_in [ a : b ] , sigma_in [ a : b ] , cond = make_condition_dict ( subscript_cond ( cond_in , a , b ) , image_cond_in [ a : b ] ) )
2022-09-03 17:08:45 +08:00
else :
2022-10-06 04:16:27 +08:00
x_out = torch . zeros_like ( x_in )
2022-10-08 20:25:59 +08:00
batch_size = batch_size * 2 if shared . batch_cond_uncond else batch_size
for batch_offset in range ( 0 , tensor . shape [ 0 ] , batch_size ) :
2022-10-06 04:16:27 +08:00
a = batch_offset
2022-10-08 20:25:59 +08:00
b = min ( a + batch_size , tensor . shape [ 0 ] )
2023-02-04 16:06:17 +08:00
if not is_edit_model :
2023-07-12 02:16:43 +08:00
c_crossattn = subscript_cond ( tensor , a , b )
2023-02-04 16:06:17 +08:00
else :
c_crossattn = torch . cat ( [ tensor [ a : b ] ] , uncond )
2023-03-25 10:48:16 +08:00
x_out [ a : b ] = self . inner_model ( x_in [ a : b ] , sigma_in [ a : b ] , cond = make_condition_dict ( c_crossattn , image_cond_in [ a : b ] ) )
2022-10-08 20:25:59 +08:00
2023-04-29 20:57:09 +08:00
if not skip_uncond :
2023-07-12 02:16:43 +08:00
x_out [ - uncond . shape [ 0 ] : ] = self . inner_model ( x_in [ - uncond . shape [ 0 ] : ] , sigma_in [ - uncond . shape [ 0 ] : ] , cond = make_condition_dict ( uncond , image_cond_in [ - uncond . shape [ 0 ] : ] ) )
2022-10-06 04:16:27 +08:00
2023-04-29 21:05:20 +08:00
denoised_image_indexes = [ x [ 0 ] [ 0 ] for x in conds_list ]
2023-04-29 20:57:09 +08:00
if skip_uncond :
fake_uncond = torch . cat ( [ x_out [ i : i + 1 ] for i in denoised_image_indexes ] )
2023-04-29 21:05:20 +08:00
x_out = torch . cat ( [ x_out , fake_uncond ] ) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
2023-04-29 20:57:09 +08:00
2023-05-14 09:49:41 +08:00
denoised_params = CFGDenoisedParams ( x_out , state . sampling_step , state . sampling_steps , self . inner_model )
2023-02-11 10:18:38 +08:00
cfg_denoised_callback ( denoised_params )
2023-01-17 03:59:46 +08:00
devices . test_for_nans ( x_out , " unet " )
2023-04-29 20:57:09 +08:00
if is_edit_model :
2023-02-04 16:06:17 +08:00
denoised = self . combine_denoised_for_edit_model ( x_out , cond_scale )
2023-04-29 20:57:09 +08:00
elif skip_uncond :
denoised = self . combine_denoised ( x_out , conds_list , uncond , 1.0 )
else :
denoised = self . combine_denoised ( x_out , conds_list , uncond , cond_scale )
2022-09-03 17:08:45 +08:00
2023-08-14 13:59:15 +08:00
if not self . mask_before_denoising and self . mask is not None :
denoised = self . init_latent * self . mask + self . nmask * denoised
2023-08-09 00:20:11 +08:00
self . sampler . last_latent = self . get_pred_x0 ( torch . cat ( [ x_in [ i : i + 1 ] for i in denoised_image_indexes ] ) , torch . cat ( [ x_out [ i : i + 1 ] for i in denoised_image_indexes ] ) , sigma )
if opts . live_preview_content == " Prompt " :
preview = self . sampler . last_latent
elif opts . live_preview_content == " Negative prompt " :
preview = self . get_pred_x0 ( x_in [ - uncond . shape [ 0 ] : ] , x_out [ - uncond . shape [ 0 ] : ] , sigma )
else :
preview = self . get_pred_x0 ( torch . cat ( [ x_in [ i : i + 1 ] for i in denoised_image_indexes ] ) , torch . cat ( [ denoised [ i : i + 1 ] for i in denoised_image_indexes ] ) , sigma )
sd_samplers_common . store_latent ( preview )
2022-09-03 17:08:45 +08:00
2023-05-14 09:49:41 +08:00
after_cfg_callback_params = AfterCFGCallbackParams ( denoised , state . sampling_step , state . sampling_steps )
cfg_after_cfg_callback ( after_cfg_callback_params )
2023-05-14 13:15:22 +08:00
denoised = after_cfg_callback_params . x
2023-05-14 09:49:41 +08:00
2022-09-15 18:10:16 +08:00
self . step + = 1
2022-09-03 17:08:45 +08:00
return denoised