2022-10-29 13:42:34 +08:00
import base64
import io
2023-05-19 01:16:09 +08:00
import json
2022-10-09 11:57:19 +08:00
import os
2022-09-24 03:49:21 +08:00
import re
2022-11-28 04:04:42 +08:00
2022-09-24 03:49:21 +08:00
import gradio as gr
2023-01-26 01:00:09 +08:00
from modules . paths import data_path
2023-01-13 05:50:09 +08:00
from modules import shared , ui_tempdir , script_callbacks
2022-10-29 15:56:19 +08:00
from PIL import Image
2022-09-24 03:49:21 +08:00
2023-01-28 16:11:47 +08:00
re_param_code = r ' \ s*([ \ w ]+): \ s*( " (?: \\ " [^,]| \\ " | \\ |[^ \ " ])+ " |[^,]*)(?:,|$) '
2022-09-25 14:25:28 +08:00
re_param = re . compile ( re_param_code )
2022-09-24 03:49:21 +08:00
re_imagesize = re . compile ( r " ^( \ d+)x( \ d+)$ " )
2022-12-14 06:25:16 +08:00
re_hypernet_hash = re . compile ( " \ (([0-9a-f]+) \ )$ " )
2022-09-25 14:25:28 +08:00
type_of_gr_update = type ( gr . update ( ) )
2023-01-30 05:25:30 +08:00
2022-10-27 13:36:11 +08:00
paste_fields = { }
2023-01-30 05:25:30 +08:00
registered_param_bindings = [ ]
class ParamBinding :
2023-05-10 16:19:16 +08:00
def __init__ ( self , paste_button , tabname , source_text_component = None , source_image_component = None , source_tabname = None , override_settings_component = None , paste_field_names = None ) :
2023-01-30 05:25:30 +08:00
self . paste_button = paste_button
self . tabname = tabname
self . source_text_component = source_text_component
self . source_image_component = source_image_component
self . source_tabname = source_tabname
self . override_settings_component = override_settings_component
2023-05-10 16:19:16 +08:00
self . paste_field_names = paste_field_names or [ ]
2022-09-24 03:49:21 +08:00
2022-10-29 13:42:34 +08:00
2022-10-31 22:36:45 +08:00
def reset ( ) :
paste_fields . clear ( )
2022-10-21 21:10:51 +08:00
def quote ( text ) :
2023-05-19 01:16:09 +08:00
if ' , ' not in str ( text ) and ' \n ' not in str ( text ) :
2022-10-21 21:10:51 +08:00
return text
2023-05-19 01:16:09 +08:00
return json . dumps ( text , ensure_ascii = False )
def unquote ( text ) :
if len ( text ) == 0 or text [ 0 ] != ' " ' or text [ - 1 ] != ' " ' :
return text
try :
return json . loads ( text )
except Exception :
return text
2022-10-21 21:10:51 +08:00
2022-10-29 13:42:34 +08:00
2022-10-27 13:36:11 +08:00
def image_from_url_text ( filedata ) :
2023-01-16 04:23:16 +08:00
if filedata is None :
return None
2023-01-04 02:49:24 +08:00
if type ( filedata ) == list and len ( filedata ) > 0 and type ( filedata [ 0 ] ) == dict and filedata [ 0 ] . get ( " is_file " , False ) :
filedata = filedata [ 0 ]
if type ( filedata ) == dict and filedata . get ( " is_file " , False ) :
2022-10-27 13:36:11 +08:00
filename = filedata [ " name " ]
2023-01-03 19:18:48 +08:00
is_in_right_dir = ui_tempdir . check_tmp_file ( shared . demo , filename )
2022-11-28 04:04:42 +08:00
assert is_in_right_dir , ' trying to open image file outside of allowed directories '
2022-10-27 13:36:11 +08:00
2023-05-05 15:51:01 +08:00
filename = filename . rsplit ( ' ? ' , 1 ) [ 0 ]
2022-10-27 13:36:11 +08:00
return Image . open ( filename )
if type ( filedata ) == list :
if len ( filedata ) == 0 :
return None
filedata = filedata [ 0 ]
if filedata . startswith ( " data:image/png;base64, " ) :
filedata = filedata [ len ( " data:image/png;base64, " ) : ]
filedata = base64 . decodebytes ( filedata . encode ( ' utf-8 ' ) )
image = Image . open ( io . BytesIO ( filedata ) )
return image
2022-10-29 13:42:34 +08:00
2023-02-19 14:30:49 +08:00
def add_paste_fields ( tabname , init_img , fields , override_settings_component = None ) :
paste_fields [ tabname ] = { " init_img " : init_img , " fields " : fields , " override_settings_component " : override_settings_component }
2022-10-29 14:01:04 +08:00
# backwards compatibility for existing extensions
import modules . ui
if tabname == ' txt2img ' :
modules . ui . txt2img_paste_fields = fields
elif tabname == ' img2img ' :
modules . ui . img2img_paste_fields = fields
2022-10-27 13:36:11 +08:00
2022-10-29 13:42:34 +08:00
2022-10-27 13:36:11 +08:00
def create_buttons ( tabs_list ) :
buttons = { }
for tab in tabs_list :
2023-01-01 21:51:12 +08:00
buttons [ tab ] = gr . Button ( f " Send to { tab } " , elem_id = f " { tab } _tab " )
2022-10-27 13:36:11 +08:00
return buttons
2022-10-29 13:42:34 +08:00
2022-10-27 13:36:11 +08:00
def bind_buttons ( buttons , send_image , send_generate_info ) :
2023-01-30 05:25:30 +08:00
""" old function for backwards compatibility; do not use this, use register_paste_params_button """
for tabname , button in buttons . items ( ) :
source_text_component = send_generate_info if isinstance ( send_generate_info , gr . components . Component ) else None
source_tabname = send_generate_info if isinstance ( send_generate_info , str ) else None
register_paste_params_button ( ParamBinding ( paste_button = button , tabname = tabname , source_text_component = source_text_component , source_image_component = send_image , source_tabname = source_tabname ) )
def register_paste_params_button ( binding : ParamBinding ) :
registered_param_bindings . append ( binding )
def connect_paste_params_buttons ( ) :
binding : ParamBinding
for binding in registered_param_bindings :
destination_image_component = paste_fields [ binding . tabname ] [ " init_img " ]
fields = paste_fields [ binding . tabname ] [ " fields " ]
2023-02-19 14:30:49 +08:00
override_settings_component = binding . override_settings_component or paste_fields [ binding . tabname ] [ " override_settings_component " ]
2023-01-30 05:25:30 +08:00
destination_width_component = next ( iter ( [ field for field , name in fields if name == " Size-1 " ] if fields else [ ] ) , None )
destination_height_component = next ( iter ( [ field for field , name in fields if name == " Size-2 " ] if fields else [ ] ) , None )
if binding . source_image_component and destination_image_component :
if isinstance ( binding . source_image_component , gr . Gallery ) :
func = send_image_and_dimensions if destination_width_component else image_from_url_text
jsfunc = " extract_image_from_gallery "
else :
func = send_image_and_dimensions if destination_width_component else lambda x : x
jsfunc = None
binding . paste_button . click (
fn = func ,
_js = jsfunc ,
inputs = [ binding . source_image_component ] ,
outputs = [ destination_image_component , destination_width_component , destination_height_component ] if destination_width_component else [ destination_image_component ] ,
2023-05-08 17:17:36 +08:00
show_progress = False ,
2023-01-30 05:25:30 +08:00
)
if binding . source_text_component is not None and fields is not None :
2023-02-19 14:30:49 +08:00
connect_paste ( binding . paste_button , fields , binding . source_text_component , override_settings_component , binding . tabname )
2023-01-30 05:25:30 +08:00
if binding . source_tabname is not None and fields is not None :
2023-02-14 19:55:42 +08:00
paste_field_names = [ ' Prompt ' , ' Negative prompt ' , ' Steps ' , ' Face restoration ' ] + ( [ " Seed " ] if shared . opts . send_seed else [ ] ) + binding . paste_field_names
2023-01-30 05:25:30 +08:00
binding . paste_button . click (
fn = lambda * x : x ,
inputs = [ field for field , name in paste_fields [ binding . source_tabname ] [ " fields " ] if name in paste_field_names ] ,
outputs = [ field for field , name in fields if name in paste_field_names ] ,
2023-05-08 17:17:36 +08:00
show_progress = False ,
2023-01-30 05:25:30 +08:00
)
binding . paste_button . click (
fn = None ,
_js = f " switch_to_ { binding . tabname } " ,
inputs = None ,
outputs = None ,
2023-05-08 17:17:36 +08:00
show_progress = False ,
2023-01-30 05:25:30 +08:00
)
2022-10-27 13:36:11 +08:00
2022-10-29 13:42:34 +08:00
2023-01-03 03:44:46 +08:00
def send_image_and_dimensions ( x ) :
if isinstance ( x , Image . Image ) :
img = x
else :
img = image_from_url_text ( x )
if shared . opts . send_size and isinstance ( img , Image . Image ) :
w = img . width
h = img . height
else :
w = gr . update ( )
h = gr . update ( )
return img , w , h
2022-10-29 13:42:34 +08:00
2022-12-14 06:25:16 +08:00
def find_hypernetwork_key ( hypernet_name , hypernet_hash = None ) :
""" Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
Example : an infotext provides " Hypernet: ke-ta " and " Hypernet hash: 1234abcd " . For the " Hypernet " config
parameter this means there should be an entry that looks like " ke-ta-10000(1234abcd) " to set it to .
2022-12-14 06:32:26 +08:00
If the infotext has no hash , then a hypernet with the same name will be selected instead .
2022-12-14 06:25:16 +08:00
"""
hypernet_name = hypernet_name . lower ( )
if hypernet_hash is not None :
# Try to match the hash in the name
for hypernet_key in shared . hypernetworks . keys ( ) :
result = re_hypernet_hash . search ( hypernet_key )
if result is not None and result [ 1 ] == hypernet_hash :
return hypernet_key
else :
# Fall back to a hypernet with the same name
for hypernet_key in shared . hypernetworks . keys ( ) :
if hypernet_key . lower ( ) . startswith ( hypernet_name ) :
return hypernet_key
return None
2023-01-03 00:42:10 +08:00
def restore_old_hires_fix_params ( res ) :
""" for infotexts that specify old First pass size parameter, convert it into
width , height , and hr scale """
firstpass_width = res . get ( ' First pass size-1 ' , None )
firstpass_height = res . get ( ' First pass size-2 ' , None )
2023-01-09 19:57:47 +08:00
if shared . opts . use_old_hires_fix_width_height :
2023-01-10 07:17:33 +08:00
hires_width = int ( res . get ( " Hires resize-1 " , 0 ) )
hires_height = int ( res . get ( " Hires resize-2 " , 0 ) )
2023-01-09 19:57:47 +08:00
2023-01-10 07:17:33 +08:00
if hires_width and hires_height :
2023-01-09 19:57:47 +08:00
res [ ' Size-1 ' ] = hires_width
res [ ' Size-2 ' ] = hires_height
return
2023-01-03 00:42:10 +08:00
if firstpass_width is None or firstpass_height is None :
return
firstpass_width , firstpass_height = int ( firstpass_width ) , int ( firstpass_height )
width = int ( res . get ( " Size-1 " , 512 ) )
height = int ( res . get ( " Size-2 " , 512 ) )
if firstpass_width == 0 or firstpass_height == 0 :
2023-01-09 19:57:47 +08:00
from modules import processing
firstpass_width , firstpass_height = processing . old_hires_fix_first_pass_dimensions ( width , height )
2023-01-03 00:42:10 +08:00
res [ ' Size-1 ' ] = firstpass_width
res [ ' Size-2 ' ] = firstpass_height
2023-01-05 03:04:40 +08:00
res [ ' Hires resize-1 ' ] = width
res [ ' Hires resize-2 ' ] = height
2023-01-03 00:42:10 +08:00
2022-09-24 03:49:21 +08:00
def parse_generation_parameters ( x : str ) :
""" parses generation parameters string, the one you see in text field under the picture in UI:
` ` `
girl with an artist ' s beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
Negative prompt : ugly , fat , obese , chubby , ( ( ( deformed ) ) ) , [ blurry ] , bad anatomy , disfigured , poorly drawn face , mutation , mutated , ( extra_limb ) , ( ugly ) , ( poorly drawn hands ) , messy drawing
Steps : 20 , Sampler : Euler a , CFG scale : 7 , Seed : 965400086 , Size : 512 x512 , Model hash : 45 dee52b
` ` `
returns a dict with field values
"""
res = { }
prompt = " "
negative_prompt = " "
done_with_prompt = False
* lines , lastline = x . strip ( ) . split ( " \n " )
2023-01-28 23:04:09 +08:00
if len ( re_param . findall ( lastline ) ) < 3 :
2022-09-25 14:25:28 +08:00
lines . append ( lastline )
lastline = ' '
2023-05-10 16:37:18 +08:00
for line in lines :
2022-09-24 03:49:21 +08:00
line = line . strip ( )
if line . startswith ( " Negative prompt: " ) :
done_with_prompt = True
line = line [ 16 : ] . strip ( )
if done_with_prompt :
2022-09-25 14:25:28 +08:00
negative_prompt + = ( " " if negative_prompt == " " else " \n " ) + line
2022-09-24 03:49:21 +08:00
else :
2022-09-25 14:25:28 +08:00
prompt + = ( " " if prompt == " " else " \n " ) + line
2022-09-24 03:49:21 +08:00
2022-10-20 02:22:03 +08:00
res [ " Prompt " ] = prompt
res [ " Negative prompt " ] = negative_prompt
2022-09-24 03:49:21 +08:00
for k , v in re_param . findall ( lastline ) :
2023-05-19 01:16:09 +08:00
if v [ 0 ] == ' " ' and v [ - 1 ] == ' " ' :
v = unquote ( v )
2022-09-24 03:49:21 +08:00
m = re_imagesize . match ( v )
if m is not None :
2023-05-10 03:17:58 +08:00
res [ f " { k } -1 " ] = m . group ( 1 )
res [ f " { k } -2 " ] = m . group ( 2 )
2022-09-24 03:49:21 +08:00
else :
res [ k ] = v
2022-12-02 03:34:16 +08:00
# Missing CLIP skip means it was set to 1 (the default)
if " Clip skip " not in res :
res [ " Clip skip " ] = " 1 "
2023-01-21 13:36:07 +08:00
hypernet = res . get ( " Hypernet " , None )
if hypernet is not None :
res [ " Prompt " ] + = f """ <hypernet: { hypernet } : { res . get ( " Hypernet strength " , " 1.0 " ) } > """
2022-12-14 06:25:16 +08:00
2023-01-05 03:04:40 +08:00
if " Hires resize-1 " not in res :
res [ " Hires resize-1 " ] = 0
res [ " Hires resize-2 " ] = 0
2023-05-19 01:16:09 +08:00
if " Hires sampler " not in res :
res [ " Hires sampler " ] = " Use same sampler "
if " Hires prompt " not in res :
res [ " Hires prompt " ] = " "
if " Hires negative prompt " not in res :
res [ " Hires negative prompt " ] = " "
2023-01-03 00:42:10 +08:00
restore_old_hires_fix_params ( res )
2023-04-29 16:29:37 +08:00
# Missing RNG means the default was set, which is GPU RNG
if " RNG " not in res :
res [ " RNG " ] = " GPU "
2022-09-24 03:49:21 +08:00
return res
2022-10-29 14:01:04 +08:00
2023-01-30 05:25:30 +08:00
settings_map = { }
2023-03-11 17:09:36 +08:00
2023-01-30 05:25:30 +08:00
infotext_to_setting_name_mapping = [
( ' Clip skip ' , ' CLIP_stop_at_last_layers ' , ) ,
( ' Conditional mask weight ' , ' inpainting_mask_weight ' ) ,
( ' Model hash ' , ' sd_model_checkpoint ' ) ,
( ' ENSD ' , ' eta_noise_seed_delta ' ) ,
( ' Noise multiplier ' , ' initial_noise_multiplier ' ) ,
2023-01-30 15:47:09 +08:00
( ' Eta ' , ' eta_ancestral ' ) ,
( ' Eta DDIM ' , ' eta_ddim ' ) ,
2023-03-11 17:09:36 +08:00
( ' Discard penultimate sigma ' , ' always_discard_next_to_last_sigma ' ) ,
( ' UniPC variant ' , ' uni_pc_variant ' ) ,
( ' UniPC skip type ' , ' uni_pc_skip_type ' ) ,
( ' UniPC order ' , ' uni_pc_order ' ) ,
( ' UniPC lower order final ' , ' uni_pc_lower_order_final ' ) ,
2023-04-10 16:37:15 +08:00
( ' Token merging ratio ' , ' token_merging_ratio ' ) ,
( ' Token merging ratio hr ' , ' token_merging_ratio_hr ' ) ,
2023-04-29 16:29:37 +08:00
( ' RNG ' , ' randn_source ' ) ,
2023-05-14 13:30:37 +08:00
( ' NGMS ' , ' s_min_uncond ' ) ,
2023-01-30 05:25:30 +08:00
]
def create_override_settings_dict ( text_pairs ) :
""" creates processing ' s override_settings parameters from gradio ' s multiselect
Example input :
[ ' Clip skip: 2 ' , ' Model hash: e6e99610c4 ' , ' ENSD: 31337 ' ]
Example output :
{ ' CLIP_stop_at_last_layers ' : 2 , ' sd_model_checkpoint ' : ' e6e99610c4 ' , ' eta_noise_seed_delta ' : 31337 }
"""
res = { }
params = { }
for pair in text_pairs :
k , v = pair . split ( " : " , maxsplit = 1 )
params [ k ] = v . strip ( )
for param_name , setting_name in infotext_to_setting_name_mapping :
value = params . get ( param_name , None )
if value is None :
continue
res [ setting_name ] = shared . opts . cast_value ( setting_name , value )
return res
2023-01-30 06:03:31 +08:00
def connect_paste ( button , paste_fields , input_comp , override_settings_component , tabname ) :
2022-09-24 03:49:21 +08:00
def paste_func ( prompt ) :
2022-10-13 17:26:34 +08:00
if not prompt and not shared . cmd_opts . hide_ui_dir_config :
2023-01-26 00:15:42 +08:00
filename = os . path . join ( data_path , " params.txt " )
2022-10-09 11:57:19 +08:00
if os . path . exists ( filename ) :
with open ( filename , " r " , encoding = " utf8 " ) as file :
prompt = file . read ( )
2022-09-24 03:49:21 +08:00
params = parse_generation_parameters ( prompt )
2023-01-13 05:50:09 +08:00
script_callbacks . infotext_pasted_callback ( prompt , params )
2022-09-24 03:49:21 +08:00
res = [ ]
2022-09-25 14:25:28 +08:00
for output , key in paste_fields :
if callable ( key ) :
v = key ( params )
else :
v = params . get ( key , None )
2022-09-24 03:49:21 +08:00
if v is None :
res . append ( gr . update ( ) )
2022-09-25 14:25:28 +08:00
elif isinstance ( v , type_of_gr_update ) :
res . append ( v )
2022-09-24 03:49:21 +08:00
else :
try :
valtype = type ( output . value )
2022-10-21 21:10:51 +08:00
if valtype == bool and v == " False " :
val = False
else :
val = valtype ( v )
2022-09-24 03:49:21 +08:00
res . append ( gr . update ( value = val ) )
except Exception :
res . append ( gr . update ( ) )
return res
2023-01-30 05:25:30 +08:00
if override_settings_component is not None :
def paste_settings ( params ) :
vals = { }
for param_name , setting_name in infotext_to_setting_name_mapping :
v = params . get ( param_name , None )
if v is None :
continue
if setting_name == " sd_model_checkpoint " and shared . opts . disable_weights_auto_swap :
continue
v = shared . opts . cast_value ( setting_name , v )
current_value = getattr ( shared . opts , setting_name , None )
if v == current_value :
continue
vals [ param_name ] = v
vals_pairs = [ f " { k } : { v } " for k , v in vals . items ( ) ]
return gr . Dropdown . update ( value = vals_pairs , choices = vals_pairs , visible = len ( vals_pairs ) > 0 )
paste_fields = paste_fields + [ ( override_settings_component , paste_settings ) ]
2022-09-24 03:49:21 +08:00
button . click (
fn = paste_func ,
inputs = [ input_comp ] ,
2022-09-25 14:25:28 +08:00
outputs = [ x [ 0 ] for x in paste_fields ] ,
2023-05-08 17:17:36 +08:00
show_progress = False ,
2022-09-24 03:49:21 +08:00
)
2023-03-25 12:29:51 +08:00
button . click (
fn = None ,
_js = f " recalculate_prompts_ { tabname } " ,
inputs = [ ] ,
outputs = [ ] ,
2023-05-08 17:17:36 +08:00
show_progress = False ,
2023-03-25 12:29:51 +08:00
)
2022-10-27 13:36:11 +08:00