2022-09-20 06:13:12 +08:00
import os
2023-01-22 15:17:12 +08:00
import re
2023-01-11 14:10:07 +08:00
import shutil
2022-09-20 06:13:12 +08:00
2022-09-11 16:31:16 +08:00
2022-09-26 07:22:12 +08:00
import torch
2022-09-27 15:44:00 +08:00
import tqdm
2022-09-26 07:22:12 +08:00
2023-01-28 03:43:08 +08:00
from modules import shared , images , sd_models , sd_vae , sd_models_config
2023-01-23 19:50:20 +08:00
from modules . ui_common import plaintext_to_html
2022-09-29 05:59:44 +08:00
import gradio as gr
2022-11-27 20:51:29 +08:00
import safetensors . torch
2022-09-14 00:23:55 +08:00
2022-09-11 16:31:16 +08:00
2022-09-17 14:07:07 +08:00
def run_pnginfo ( image ) :
2022-09-20 01:18:16 +08:00
if image is None :
return ' ' , ' ' , ' '
2022-11-27 21:28:32 +08:00
geninfo , items = images . read_info_from_image ( image )
items = { * * { ' parameters ' : geninfo } , * * items }
2022-11-24 10:39:09 +08:00
2022-09-11 16:31:16 +08:00
info = ' '
2022-09-14 00:23:55 +08:00
for key , text in items . items ( ) :
2022-09-11 16:31:16 +08:00
info + = f """
< div >
< p > < b > { plaintext_to_html ( str ( key ) ) } < / b > < / p >
< p > { plaintext_to_html ( str ( text ) ) } < / p >
< / div >
""" .strip()+ " \n "
if len ( info ) == 0 :
message = " Nothing found in the image. "
info = f " <div><p> { message } <p></div> "
2022-09-24 03:49:21 +08:00
return ' ' , geninfo , info
2022-09-26 07:22:12 +08:00
2023-01-11 14:10:07 +08:00
def create_config ( ckpt_result , config_source , a , b , c ) :
def config ( x ) :
2023-01-28 03:43:08 +08:00
res = sd_models_config . find_checkpoint_config_near_filename ( x ) if x else None
2023-01-19 15:39:51 +08:00
return res if res != shared . sd_default_config else None
2023-01-11 14:10:07 +08:00
if config_source == 0 :
cfg = config ( a ) or config ( b ) or config ( c )
elif config_source == 1 :
cfg = config ( b )
elif config_source == 2 :
cfg = config ( c )
else :
cfg = None
if cfg is None :
return
filename , _ = os . path . splitext ( ckpt_result )
checkpoint_filename = filename + " .yaml "
print ( " Copying config: " )
print ( " from: " , cfg )
print ( " to: " , checkpoint_filename )
shutil . copyfile ( cfg , checkpoint_filename )
2023-01-19 23:24:17 +08:00
checkpoint_dict_skip_on_merge = [ " cond_stage_model.transformer.text_model.embeddings.position_ids " ]
2023-01-19 15:39:51 +08:00
2023-01-19 17:12:09 +08:00
def to_half ( tensor , enable ) :
if enable and tensor . dtype == torch . float :
return tensor . half ( )
return tensor
2023-01-22 15:17:12 +08:00
def run_modelmerger ( id_task , primary_model_name , secondary_model_name , tertiary_model_name , interp_method , multiplier , save_as_half , custom_name , checkpoint_format , config_source , bake_in_vae , discard_weights ) :
2023-01-03 23:21:51 +08:00
shared . state . begin ( )
shared . state . job = ' model-merge '
2023-01-19 13:53:50 +08:00
def fail ( message ) :
shared . state . textinfo = message
shared . state . end ( )
2023-01-19 14:25:37 +08:00
return [ * [ gr . update ( ) for _ in range ( 4 ) ] , message ]
2023-01-19 13:53:50 +08:00
2022-10-17 06:44:39 +08:00
def weighted_sum ( theta0 , theta1 , alpha ) :
2022-09-26 22:50:21 +08:00
return ( ( 1 - alpha ) * theta0 ) + ( alpha * theta1 )
2022-10-17 06:44:39 +08:00
def get_difference ( theta1 , theta2 ) :
return theta1 - theta2
def add_difference ( theta0 , theta1_2_diff , alpha ) :
return theta0 + ( alpha * theta1_2_diff )
2022-10-14 14:05:06 +08:00
2023-01-19 23:24:17 +08:00
def filename_weighted_sum ( ) :
2023-01-19 15:39:51 +08:00
a = primary_model_info . model_name
b = secondary_model_info . model_name
Ma = round ( 1 - multiplier , 2 )
Mb = round ( multiplier , 2 )
return f " { Ma } ( { a } ) + { Mb } ( { b } ) "
2023-01-19 23:24:17 +08:00
def filename_add_difference ( ) :
2023-01-19 15:39:51 +08:00
a = primary_model_info . model_name
b = secondary_model_info . model_name
c = tertiary_model_info . model_name
M = round ( multiplier , 2 )
return f " { a } + { M } ( { b } - { c } ) "
def filename_nothing ( ) :
return primary_model_info . model_name
theta_funcs = {
2023-01-19 23:24:17 +08:00
" Weighted sum " : ( filename_weighted_sum , None , weighted_sum ) ,
" Add difference " : ( filename_add_difference , get_difference , add_difference ) ,
2023-01-19 15:39:51 +08:00
" No interpolation " : ( filename_nothing , None , None ) ,
}
filename_generator , theta_func1 , theta_func2 = theta_funcs [ interp_method ]
shared . state . job_count = ( 1 if theta_func1 else 0 ) + ( 1 if theta_func2 else 0 )
2023-01-19 08:13:15 +08:00
if not primary_model_name :
2023-01-19 13:53:50 +08:00
return fail ( " Failed: Merging requires a primary model. " )
2023-01-19 08:13:15 +08:00
2022-09-29 05:59:44 +08:00
primary_model_info = sd_models . checkpoints_list [ primary_model_name ]
2023-01-19 08:13:15 +08:00
2023-01-19 15:39:51 +08:00
if theta_func2 and not secondary_model_name :
2023-01-19 13:53:50 +08:00
return fail ( " Failed: Merging requires a secondary model. " )
2022-09-27 15:44:00 +08:00
2023-01-19 15:39:51 +08:00
secondary_model_info = sd_models . checkpoints_list [ secondary_model_name ] if theta_func2 else None
2022-09-27 15:44:00 +08:00
2023-01-19 08:13:15 +08:00
if theta_func1 and not tertiary_model_name :
2023-01-19 13:53:50 +08:00
return fail ( f " Failed: Interpolation method ( { interp_method } ) requires a tertiary model. " )
2023-01-19 15:39:51 +08:00
2023-01-19 10:21:52 +08:00
tertiary_model_info = sd_models . checkpoints_list [ tertiary_model_name ] if theta_func1 else None
2023-01-19 08:13:15 +08:00
result_is_inpainting_model = False
2023-01-26 19:05:40 +08:00
result_is_instruct_pix2pix_model = False
2022-12-04 14:13:36 +08:00
2023-01-19 15:39:51 +08:00
if theta_func2 :
shared . state . textinfo = f " Loading B "
print ( f " Loading { secondary_model_info . filename } ... " )
theta_1 = sd_models . read_state_dict ( secondary_model_info . filename , map_location = ' cpu ' )
else :
theta_1 = None
2022-10-14 14:05:06 +08:00
2022-10-17 06:44:39 +08:00
if theta_func1 :
2023-01-19 15:39:51 +08:00
shared . state . textinfo = f " Loading C "
2022-12-04 14:13:36 +08:00
print ( f " Loading { tertiary_model_info . filename } ... " )
theta_2 = sd_models . read_state_dict ( tertiary_model_info . filename , map_location = ' cpu ' )
2023-01-19 15:39:51 +08:00
shared . state . textinfo = ' Merging B and C '
2023-01-19 14:25:37 +08:00
shared . state . sampling_steps = len ( theta_1 . keys ( ) )
2022-10-17 06:44:39 +08:00
for key in tqdm . tqdm ( theta_1 . keys ( ) ) :
2023-01-19 23:24:17 +08:00
if key in checkpoint_dict_skip_on_merge :
2023-01-19 15:39:51 +08:00
continue
2022-10-17 06:44:39 +08:00
if ' model ' in key :
2022-10-18 20:33:24 +08:00
if key in theta_2 :
t2 = theta_2 . get ( key , torch . zeros_like ( theta_1 [ key ] ) )
theta_1 [ key ] = theta_func1 ( theta_1 [ key ] , t2 )
else :
2022-10-18 21:05:52 +08:00
theta_1 [ key ] = torch . zeros_like ( theta_1 [ key ] )
2023-01-19 14:25:37 +08:00
shared . state . sampling_step + = 1
2022-12-04 14:13:36 +08:00
del theta_2
2023-01-19 14:25:37 +08:00
shared . state . nextjob ( )
2023-01-03 23:21:51 +08:00
shared . state . textinfo = f " Loading { primary_model_info . filename } ... "
2022-12-04 14:13:36 +08:00
print ( f " Loading { primary_model_info . filename } ... " )
theta_0 = sd_models . read_state_dict ( primary_model_info . filename , map_location = ' cpu ' )
print ( " Merging... " )
2023-01-19 15:39:51 +08:00
shared . state . textinfo = ' Merging A and B '
2023-01-19 14:25:37 +08:00
shared . state . sampling_steps = len ( theta_0 . keys ( ) )
2022-09-27 15:44:00 +08:00
for key in tqdm . tqdm ( theta_0 . keys ( ) ) :
2023-01-19 15:39:51 +08:00
if theta_1 and ' model ' in key and key in theta_1 :
2023-01-14 19:00:00 +08:00
2023-01-19 23:24:17 +08:00
if key in checkpoint_dict_skip_on_merge :
2023-01-14 19:00:00 +08:00
continue
2022-12-04 17:30:44 +08:00
a = theta_0 [ key ]
b = theta_1 [ key ]
2022-10-15 02:20:28 +08:00
2022-12-04 17:30:44 +08:00
# this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
if a . shape != b . shape and a . shape [ 0 : 1 ] + a . shape [ 2 : ] == b . shape [ 0 : 1 ] + b . shape [ 2 : ] :
if a . shape [ 1 ] == 4 and b . shape [ 1 ] == 9 :
raise RuntimeError ( " When merging inpainting model with a normal one, A must be the inpainting model. " )
2023-01-26 17:38:04 +08:00
if a . shape [ 1 ] == 4 and b . shape [ 1 ] == 8 :
2023-01-26 19:05:40 +08:00
raise RuntimeError ( " When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model. " )
2022-12-04 17:30:44 +08:00
2023-01-26 19:05:40 +08:00
if a . shape [ 1 ] == 8 and b . shape [ 1 ] == 4 : #If we have an Instruct-Pix2Pix model...
2023-01-26 16:45:16 +08:00
theta_0 [ key ] [ : , 0 : 4 , : , : ] = theta_func2 ( a [ : , 0 : 4 , : , : ] , b , multiplier ) #Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch.
2023-01-26 19:05:40 +08:00
result_is_instruct_pix2pix_model = True
2023-01-26 16:45:16 +08:00
else :
assert a . shape [ 1 ] == 9 and b . shape [ 1 ] == 4 , f " Bad dimensions for merged layer { key } : A= { a . shape } , B= { b . shape } "
theta_0 [ key ] [ : , 0 : 4 , : , : ] = theta_func2 ( a [ : , 0 : 4 , : , : ] , b , multiplier )
result_is_inpainting_model = True
2022-12-04 17:30:44 +08:00
else :
theta_0 [ key ] = theta_func2 ( a , b , multiplier )
2023-01-26 16:45:16 +08:00
2023-01-19 17:12:09 +08:00
theta_0 [ key ] = to_half ( theta_0 [ key ] , save_as_half )
2022-10-10 09:26:52 +08:00
2023-01-19 14:25:37 +08:00
shared . state . sampling_step + = 1
2023-01-19 15:39:51 +08:00
del theta_1
bake_in_vae_filename = sd_vae . vae_dict . get ( bake_in_vae , None )
if bake_in_vae_filename is not None :
print ( f " Baking in VAE from { bake_in_vae_filename } " )
shared . state . textinfo = ' Baking in VAE '
vae_dict = sd_vae . load_vae_dict ( bake_in_vae_filename , map_location = ' cpu ' )
2023-01-14 19:00:00 +08:00
2023-01-19 15:39:51 +08:00
for key in vae_dict . keys ( ) :
theta_0_key = ' first_stage_model. ' + key
if theta_0_key in theta_0 :
2023-01-19 17:12:09 +08:00
theta_0 [ theta_0_key ] = to_half ( vae_dict [ key ] , save_as_half )
2023-01-14 19:00:00 +08:00
2023-01-19 15:39:51 +08:00
del vae_dict
2022-09-27 15:44:00 +08:00
2023-01-19 17:12:09 +08:00
if save_as_half and not theta_func2 :
for key in theta_0 . keys ( ) :
theta_0 [ key ] = to_half ( theta_0 [ key ] , save_as_half )
2023-01-22 15:17:12 +08:00
if discard_weights :
regex = re . compile ( discard_weights )
for key in list ( theta_0 ) :
if re . search ( regex , key ) :
theta_0 . pop ( key , None )
2022-10-01 03:57:25 +08:00
ckpt_dir = shared . cmd_opts . ckpt_dir or sd_models . model_path
2023-01-19 15:39:51 +08:00
filename = filename_generator ( ) if custom_name == ' ' else custom_name
filename + = " .inpainting " if result_is_inpainting_model else " "
2023-01-27 00:27:07 +08:00
filename + = " .instruct-pix2pix " if result_is_instruct_pix2pix_model else " "
2023-01-19 15:39:51 +08:00
filename + = " . " + checkpoint_format
2022-12-04 17:30:44 +08:00
2022-10-01 03:57:25 +08:00
output_modelname = os . path . join ( ckpt_dir , filename )
2022-09-29 05:21:54 +08:00
2023-01-19 14:25:37 +08:00
shared . state . nextjob ( )
2023-01-19 15:39:51 +08:00
shared . state . textinfo = " Saving "
2022-09-27 15:44:00 +08:00
print ( f " Saving to { output_modelname } ... " )
2022-11-27 20:51:29 +08:00
_ , extension = os . path . splitext ( output_modelname )
if extension . lower ( ) == " .safetensors " :
safetensors . torch . save_file ( theta_0 , output_modelname , metadata = { " format " : " pt " } )
else :
torch . save ( theta_0 , output_modelname )
2022-09-27 15:44:00 +08:00
2022-09-29 05:59:44 +08:00
sd_models . list_models ( )
2023-01-11 14:10:07 +08:00
create_config ( output_modelname , config_source , primary_model_info , secondary_model_info , tertiary_model_info )
2023-01-19 15:39:51 +08:00
print ( f " Checkpoint saved to { output_modelname } . " )
shared . state . textinfo = " Checkpoint saved "
2023-01-03 23:21:51 +08:00
shared . state . end ( )
2023-01-19 14:25:37 +08:00
return [ * [ gr . Dropdown . update ( choices = sd_models . checkpoint_tiles ( ) ) for _ in range ( 4 ) ] , " Checkpoint saved to " + output_modelname ]