2022-10-10 09:26:52 +08:00
import math
2022-09-20 06:13:12 +08:00
import os
2022-09-11 16:31:16 +08:00
import numpy as np
from PIL import Image
2022-09-26 07:22:12 +08:00
import torch
2022-09-27 15:44:00 +08:00
import tqdm
2022-09-26 07:22:12 +08:00
2022-09-29 05:21:54 +08:00
from modules import processing , shared , images , devices , sd_models
2022-09-11 16:31:16 +08:00
from modules . shared import opts
import modules . gfpgan_model
from modules . ui import plaintext_to_html
import modules . codeformer_model
2022-09-14 00:23:55 +08:00
import piexif
2022-09-14 20:20:05 +08:00
import piexif . helper
2022-09-29 05:59:44 +08:00
import gradio as gr
2022-09-14 00:23:55 +08:00
2022-09-11 16:31:16 +08:00
cached_images = { }
2022-10-10 09:26:52 +08:00
def run_extras ( extras_mode , resize_mode , image , image_folder , gfpgan_visibility , codeformer_visibility , codeformer_weight , upscaling_resize , upscaling_resize_w , upscaling_resize_h , upscaling_crop , extras_upscaler_1 , extras_upscaler_2 , extras_upscaler_2_visibility ) :
2022-09-12 04:24:24 +08:00
devices . torch_gc ( )
2022-09-11 16:31:16 +08:00
2022-09-16 11:23:37 +08:00
imageArr = [ ]
2022-09-20 06:13:12 +08:00
# Also keep track of original file names
imageNameArr = [ ]
2022-09-12 23:59:53 +08:00
2022-09-22 17:11:48 +08:00
if extras_mode == 1 :
2022-09-16 11:23:37 +08:00
#convert file to pillow image
for img in image_folder :
2022-10-09 21:14:56 +08:00
image = Image . open ( img )
2022-09-16 11:23:37 +08:00
imageArr . append ( image )
2022-09-20 06:13:12 +08:00
imageNameArr . append ( os . path . splitext ( img . orig_name ) [ 0 ] )
2022-09-22 17:11:48 +08:00
else :
imageArr . append ( image )
imageNameArr . append ( None )
2022-09-11 16:31:16 +08:00
outpath = opts . outdir_samples or opts . outdir_extras_samples
2022-09-16 17:43:24 +08:00
outputs = [ ]
2022-09-20 06:13:12 +08:00
for image , image_name in zip ( imageArr , imageNameArr ) :
2022-09-26 22:29:50 +08:00
if image is None :
return outputs , " Please select an input image. " , ' '
2022-09-16 11:23:37 +08:00
existing_pnginfo = image . info or { }
image = image . convert ( " RGB " )
info = " "
if gfpgan_visibility > 0 :
restored_img = modules . gfpgan_model . gfpgan_fix_faces ( np . array ( image , dtype = np . uint8 ) )
res = Image . fromarray ( restored_img )
2022-09-11 16:31:16 +08:00
2022-09-16 11:23:37 +08:00
if gfpgan_visibility < 1.0 :
res = Image . blend ( image , res , gfpgan_visibility )
2022-09-11 16:31:16 +08:00
2022-09-16 11:23:37 +08:00
info + = f " GFPGAN visibility: { round ( gfpgan_visibility , 2 ) } \n "
image = res
2022-09-11 16:31:16 +08:00
2022-09-16 11:23:37 +08:00
if codeformer_visibility > 0 :
restored_img = modules . codeformer_model . codeformer . restore ( np . array ( image , dtype = np . uint8 ) , w = codeformer_weight )
res = Image . fromarray ( restored_img )
2022-09-11 16:31:16 +08:00
2022-09-16 11:23:37 +08:00
if codeformer_visibility < 1.0 :
res = Image . blend ( image , res , codeformer_visibility )
2022-09-11 16:31:16 +08:00
2022-09-18 03:02:46 +08:00
info + = f " CodeFormer w: { round ( codeformer_weight , 2 ) } , CodeFormer visibility: { round ( codeformer_visibility , 2 ) } \n "
2022-09-16 11:23:37 +08:00
image = res
2022-09-11 16:31:16 +08:00
2022-10-10 09:26:52 +08:00
if resize_mode == 1 :
upscaling_resize = max ( upscaling_resize_w / image . width , upscaling_resize_h / image . height )
crop_info = " (crop) " if upscaling_crop else " "
info + = f " Resize to: { upscaling_resize_w : g } x { upscaling_resize_h : g } { crop_info } \n "
2022-09-16 11:23:37 +08:00
if upscaling_resize != 1.0 :
2022-10-10 09:26:52 +08:00
def upscale ( image , scaler_index , resize , mode , resize_w , resize_h , crop ) :
2022-09-16 11:23:37 +08:00
small = image . crop ( ( image . width / / 2 , image . height / / 2 , image . width / / 2 + 10 , image . height / / 2 + 10 ) )
pixels = tuple ( np . array ( small ) . flatten ( ) . tolist ( ) )
key = ( resize , scaler_index , image . width , image . height , gfpgan_visibility , codeformer_visibility , codeformer_weight ) + pixels
2022-09-11 16:31:16 +08:00
2022-09-16 11:23:37 +08:00
c = cached_images . get ( key )
if c is None :
upscaler = shared . sd_upscalers [ scaler_index ]
2022-09-30 16:42:40 +08:00
c = upscaler . scaler . upscale ( image , resize , upscaler . data_path )
2022-10-10 09:26:52 +08:00
if mode == 1 and crop :
2022-10-11 02:04:21 +08:00
cropped = Image . new ( " RGB " , ( resize_w , resize_h ) )
cropped . paste ( c , box = ( resize_w / / 2 - c . width / / 2 , resize_h / / 2 - c . height / / 2 ) )
c = cropped
2022-09-16 11:23:37 +08:00
cached_images [ key ] = c
2022-09-11 16:31:16 +08:00
2022-09-16 11:23:37 +08:00
return c
2022-09-11 16:31:16 +08:00
2022-09-16 11:23:37 +08:00
info + = f " Upscale: { round ( upscaling_resize , 3 ) } , model: { shared . sd_upscalers [ extras_upscaler_1 ] . name } \n "
2022-10-10 09:26:52 +08:00
res = upscale ( image , extras_upscaler_1 , upscaling_resize , resize_mode , upscaling_resize_w , upscaling_resize_h , upscaling_crop )
2022-09-11 16:31:16 +08:00
2022-09-16 11:23:37 +08:00
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0 :
2022-10-10 09:26:52 +08:00
res2 = upscale ( image , extras_upscaler_2 , upscaling_resize , resize_mode , upscaling_resize_w , upscaling_resize_h , upscaling_crop )
2022-09-16 11:23:37 +08:00
info + = f " Upscale: { round ( upscaling_resize , 3 ) } , visibility: { round ( extras_upscaler_2_visibility , 3 ) } , model: { shared . sd_upscalers [ extras_upscaler_2 ] . name } \n "
res = Image . blend ( res , res2 , extras_upscaler_2_visibility )
2022-09-11 16:31:16 +08:00
2022-09-16 11:23:37 +08:00
image = res
2022-09-11 16:31:16 +08:00
2022-09-16 11:23:37 +08:00
while len ( cached_images ) > 2 :
del cached_images [ next ( iter ( cached_images . keys ( ) ) ) ]
2022-09-11 16:31:16 +08:00
2022-09-20 06:13:12 +08:00
images . save_image ( image , path = outpath , basename = " " , seed = None , prompt = None , extension = opts . samples_format , info = info , short_filename = True ,
no_prompt = True , grid = False , pnginfo_section_name = " extras " , existing_info = existing_pnginfo ,
forced_filename = image_name if opts . use_original_name_batch else None )
2022-09-11 16:31:16 +08:00
2022-10-09 18:10:15 +08:00
if opts . enable_pnginfo :
image . info = existing_pnginfo
image . info [ " extras " ] = info
2022-09-16 17:43:24 +08:00
outputs . append ( image )
2022-09-29 09:14:13 +08:00
devices . torch_gc ( )
2022-09-16 17:43:24 +08:00
return outputs , plaintext_to_html ( info ) , ' '
2022-09-11 16:31:16 +08:00
2022-09-17 14:07:07 +08:00
def run_pnginfo ( image ) :
2022-09-20 01:18:16 +08:00
if image is None :
return ' ' , ' ' , ' '
2022-09-14 00:23:55 +08:00
items = image . info
2022-09-24 03:49:21 +08:00
geninfo = ' '
2022-09-14 00:23:55 +08:00
if " exif " in image . info :
exif = piexif . load ( image . info [ " exif " ] )
exif_comment = ( exif or { } ) . get ( " Exif " , { } ) . get ( piexif . ExifIFD . UserComment , b ' ' )
2022-09-14 20:20:05 +08:00
try :
exif_comment = piexif . helper . UserComment . load ( exif_comment )
except ValueError :
exif_comment = exif_comment . decode ( ' utf8 ' , errors = " ignore " )
2022-09-14 00:23:55 +08:00
items [ ' exif comment ' ] = exif_comment
2022-09-24 03:49:21 +08:00
geninfo = exif_comment
2022-09-14 00:23:55 +08:00
2022-09-17 04:48:22 +08:00
for field in [ ' jfif ' , ' jfif_version ' , ' jfif_unit ' , ' jfif_density ' , ' dpi ' , ' exif ' ,
' loop ' , ' background ' , ' timestamp ' , ' duration ' ] :
items . pop ( field , None )
2022-09-14 00:23:55 +08:00
2022-09-24 03:49:21 +08:00
geninfo = items . get ( ' parameters ' , geninfo )
2022-09-14 00:23:55 +08:00
2022-09-11 16:31:16 +08:00
info = ' '
2022-09-14 00:23:55 +08:00
for key , text in items . items ( ) :
2022-09-11 16:31:16 +08:00
info + = f """
< div >
< p > < b > { plaintext_to_html ( str ( key ) ) } < / b > < / p >
< p > { plaintext_to_html ( str ( text ) ) } < / p >
< / div >
""" .strip()+ " \n "
if len ( info ) == 0 :
message = " Nothing found in the image. "
info = f " <div><p> { message } <p></div> "
2022-09-24 03:49:21 +08:00
return ' ' , geninfo , info
2022-09-26 07:22:12 +08:00
2022-10-14 14:05:06 +08:00
def run_modelmerger ( primary_model_name , secondary_model_name , teritary_model_name , interp_method , interp_amount , save_as_half , custom_name ) :
2022-09-26 22:50:21 +08:00
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
2022-10-14 14:05:06 +08:00
def weighted_sum ( theta0 , theta1 , theta2 , alpha ) :
2022-09-26 22:50:21 +08:00
return ( ( 1 - alpha ) * theta0 ) + ( alpha * theta1 )
# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
2022-10-14 14:05:06 +08:00
def sigmoid ( theta0 , theta1 , theta2 , alpha ) :
2022-09-26 22:50:21 +08:00
alpha = alpha * alpha * ( 3 - ( 2 * alpha ) )
return theta0 + ( ( theta1 - theta0 ) * alpha )
2022-09-28 20:52:46 +08:00
# Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
2022-10-14 14:05:06 +08:00
def inv_sigmoid ( theta0 , theta1 , theta2 , alpha ) :
2022-09-28 20:52:46 +08:00
import math
alpha = 0.5 - math . sin ( math . asin ( 1.0 - 2.0 * alpha ) / 3.0 )
return theta0 + ( ( theta1 - theta0 ) * alpha )
2022-10-14 14:05:06 +08:00
def add_difference ( theta0 , theta1 , theta2 , alpha ) :
return theta0 + ( theta1 - theta2 ) * ( 1.0 - alpha )
2022-09-29 05:59:44 +08:00
primary_model_info = sd_models . checkpoints_list [ primary_model_name ]
secondary_model_info = sd_models . checkpoints_list [ secondary_model_name ]
2022-10-14 14:05:06 +08:00
teritary_model_info = sd_models . checkpoints_list . get ( teritary_model_name , None )
2022-09-27 15:44:00 +08:00
2022-09-29 05:59:44 +08:00
print ( f " Loading { primary_model_info . filename } ... " )
primary_model = torch . load ( primary_model_info . filename , map_location = ' cpu ' )
2022-10-14 14:05:06 +08:00
theta_0 = sd_models . get_state_dict_from_checkpoint ( primary_model )
2022-09-28 09:34:24 +08:00
2022-09-29 05:59:44 +08:00
print ( f " Loading { secondary_model_info . filename } ... " )
secondary_model = torch . load ( secondary_model_info . filename , map_location = ' cpu ' )
2022-10-09 15:23:31 +08:00
theta_1 = sd_models . get_state_dict_from_checkpoint ( secondary_model )
2022-09-27 15:44:00 +08:00
2022-10-14 14:05:06 +08:00
if teritary_model_info is not None :
print ( f " Loading { teritary_model_info . filename } ... " )
teritary_model = torch . load ( teritary_model_info . filename , map_location = ' cpu ' )
theta_2 = sd_models . get_state_dict_from_checkpoint ( teritary_model )
else :
theta_2 = None
2022-09-27 15:44:00 +08:00
theta_funcs = {
" Weighted Sum " : weighted_sum ,
" Sigmoid " : sigmoid ,
2022-09-29 05:21:54 +08:00
" Inverse Sigmoid " : inv_sigmoid ,
2022-10-14 14:05:06 +08:00
" Add difference " : add_difference ,
2022-09-27 15:44:00 +08:00
}
theta_func = theta_funcs [ interp_method ]
print ( f " Merging... " )
2022-10-14 14:05:06 +08:00
2022-09-27 15:44:00 +08:00
for key in tqdm . tqdm ( theta_0 . keys ( ) ) :
2022-09-26 07:22:12 +08:00
if ' model ' in key and key in theta_1 :
2022-10-14 14:05:06 +08:00
theta_0 [ key ] = theta_func ( theta_0 [ key ] , theta_1 [ key ] , theta_2 [ key ] if theta_2 else None , ( float ( 1.0 ) - interp_amount ) ) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
2022-09-29 05:59:44 +08:00
if save_as_half :
theta_0 [ key ] = theta_0 [ key ] . half ( )
2022-10-10 09:26:52 +08:00
2022-10-14 14:05:06 +08:00
# I believe this part should be discarded, but I'll leave it for now until I am sure
2022-09-26 07:22:12 +08:00
for key in theta_1 . keys ( ) :
if ' model ' in key and key not in theta_0 :
theta_0 [ key ] = theta_1 [ key ]
2022-09-29 05:59:44 +08:00
if save_as_half :
theta_0 [ key ] = theta_0 [ key ] . half ( )
2022-09-27 15:44:00 +08:00
2022-10-01 03:57:25 +08:00
ckpt_dir = shared . cmd_opts . ckpt_dir or sd_models . model_path
2022-09-29 05:59:44 +08:00
filename = primary_model_info . model_name + ' _ ' + str ( round ( interp_amount , 2 ) ) + ' - ' + secondary_model_info . model_name + ' _ ' + str ( round ( ( float ( 1.0 ) - interp_amount ) , 2 ) ) + ' - ' + interp_method . replace ( " " , " _ " ) + ' -merged.ckpt '
2022-09-29 07:50:34 +08:00
filename = filename if custom_name == ' ' else ( custom_name + ' .ckpt ' )
2022-10-01 03:57:25 +08:00
output_modelname = os . path . join ( ckpt_dir , filename )
2022-09-29 05:21:54 +08:00
2022-09-27 15:44:00 +08:00
print ( f " Saving to { output_modelname } ... " )
2022-09-28 09:34:24 +08:00
torch . save ( primary_model , output_modelname )
2022-09-27 15:44:00 +08:00
2022-09-29 05:59:44 +08:00
sd_models . list_models ( )
2022-09-27 15:44:00 +08:00
print ( f " Checkpoint saved. " )
2022-10-14 14:05:06 +08:00
return [ " Checkpoint saved to " + output_modelname ] + [ gr . Dropdown . update ( choices = sd_models . checkpoint_tiles ( ) ) for _ in range ( 4 ) ]