2022-09-03 17:08:45 +08:00
import torch
2022-10-03 05:31:19 +08:00
from torch . nn . functional import silu
2023-01-12 22:03:46 +08:00
from types import MethodType
2022-09-03 17:08:45 +08:00
2022-10-02 20:03:39 +08:00
import modules . textual_inversion . textual_inversion
2023-05-10 13:43:42 +08:00
from modules import devices , sd_hijack_optimizations , shared
2022-11-26 21:45:57 +08:00
from modules . hypernetworks import hypernetwork
2022-12-10 14:17:39 +08:00
from modules . shared import cmd_opts
2022-12-31 23:06:35 +08:00
from modules import sd_hijack_clip , sd_hijack_open_clip , sd_hijack_unet , sd_hijack_xlmr , xlmr
2022-11-26 21:10:46 +08:00
2022-09-05 06:41:20 +08:00
import ldm . modules . attention
2022-09-13 19:29:56 +08:00
import ldm . modules . diffusionmodules . model
2022-12-02 20:47:02 +08:00
import ldm . modules . diffusionmodules . openaimodel
2022-11-11 23:20:18 +08:00
import ldm . models . diffusion . ddim
import ldm . models . diffusion . plms
2022-11-26 21:10:46 +08:00
import ldm . modules . encoders . modules
2022-09-13 19:29:56 +08:00
2022-10-02 20:03:39 +08:00
attention_CrossAttention_forward = ldm . modules . attention . CrossAttention . forward
diffusionmodules_model_nonlinearity = ldm . modules . diffusionmodules . model . nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm . modules . diffusionmodules . model . AttnBlock . forward
2022-09-13 19:29:56 +08:00
2022-11-26 21:10:46 +08:00
# new memory efficient cross attention blocks do not support hypernets and we already
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
ldm . modules . attention . MemoryEfficientCrossAttention = ldm . modules . attention . CrossAttention
ldm . modules . attention . BasicTransformerBlock . ATTENTION_MODES [ " softmax-xformers " ] = ldm . modules . attention . CrossAttention
# silence new console spam from SD2
ldm . modules . attention . print = lambda * args : None
ldm . modules . diffusionmodules . model . print = lambda * args : None
2022-10-15 21:59:37 +08:00
2022-12-10 14:14:30 +08:00
2022-10-02 20:03:39 +08:00
def apply_optimizations ( ) :
2022-10-07 21:39:51 +08:00
undo_optimizations ( )
2022-10-03 05:31:19 +08:00
ldm . modules . diffusionmodules . model . nonlinearity = silu
2022-12-10 14:14:30 +08:00
ldm . modules . diffusionmodules . openaimodel . th = sd_hijack_unet . th
2023-01-04 21:04:38 +08:00
optimization_method = None
2022-09-13 19:29:56 +08:00
2023-05-10 16:05:02 +08:00
can_use_sdp = hasattr ( torch . nn . functional , " scaled_dot_product_attention " ) and callable ( torch . nn . functional . scaled_dot_product_attention ) # not everyone has torch 2.x to use sdp
2023-03-10 15:58:10 +08:00
2022-10-16 00:19:54 +08:00
if cmd_opts . force_enable_xformers or ( cmd_opts . xformers and shared . xformers_available and torch . version . cuda and ( 6 , 0 ) < = torch . cuda . get_device_capability ( shared . device ) < = ( 9 , 0 ) ) :
2022-10-09 00:22:15 +08:00
print ( " Applying xformers cross attention optimization. " )
2022-10-08 22:44:53 +08:00
ldm . modules . attention . CrossAttention . forward = sd_hijack_optimizations . xformers_attention_forward
2022-10-18 03:19:18 +08:00
ldm . modules . diffusionmodules . model . AttnBlock . forward = sd_hijack_optimizations . xformers_attnblock_forward
2023-01-04 21:04:38 +08:00
optimization_method = ' xformers '
2023-03-10 15:58:10 +08:00
elif cmd_opts . opt_sdp_no_mem_attention and can_use_sdp :
print ( " Applying scaled dot product cross attention optimization (without memory efficient attention). " )
ldm . modules . attention . CrossAttention . forward = sd_hijack_optimizations . scaled_dot_product_no_mem_attention_forward
2023-03-11 01:48:41 +08:00
ldm . modules . diffusionmodules . model . AttnBlock . forward = sd_hijack_optimizations . sdp_no_mem_attnblock_forward
2023-03-10 15:58:10 +08:00
optimization_method = ' sdp-no-mem '
elif cmd_opts . opt_sdp_attention and can_use_sdp :
print ( " Applying scaled dot product cross attention optimization. " )
ldm . modules . attention . CrossAttention . forward = sd_hijack_optimizations . scaled_dot_product_attention_forward
2023-03-11 01:48:41 +08:00
ldm . modules . diffusionmodules . model . AttnBlock . forward = sd_hijack_optimizations . sdp_attnblock_forward
2023-03-10 15:58:10 +08:00
optimization_method = ' sdp '
2022-12-27 21:50:55 +08:00
elif cmd_opts . opt_sub_quad_attention :
print ( " Applying sub-quadratic cross attention optimization. " )
ldm . modules . attention . CrossAttention . forward = sd_hijack_optimizations . sub_quad_attention_forward
ldm . modules . diffusionmodules . model . AttnBlock . forward = sd_hijack_optimizations . sub_quad_attnblock_forward
2023-01-05 12:10:31 +08:00
optimization_method = ' sub-quadratic '
2022-10-08 09:10:35 +08:00
elif cmd_opts . opt_split_attention_v1 :
2022-10-09 00:22:15 +08:00
print ( " Applying v1 cross attention optimization. " )
2022-10-02 20:03:39 +08:00
ldm . modules . attention . CrossAttention . forward = sd_hijack_optimizations . split_cross_attention_forward_v1
2023-01-04 21:04:38 +08:00
optimization_method = ' V1 '
2023-01-06 14:33:15 +08:00
elif not cmd_opts . disable_opt_split_attention and ( cmd_opts . opt_split_attention_invokeai or not cmd_opts . opt_split_attention and not torch . cuda . is_available ( ) ) :
2022-12-27 21:50:55 +08:00
print ( " Applying cross attention optimization (InvokeAI). " )
ldm . modules . attention . CrossAttention . forward = sd_hijack_optimizations . split_cross_attention_forward_invokeAI
2023-01-05 12:10:31 +08:00
optimization_method = ' InvokeAI '
2022-10-02 20:03:39 +08:00
elif not cmd_opts . disable_opt_split_attention and ( cmd_opts . opt_split_attention or torch . cuda . is_available ( ) ) :
2022-10-11 10:48:54 +08:00
print ( " Applying cross attention optimization (Doggettx). " )
2022-10-02 20:03:39 +08:00
ldm . modules . attention . CrossAttention . forward = sd_hijack_optimizations . split_cross_attention_forward
ldm . modules . diffusionmodules . model . AttnBlock . forward = sd_hijack_optimizations . cross_attention_attnblock_forward
2023-01-04 21:04:38 +08:00
optimization_method = ' Doggettx '
return optimization_method
2022-09-13 19:29:56 +08:00
2022-10-02 20:03:39 +08:00
def undo_optimizations ( ) :
2022-11-26 21:45:57 +08:00
ldm . modules . attention . CrossAttention . forward = hypernetwork . attention_CrossAttention_forward
2022-10-02 20:03:39 +08:00
ldm . modules . diffusionmodules . model . nonlinearity = diffusionmodules_model_nonlinearity
ldm . modules . diffusionmodules . model . AttnBlock . forward = diffusionmodules_model_AttnBlock_forward
2022-09-13 19:29:56 +08:00
2022-09-03 17:08:45 +08:00
2023-01-20 01:39:03 +08:00
def fix_checkpoint ( ) :
""" checkpoints are now added and removed in embedding/hypernet code, since torch doesn ' t want
checkpoints to be added when not training ( there ' s a warning) " " "
pass
2023-01-12 22:03:46 +08:00
def weighted_loss ( sd_model , pred , target , mean = True ) :
#Calculate the weight normally, but ignore the mean
loss = sd_model . _old_get_loss ( pred , target , mean = False )
#Check if we have weights available
weight = getattr ( sd_model , ' _custom_loss_weight ' , None )
if weight is not None :
loss * = weight
#Return the loss, as mean if specified
return loss . mean ( ) if mean else loss
def weighted_forward ( sd_model , x , c , w , * args , * * kwargs ) :
try :
#Temporarily append weights to a place accessible during loss calc
sd_model . _custom_loss_weight = w
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
if not hasattr ( sd_model , ' _old_get_loss ' ) :
sd_model . _old_get_loss = sd_model . get_loss
sd_model . get_loss = MethodType ( weighted_loss , sd_model )
#Run the standard forward function, but with the patched 'get_loss'
return sd_model . forward ( x , c , * args , * * kwargs )
finally :
try :
#Delete temporary weights if appended
del sd_model . _custom_loss_weight
2023-05-10 12:52:45 +08:00
except AttributeError :
2023-01-12 22:03:46 +08:00
pass
#If we have an old loss function, reset the loss function to the original one
if hasattr ( sd_model , ' _old_get_loss ' ) :
sd_model . get_loss = sd_model . _old_get_loss
del sd_model . _old_get_loss
def apply_weighted_forward ( sd_model ) :
#Add new function 'weighted_forward' that can be called to calc weighted loss
sd_model . weighted_forward = MethodType ( weighted_forward , sd_model )
def undo_weighted_forward ( sd_model ) :
try :
del sd_model . weighted_forward
2023-05-10 12:52:45 +08:00
except AttributeError :
2023-01-12 22:03:46 +08:00
pass
2022-09-03 17:08:45 +08:00
class StableDiffusionModelHijack :
fixes = None
comments = [ ]
2022-09-05 08:25:37 +08:00
layers = None
circular_enabled = False
2022-09-28 03:56:18 +08:00
clip = None
2023-01-04 21:04:38 +08:00
optimization_method = None
2022-09-03 17:08:45 +08:00
2023-01-08 14:37:33 +08:00
embedding_db = modules . textual_inversion . textual_inversion . EmbeddingDatabase ( )
2022-09-03 17:08:45 +08:00
2023-01-08 14:37:33 +08:00
def __init__ ( self ) :
self . embedding_db . add_embedding_dir ( cmd_opts . embeddings_dir )
2022-11-30 10:13:17 +08:00
2023-01-08 14:37:33 +08:00
def hijack ( self , m ) :
2022-12-31 23:06:35 +08:00
if type ( m . cond_stage_model ) == xlmr . BertSeriesModelWithTransformation :
2022-11-30 14:56:12 +08:00
model_embeddings = m . cond_stage_model . roberta . embeddings
model_embeddings . token_embedding = EmbeddingsWithFixes ( model_embeddings . word_embeddings , self )
2022-12-31 23:06:35 +08:00
m . cond_stage_model = sd_hijack_xlmr . FrozenXLMREmbedderWithCustomWords ( m . cond_stage_model , self )
2022-11-30 14:56:12 +08:00
elif type ( m . cond_stage_model ) == ldm . modules . encoders . modules . FrozenCLIPEmbedder :
2022-11-26 21:10:46 +08:00
model_embeddings = m . cond_stage_model . transformer . text_model . embeddings
model_embeddings . token_embedding = EmbeddingsWithFixes ( model_embeddings . token_embedding , self )
m . cond_stage_model = sd_hijack_clip . FrozenCLIPEmbedderWithCustomWords ( m . cond_stage_model , self )
2022-12-31 23:06:35 +08:00
2022-11-26 21:10:46 +08:00
elif type ( m . cond_stage_model ) == ldm . modules . encoders . modules . FrozenOpenCLIPEmbedder :
m . cond_stage_model . model . token_embedding = EmbeddingsWithFixes ( m . cond_stage_model . model . token_embedding , self )
m . cond_stage_model = sd_hijack_open_clip . FrozenOpenCLIPEmbedderWithCustomWords ( m . cond_stage_model , self )
2022-12-31 23:06:35 +08:00
2023-01-12 22:03:46 +08:00
apply_weighted_forward ( m )
2023-02-07 13:05:54 +08:00
if m . cond_stage_key == " edit " :
sd_hijack_unet . hijack_ddpm_edit ( )
2023-01-12 22:03:46 +08:00
2023-01-04 21:04:38 +08:00
self . optimization_method = apply_optimizations ( )
2022-12-31 23:06:35 +08:00
2022-09-28 03:56:18 +08:00
self . clip = m . cond_stage_model
2022-09-05 06:41:20 +08:00
2022-09-05 08:25:37 +08:00
def flatten ( el ) :
flattened = [ flatten ( children ) for children in el . children ( ) ]
res = [ el ]
for c in flattened :
res + = c
return res
self . layers = flatten ( m )
2022-09-29 20:40:28 +08:00
def undo_hijack ( self , m ) :
2022-12-31 23:06:35 +08:00
if type ( m . cond_stage_model ) == xlmr . BertSeriesModelWithTransformation :
2022-12-06 16:04:50 +08:00
m . cond_stage_model = m . cond_stage_model . wrapped
elif type ( m . cond_stage_model ) == sd_hijack_clip . FrozenCLIPEmbedderWithCustomWords :
2022-09-29 20:40:28 +08:00
m . cond_stage_model = m . cond_stage_model . wrapped
2022-11-26 21:10:46 +08:00
model_embeddings = m . cond_stage_model . transformer . text_model . embeddings
if type ( model_embeddings . token_embedding ) == EmbeddingsWithFixes :
model_embeddings . token_embedding = model_embeddings . token_embedding . wrapped
elif type ( m . cond_stage_model ) == sd_hijack_open_clip . FrozenOpenCLIPEmbedderWithCustomWords :
m . cond_stage_model . wrapped . model . token_embedding = m . cond_stage_model . wrapped . model . token_embedding . wrapped
m . cond_stage_model = m . cond_stage_model . wrapped
2022-09-29 20:40:28 +08:00
2023-01-28 20:24:29 +08:00
undo_optimizations ( )
2023-01-12 22:03:46 +08:00
undo_weighted_forward ( m )
2023-01-28 20:24:29 +08:00
2022-11-18 18:22:55 +08:00
self . apply_circular ( False )
2022-11-01 15:01:49 +08:00
self . layers = None
self . clip = None
2022-09-05 08:25:37 +08:00
def apply_circular ( self , enable ) :
if self . circular_enabled == enable :
return
self . circular_enabled = enable
for layer in [ layer for layer in self . layers if type ( layer ) == torch . nn . Conv2d ] :
layer . padding_mode = ' circular ' if enable else ' zeros '
2022-10-08 05:48:34 +08:00
def clear_comments ( self ) :
self . comments = [ ]
2023-01-07 06:45:28 +08:00
def get_prompt_lengths ( self , text ) :
_ , token_count = self . clip . process_texts ( [ text ] )
2022-09-03 17:08:45 +08:00
2023-01-07 06:45:28 +08:00
return token_count , self . clip . get_target_prompt_token_count ( token_count )
2022-09-03 17:08:45 +08:00
class EmbeddingsWithFixes ( torch . nn . Module ) :
def __init__ ( self , wrapped , embeddings ) :
super ( ) . __init__ ( )
self . wrapped = wrapped
self . embeddings = embeddings
def forward ( self , input_ids ) :
batch_fixes = self . embeddings . fixes
self . embeddings . fixes = None
inputs_embeds = self . wrapped ( input_ids )
2022-10-02 20:03:39 +08:00
if batch_fixes is None or len ( batch_fixes ) == 0 or max ( [ len ( x ) for x in batch_fixes ] ) == 0 :
return inputs_embeds
vecs = [ ]
for fixes , tensor in zip ( batch_fixes , inputs_embeds ) :
for offset , embedding in fixes :
2023-01-27 23:19:43 +08:00
emb = devices . cond_cast_unet ( embedding . vec )
2022-10-15 21:59:37 +08:00
emb_len = min ( tensor . shape [ 0 ] - offset - 1 , emb . shape [ 0 ] )
tensor = torch . cat ( [ tensor [ 0 : offset + 1 ] , emb [ 0 : emb_len ] , tensor [ offset + 1 + emb_len : ] ] )
2022-10-02 20:03:39 +08:00
vecs . append ( tensor )
2022-09-03 17:08:45 +08:00
2022-10-02 20:03:39 +08:00
return torch . stack ( vecs )
2022-09-03 17:08:45 +08:00
2022-09-05 07:16:36 +08:00
def add_circular_option_to_conv_2d ( ) :
conv2d_constructor = torch . nn . Conv2d . __init__
2022-09-05 06:41:20 +08:00
2022-09-05 07:16:36 +08:00
def conv2d_constructor_circular ( self , * args , * * kwargs ) :
return conv2d_constructor ( self , * args , padding_mode = ' circular ' , * * kwargs )
2022-09-05 06:41:20 +08:00
2022-09-05 07:16:36 +08:00
torch . nn . Conv2d . __init__ = conv2d_constructor_circular
2022-09-05 06:41:20 +08:00
2022-09-03 17:08:45 +08:00
model_hijack = StableDiffusionModelHijack ( )
2022-11-11 23:20:18 +08:00
def register_buffer ( self , name , attr ) :
"""
Fix register buffer bug for Mac OS .
"""
if type ( attr ) == torch . Tensor :
if attr . device != devices . device :
2022-11-12 15:17:55 +08:00
attr = attr . to ( device = devices . device , dtype = ( torch . float32 if devices . device . type == ' mps ' else None ) )
2022-11-11 23:20:18 +08:00
setattr ( self , name , attr )
ldm . models . diffusion . ddim . DDIMSampler . register_buffer = register_buffer
ldm . models . diffusion . plms . PLMSSampler . register_buffer = register_buffer