2022-09-04 00:32:45 +08:00
from collections import namedtuple
from copy import copy
2022-10-06 18:55:21 +08:00
from itertools import permutations , chain
2022-09-04 00:32:45 +08:00
import random
2022-10-06 18:55:21 +08:00
import csv
from io import StringIO
2022-09-26 21:46:18 +08:00
from PIL import Image
2022-09-09 22:54:04 +08:00
import numpy as np
2022-09-04 00:32:45 +08:00
import modules . scripts as scripts
import gradio as gr
2022-10-09 19:33:22 +08:00
from modules import images , hypernetwork
2022-10-09 08:13:13 +08:00
from modules . processing import process_images , Processed , get_correct_sampler
2022-09-04 00:32:45 +08:00
from modules . shared import opts , cmd_opts , state
2022-09-17 18:49:36 +08:00
import modules . shared as shared
2022-09-04 00:32:45 +08:00
import modules . sd_samplers
2022-09-17 18:49:36 +08:00
import modules . sd_models
2022-09-06 15:11:25 +08:00
import re
2022-09-04 00:32:45 +08:00
def apply_field ( field ) :
def fun ( p , x , xs ) :
setattr ( p , field , x )
return fun
def apply_prompt ( p , x , xs ) :
p . prompt = p . prompt . replace ( xs [ 0 ] , x )
2022-09-09 13:58:31 +08:00
p . negative_prompt = p . negative_prompt . replace ( xs [ 0 ] , x )
2022-09-04 00:32:45 +08:00
2022-10-04 14:18:00 +08:00
2022-10-04 10:20:09 +08:00
def apply_order ( p , x , xs ) :
token_order = [ ]
2022-10-04 14:18:00 +08:00
# Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
2022-10-04 10:20:09 +08:00
for token in x :
token_order . append ( ( p . prompt . find ( token ) , token ) )
token_order . sort ( key = lambda t : t [ 0 ] )
2022-10-04 13:07:36 +08:00
prompt_parts = [ ]
# Split the prompt up, taking out the tokens
for _ , token in token_order :
n = p . prompt . find ( token )
prompt_parts . append ( p . prompt [ 0 : n ] )
p . prompt = p . prompt [ n + len ( token ) : ]
# Rebuild the prompt with the tokens in the order we want
prompt_tmp = " "
for idx , part in enumerate ( prompt_parts ) :
prompt_tmp + = part
prompt_tmp + = x [ idx ]
p . prompt = prompt_tmp + p . prompt
2022-09-04 00:32:45 +08:00
2022-10-09 08:13:13 +08:00
def build_samplers_dict ( p ) :
samplers_dict = { }
for i , sampler in enumerate ( get_correct_sampler ( p ) ) :
samplers_dict [ sampler . name . lower ( ) ] = i
for alias in sampler . aliases :
samplers_dict [ alias . lower ( ) ] = i
return samplers_dict
2022-09-04 00:32:45 +08:00
def apply_sampler ( p , x , xs ) :
2022-10-09 08:13:13 +08:00
sampler_index = build_samplers_dict ( p ) . get ( x . lower ( ) , None )
2022-09-04 00:32:45 +08:00
if sampler_index is None :
raise RuntimeError ( f " Unknown sampler: { x } " )
p . sampler_index = sampler_index
2022-09-17 18:49:36 +08:00
def apply_checkpoint ( p , x , xs ) :
2022-09-29 05:31:53 +08:00
info = modules . sd_models . get_closet_checkpoint_match ( x )
assert info is not None , f ' Checkpoint for { x } not found '
2022-09-17 18:49:36 +08:00
modules . sd_models . reload_model_weights ( shared . sd_model , info )
2022-10-07 15:17:52 +08:00
def apply_hypernetwork ( p , x , xs ) :
2022-10-09 19:33:22 +08:00
hypernetwork . load_hypernetwork ( x )
2022-10-07 15:17:52 +08:00
2022-09-04 00:32:45 +08:00
def format_value_add_label ( p , opt , x ) :
2022-09-09 23:05:43 +08:00
if type ( x ) == float :
x = round ( x , 8 )
2022-09-04 00:32:45 +08:00
return f " { opt . label } : { x } "
def format_value ( p , opt , x ) :
2022-09-09 23:05:43 +08:00
if type ( x ) == float :
x = round ( x , 8 )
2022-09-04 00:32:45 +08:00
return x
2022-10-04 14:18:00 +08:00
def format_value_join_list ( p , opt , x ) :
return " , " . join ( x )
2022-09-09 22:54:04 +08:00
def do_nothing ( p , x , xs ) :
pass
2022-10-04 14:18:00 +08:00
2022-09-09 22:54:04 +08:00
def format_nothing ( p , opt , x ) :
return " "
2022-09-04 00:32:45 +08:00
2022-10-04 14:18:00 +08:00
def str_permutations ( x ) :
""" dummy function for specifying it in AxisOption ' s type when you want to get a list of permutations """
return x
2022-09-04 00:32:45 +08:00
AxisOption = namedtuple ( " AxisOption " , [ " label " , " type " , " apply " , " format_value " ] )
AxisOptionImg2Img = namedtuple ( " AxisOptionImg2Img " , [ " label " , " type " , " apply " , " format_value " ] )
axis_options = [
2022-09-09 22:54:04 +08:00
AxisOption ( " Nothing " , str , do_nothing , format_nothing ) ,
2022-09-04 00:32:45 +08:00
AxisOption ( " Seed " , int , apply_field ( " seed " ) , format_value_add_label ) ,
2022-09-09 22:54:04 +08:00
AxisOption ( " Var. seed " , int , apply_field ( " subseed " ) , format_value_add_label ) ,
AxisOption ( " Var. strength " , float , apply_field ( " subseed_strength " ) , format_value_add_label ) ,
2022-09-04 00:32:45 +08:00
AxisOption ( " Steps " , int , apply_field ( " steps " ) , format_value_add_label ) ,
AxisOption ( " CFG Scale " , float , apply_field ( " cfg_scale " ) , format_value_add_label ) ,
AxisOption ( " Prompt S/R " , str , apply_prompt , format_value ) ,
2022-10-04 14:18:00 +08:00
AxisOption ( " Prompt order " , str_permutations , apply_order , format_value_join_list ) ,
2022-09-04 02:50:33 +08:00
AxisOption ( " Sampler " , str , apply_sampler , format_value ) ,
2022-09-17 18:49:36 +08:00
AxisOption ( " Checkpoint name " , str , apply_checkpoint , format_value ) ,
2022-10-07 15:17:52 +08:00
AxisOption ( " Hypernetwork " , str , apply_hypernetwork , format_value ) ,
2022-09-28 23:09:06 +08:00
AxisOption ( " Sigma Churn " , float , apply_field ( " s_churn " ) , format_value_add_label ) ,
AxisOption ( " Sigma min " , float , apply_field ( " s_tmin " ) , format_value_add_label ) ,
AxisOption ( " Sigma max " , float , apply_field ( " s_tmax " ) , format_value_add_label ) ,
AxisOption ( " Sigma noise " , float , apply_field ( " s_noise " ) , format_value_add_label ) ,
AxisOption ( " Eta " , float , apply_field ( " eta " ) , format_value_add_label ) ,
AxisOptionImg2Img ( " Denoising " , float , apply_field ( " denoising_strength " ) , format_value_add_label ) , # as it is now all AxisOptionImg2Img items must go after AxisOption ones
2022-09-04 00:32:45 +08:00
]
2022-09-17 19:55:40 +08:00
def draw_xy_grid ( p , xs , ys , x_labels , y_labels , cell , draw_legend ) :
2022-09-04 00:32:45 +08:00
res = [ ]
2022-09-17 19:55:40 +08:00
ver_texts = [ [ images . GridAnnotation ( y ) ] for y in y_labels ]
hor_texts = [ [ images . GridAnnotation ( x ) ] for x in x_labels ]
2022-09-04 00:32:45 +08:00
2022-10-08 13:30:49 +08:00
first_processed = None
2022-09-04 00:32:45 +08:00
2022-09-14 18:08:05 +08:00
state . job_count = len ( xs ) * len ( ys ) * p . n_iter
2022-09-06 07:09:01 +08:00
2022-09-04 00:32:45 +08:00
for iy , y in enumerate ( ys ) :
for ix , x in enumerate ( xs ) :
2022-09-24 13:23:01 +08:00
state . job = f " { ix + iy * len ( xs ) + 1 } out of { len ( xs ) * len ( ys ) } "
2022-09-04 00:32:45 +08:00
processed = cell ( x , y )
2022-10-08 13:30:49 +08:00
if first_processed is None :
first_processed = processed
2022-09-04 00:32:45 +08:00
2022-09-26 21:46:18 +08:00
try :
res . append ( processed . images [ 0 ] )
except :
res . append ( Image . new ( res [ 0 ] . mode , res [ 0 ] . size ) )
2022-09-04 00:32:45 +08:00
grid = images . image_grid ( res , rows = len ( ys ) )
2022-09-14 20:01:16 +08:00
if draw_legend :
grid = images . draw_grid_annotations ( grid , res [ 0 ] . width , res [ 0 ] . height , hor_texts , ver_texts )
2022-09-04 00:32:45 +08:00
2022-10-08 13:30:49 +08:00
first_processed . images = [ grid ]
2022-09-04 00:32:45 +08:00
2022-10-08 13:30:49 +08:00
return first_processed
2022-09-04 00:32:45 +08:00
2022-09-06 15:11:25 +08:00
re_range = re . compile ( r " \ s*([+-]? \ s* \ d+) \ s*- \ s*([+-]? \ s* \ d+)(?: \ s* \ (([+-] \ d+) \ s* \ ))? \ s* " )
2022-09-09 22:54:04 +08:00
re_range_float = re . compile ( r " \ s*([+-]? \ s* \ d+(?:. \ d*)?) \ s*- \ s*([+-]? \ s* \ d+(?:. \ d*)?)(?: \ s* \ (([+-] \ d+(?:. \ d*)?) \ s* \ ))? \ s* " )
2022-09-06 15:11:25 +08:00
2022-09-14 19:56:26 +08:00
re_range_count = re . compile ( r " \ s*([+-]? \ s* \ d+) \ s*- \ s*([+-]? \ s* \ d+)(?: \ s* \ [( \ d+) \ s* \ ])? \ s* " )
re_range_count_float = re . compile ( r " \ s*([+-]? \ s* \ d+(?:. \ d*)?) \ s*- \ s*([+-]? \ s* \ d+(?:. \ d*)?)(?: \ s* \ [( \ d+(?:. \ d*)?) \ s* \ ])? \ s* " )
2022-09-04 00:32:45 +08:00
class Script ( scripts . Script ) :
def title ( self ) :
return " X/Y plot "
def ui ( self , is_img2img ) :
current_axis_options = [ x for x in axis_options if type ( x ) == AxisOption or type ( x ) == AxisOptionImg2Img and is_img2img ]
with gr . Row ( ) :
2022-09-09 22:54:04 +08:00
x_type = gr . Dropdown ( label = " X type " , choices = [ x . label for x in current_axis_options ] , value = current_axis_options [ 1 ] . label , visible = False , type = " index " , elem_id = " x_type " )
2022-09-04 00:32:45 +08:00
x_values = gr . Textbox ( label = " X values " , visible = False , lines = 1 )
with gr . Row ( ) :
2022-09-09 22:54:04 +08:00
y_type = gr . Dropdown ( label = " Y type " , choices = [ x . label for x in current_axis_options ] , value = current_axis_options [ 4 ] . label , visible = False , type = " index " , elem_id = " y_type " )
2022-09-04 00:32:45 +08:00
y_values = gr . Textbox ( label = " Y values " , visible = False , lines = 1 )
2022-09-14 20:01:16 +08:00
draw_legend = gr . Checkbox ( label = ' Draw legend ' , value = True )
2022-09-24 13:23:01 +08:00
no_fixed_seeds = gr . Checkbox ( label = ' Keep -1 for seeds ' , value = False )
2022-09-04 00:32:45 +08:00
2022-09-24 13:23:01 +08:00
return [ x_type , x_values , y_type , y_values , draw_legend , no_fixed_seeds ]
2022-09-24 13:09:59 +08:00
2022-09-24 13:23:01 +08:00
def run ( self , p , x_type , x_values , y_type , y_values , draw_legend , no_fixed_seeds ) :
2022-10-07 07:31:36 +08:00
if not no_fixed_seeds :
modules . processing . fix_seed ( p )
2022-09-04 02:50:33 +08:00
p . batch_size = 1
2022-09-04 00:32:45 +08:00
def process_axis ( opt , vals ) :
2022-09-30 02:16:12 +08:00
if opt . label == ' Nothing ' :
return [ 0 ]
2022-10-07 01:16:21 +08:00
valslist = [ x . strip ( ) for x in chain . from_iterable ( csv . reader ( StringIO ( vals ) ) ) ]
2022-09-04 00:32:45 +08:00
if opt . type == int :
valslist_ext = [ ]
for val in valslist :
2022-09-06 15:11:25 +08:00
m = re_range . fullmatch ( val )
2022-09-14 19:56:26 +08:00
mc = re_range_count . fullmatch ( val )
2022-09-06 15:11:25 +08:00
if m is not None :
start = int ( m . group ( 1 ) )
end = int ( m . group ( 2 ) ) + 1
step = int ( m . group ( 3 ) ) if m . group ( 3 ) is not None else 1
2022-09-04 00:32:45 +08:00
valslist_ext + = list ( range ( start , end , step ) )
2022-09-14 19:56:26 +08:00
elif mc is not None :
start = int ( mc . group ( 1 ) )
end = int ( mc . group ( 2 ) )
num = int ( mc . group ( 3 ) ) if mc . group ( 3 ) is not None else 1
2022-09-24 13:23:01 +08:00
valslist_ext + = [ int ( x ) for x in np . linspace ( start = start , stop = end , num = num ) . tolist ( ) ]
2022-09-04 00:32:45 +08:00
else :
valslist_ext . append ( val )
valslist = valslist_ext
2022-09-09 22:54:04 +08:00
elif opt . type == float :
valslist_ext = [ ]
for val in valslist :
m = re_range_float . fullmatch ( val )
2022-09-14 19:56:26 +08:00
mc = re_range_count_float . fullmatch ( val )
2022-09-09 22:54:04 +08:00
if m is not None :
start = float ( m . group ( 1 ) )
end = float ( m . group ( 2 ) )
step = float ( m . group ( 3 ) ) if m . group ( 3 ) is not None else 1
valslist_ext + = np . arange ( start , end + step , step ) . tolist ( )
2022-09-14 19:56:26 +08:00
elif mc is not None :
start = float ( mc . group ( 1 ) )
end = float ( mc . group ( 2 ) )
num = int ( mc . group ( 3 ) ) if mc . group ( 3 ) is not None else 1
2022-09-24 13:23:01 +08:00
valslist_ext + = np . linspace ( start = start , stop = end , num = num ) . tolist ( )
2022-09-09 22:54:04 +08:00
else :
valslist_ext . append ( val )
valslist = valslist_ext
2022-10-04 14:18:00 +08:00
elif opt . type == str_permutations :
valslist = list ( permutations ( valslist ) )
2022-09-04 00:32:45 +08:00
valslist = [ opt . type ( x ) for x in valslist ]
2022-10-08 13:30:49 +08:00
# Confirm options are valid before starting
if opt . label == " Sampler " :
2022-10-09 20:01:42 +08:00
samplers_dict = build_samplers_dict ( p )
2022-10-08 13:30:49 +08:00
for sampler_val in valslist :
if sampler_val . lower ( ) not in samplers_dict . keys ( ) :
raise RuntimeError ( f " Unknown sampler: { sampler_val } " )
elif opt . label == " Checkpoint name " :
for ckpt_val in valslist :
if modules . sd_models . get_closet_checkpoint_match ( ckpt_val ) is None :
raise RuntimeError ( f " Checkpoint for { ckpt_val } not found " )
2022-09-04 00:32:45 +08:00
return valslist
x_opt = axis_options [ x_type ]
xs = process_axis ( x_opt , x_values )
y_opt = axis_options [ y_type ]
ys = process_axis ( y_opt , y_values )
2022-09-24 13:09:59 +08:00
def fix_axis_seeds ( axis_opt , axis_list ) :
if axis_opt . label == ' Seed ' :
return [ int ( random . randrange ( 4294967294 ) ) if val is None or val == ' ' or val == - 1 else val for val in axis_list ]
else :
return axis_list
2022-09-24 13:23:01 +08:00
if not no_fixed_seeds :
2022-09-24 13:09:59 +08:00
xs = fix_axis_seeds ( x_opt , xs )
ys = fix_axis_seeds ( y_opt , ys )
if x_opt . label == ' Steps ' :
total_steps = sum ( xs ) * len ( ys )
elif y_opt . label == ' Steps ' :
total_steps = sum ( ys ) * len ( xs )
else :
total_steps = p . steps * len ( xs ) * len ( ys )
2022-09-24 13:23:01 +08:00
print ( f " X/Y plot will create { len ( xs ) * len ( ys ) * p . n_iter } images on a { len ( xs ) } x { len ( ys ) } grid. (Total steps to process: { total_steps * p . n_iter } ) " )
2022-09-24 13:09:59 +08:00
shared . total_tqdm . updateTotal ( total_steps * p . n_iter )
2022-09-04 00:32:45 +08:00
def cell ( x , y ) :
pc = copy ( p )
x_opt . apply ( pc , x , xs )
y_opt . apply ( pc , y , ys )
return process_images ( pc )
processed = draw_xy_grid (
2022-09-14 18:08:05 +08:00
p ,
2022-09-04 00:32:45 +08:00
xs = xs ,
ys = ys ,
2022-09-17 19:55:40 +08:00
x_labels = [ x_opt . format_value ( p , x_opt , x ) for x in xs ] ,
y_labels = [ y_opt . format_value ( p , y_opt , y ) for y in ys ] ,
2022-09-14 20:01:16 +08:00
cell = cell ,
draw_legend = draw_legend
2022-09-04 00:32:45 +08:00
)
2022-09-04 08:38:24 +08:00
if opts . grid_save :
2022-09-14 04:28:03 +08:00
images . save_image ( processed . images [ 0 ] , p . outpath_grids , " xy_grid " , prompt = p . prompt , seed = processed . seed , grid = True , p = p )
2022-09-04 00:32:45 +08:00
2022-09-17 18:49:36 +08:00
# restore checkpoint in case it was changed by axes
modules . sd_models . reload_model_weights ( shared . sd_model )
2022-10-09 19:33:22 +08:00
hypernetwork . load_hypernetwork ( opts . sd_hypernetwork )
2022-10-07 15:17:52 +08:00
2022-09-04 00:32:45 +08:00
return processed