2023-12-02 13:33:28 +08:00
from __future__ import annotations
2022-10-29 13:42:34 +08:00
import base64
import io
2023-05-19 01:16:09 +08:00
import json
2022-10-09 11:57:19 +08:00
import os
2022-09-24 03:49:21 +08:00
import re
2024-01-01 18:52:37 +08:00
import sys
2022-11-28 04:04:42 +08:00
2022-09-24 03:49:21 +08:00
import gradio as gr
2023-01-26 01:00:09 +08:00
from modules . paths import data_path
2024-01-01 20:00:39 +08:00
from modules import shared , ui_tempdir , script_callbacks , processing , infotext_versions
2022-10-29 15:56:19 +08:00
from PIL import Image
2022-09-24 03:49:21 +08:00
2024-01-01 18:52:37 +08:00
sys . modules [ ' modules.generation_parameters_copypaste ' ] = sys . modules [ __name__ ] # alias for old name
2023-10-03 00:16:41 +08:00
re_param_code = r ' \ s*( \ w[ \ w \ -/]+): \ s*( " (?: \\ .|[^ \\ " ])+ " |[^,]*)(?:,|$) '
2022-09-25 14:25:28 +08:00
re_param = re . compile ( re_param_code )
2022-09-24 03:49:21 +08:00
re_imagesize = re . compile ( r " ^( \ d+)x( \ d+)$ " )
2022-12-14 06:25:16 +08:00
re_hypernet_hash = re . compile ( " \ (([0-9a-f]+) \ )$ " )
2022-09-25 14:25:28 +08:00
type_of_gr_update = type ( gr . update ( ) )
2023-01-30 05:25:30 +08:00
class ParamBinding :
2023-05-10 16:19:16 +08:00
def __init__ ( self , paste_button , tabname , source_text_component = None , source_image_component = None , source_tabname = None , override_settings_component = None , paste_field_names = None ) :
2023-01-30 05:25:30 +08:00
self . paste_button = paste_button
self . tabname = tabname
self . source_text_component = source_text_component
self . source_image_component = source_image_component
self . source_tabname = source_tabname
self . override_settings_component = override_settings_component
2023-05-10 16:19:16 +08:00
self . paste_field_names = paste_field_names or [ ]
2022-09-24 03:49:21 +08:00
2022-10-29 13:42:34 +08:00
2023-12-17 15:22:03 +08:00
class PasteField ( tuple ) :
def __new__ ( cls , component , target , * , api = None ) :
return super ( ) . __new__ ( cls , ( component , target ) )
def __init__ ( self , component , target , * , api = None ) :
super ( ) . __init__ ( )
self . api = api
self . component = component
self . label = target if isinstance ( target , str ) else None
self . function = target if callable ( target ) else None
2023-12-02 13:33:28 +08:00
paste_fields : dict [ str , dict ] = { }
registered_param_bindings : list [ ParamBinding ] = [ ]
2022-10-31 22:36:45 +08:00
def reset ( ) :
paste_fields . clear ( )
2023-08-27 14:19:02 +08:00
registered_param_bindings . clear ( )
2022-10-31 22:36:45 +08:00
2022-10-21 21:10:51 +08:00
def quote ( text ) :
2023-05-28 15:39:57 +08:00
if ' , ' not in str ( text ) and ' \n ' not in str ( text ) and ' : ' not in str ( text ) :
2022-10-21 21:10:51 +08:00
return text
2023-05-19 01:16:09 +08:00
return json . dumps ( text , ensure_ascii = False )
def unquote ( text ) :
if len ( text ) == 0 or text [ 0 ] != ' " ' or text [ - 1 ] != ' " ' :
return text
try :
return json . loads ( text )
except Exception :
return text
2022-10-21 21:10:51 +08:00
2022-10-29 13:42:34 +08:00
2022-10-27 13:36:11 +08:00
def image_from_url_text ( filedata ) :
2023-01-16 04:23:16 +08:00
if filedata is None :
return None
2023-06-02 19:58:10 +08:00
if type ( filedata ) == list and filedata and type ( filedata [ 0 ] ) == dict and filedata [ 0 ] . get ( " is_file " , False ) :
2023-01-04 02:49:24 +08:00
filedata = filedata [ 0 ]
if type ( filedata ) == dict and filedata . get ( " is_file " , False ) :
2022-10-27 13:36:11 +08:00
filename = filedata [ " name " ]
2023-01-03 19:18:48 +08:00
is_in_right_dir = ui_tempdir . check_tmp_file ( shared . demo , filename )
2022-11-28 04:04:42 +08:00
assert is_in_right_dir , ' trying to open image file outside of allowed directories '
2022-10-27 13:36:11 +08:00
2023-05-05 15:51:01 +08:00
filename = filename . rsplit ( ' ? ' , 1 ) [ 0 ]
2022-10-27 13:36:11 +08:00
return Image . open ( filename )
if type ( filedata ) == list :
if len ( filedata ) == 0 :
return None
filedata = filedata [ 0 ]
if filedata . startswith ( " data:image/png;base64, " ) :
filedata = filedata [ len ( " data:image/png;base64, " ) : ]
filedata = base64 . decodebytes ( filedata . encode ( ' utf-8 ' ) )
image = Image . open ( io . BytesIO ( filedata ) )
return image
2022-10-29 13:42:34 +08:00
2023-02-19 14:30:49 +08:00
def add_paste_fields ( tabname , init_img , fields , override_settings_component = None ) :
2023-12-17 15:22:03 +08:00
if fields :
for i in range ( len ( fields ) ) :
if not isinstance ( fields [ i ] , PasteField ) :
fields [ i ] = PasteField ( * fields [ i ] )
2023-02-19 14:30:49 +08:00
paste_fields [ tabname ] = { " init_img " : init_img , " fields " : fields , " override_settings_component " : override_settings_component }
2022-10-29 14:01:04 +08:00
# backwards compatibility for existing extensions
import modules . ui
if tabname == ' txt2img ' :
modules . ui . txt2img_paste_fields = fields
elif tabname == ' img2img ' :
modules . ui . img2img_paste_fields = fields
2022-10-27 13:36:11 +08:00
2022-10-29 13:42:34 +08:00
2022-10-27 13:36:11 +08:00
def create_buttons ( tabs_list ) :
buttons = { }
for tab in tabs_list :
2023-01-01 21:51:12 +08:00
buttons [ tab ] = gr . Button ( f " Send to { tab } " , elem_id = f " { tab } _tab " )
2022-10-27 13:36:11 +08:00
return buttons
2022-10-29 13:42:34 +08:00
2022-10-27 13:36:11 +08:00
def bind_buttons ( buttons , send_image , send_generate_info ) :
2023-01-30 05:25:30 +08:00
""" old function for backwards compatibility; do not use this, use register_paste_params_button """
for tabname , button in buttons . items ( ) :
source_text_component = send_generate_info if isinstance ( send_generate_info , gr . components . Component ) else None
source_tabname = send_generate_info if isinstance ( send_generate_info , str ) else None
register_paste_params_button ( ParamBinding ( paste_button = button , tabname = tabname , source_text_component = source_text_component , source_image_component = send_image , source_tabname = source_tabname ) )
def register_paste_params_button ( binding : ParamBinding ) :
registered_param_bindings . append ( binding )
def connect_paste_params_buttons ( ) :
for binding in registered_param_bindings :
destination_image_component = paste_fields [ binding . tabname ] [ " init_img " ]
fields = paste_fields [ binding . tabname ] [ " fields " ]
2023-02-19 14:30:49 +08:00
override_settings_component = binding . override_settings_component or paste_fields [ binding . tabname ] [ " override_settings_component " ]
2023-01-30 05:25:30 +08:00
destination_width_component = next ( iter ( [ field for field , name in fields if name == " Size-1 " ] if fields else [ ] ) , None )
destination_height_component = next ( iter ( [ field for field , name in fields if name == " Size-2 " ] if fields else [ ] ) , None )
if binding . source_image_component and destination_image_component :
if isinstance ( binding . source_image_component , gr . Gallery ) :
func = send_image_and_dimensions if destination_width_component else image_from_url_text
jsfunc = " extract_image_from_gallery "
else :
func = send_image_and_dimensions if destination_width_component else lambda x : x
jsfunc = None
binding . paste_button . click (
fn = func ,
_js = jsfunc ,
inputs = [ binding . source_image_component ] ,
outputs = [ destination_image_component , destination_width_component , destination_height_component ] if destination_width_component else [ destination_image_component ] ,
2023-05-08 17:17:36 +08:00
show_progress = False ,
2023-01-30 05:25:30 +08:00
)
if binding . source_text_component is not None and fields is not None :
2023-02-19 14:30:49 +08:00
connect_paste ( binding . paste_button , fields , binding . source_text_component , override_settings_component , binding . tabname )
2023-01-30 05:25:30 +08:00
if binding . source_tabname is not None and fields is not None :
2023-02-14 19:55:42 +08:00
paste_field_names = [ ' Prompt ' , ' Negative prompt ' , ' Steps ' , ' Face restoration ' ] + ( [ " Seed " ] if shared . opts . send_seed else [ ] ) + binding . paste_field_names
2023-01-30 05:25:30 +08:00
binding . paste_button . click (
fn = lambda * x : x ,
inputs = [ field for field , name in paste_fields [ binding . source_tabname ] [ " fields " ] if name in paste_field_names ] ,
outputs = [ field for field , name in fields if name in paste_field_names ] ,
2023-05-08 17:17:36 +08:00
show_progress = False ,
2023-01-30 05:25:30 +08:00
)
binding . paste_button . click (
fn = None ,
_js = f " switch_to_ { binding . tabname } " ,
inputs = None ,
outputs = None ,
2023-05-08 17:17:36 +08:00
show_progress = False ,
2023-01-30 05:25:30 +08:00
)
2022-10-27 13:36:11 +08:00
2022-10-29 13:42:34 +08:00
2023-01-03 03:44:46 +08:00
def send_image_and_dimensions ( x ) :
if isinstance ( x , Image . Image ) :
img = x
else :
img = image_from_url_text ( x )
if shared . opts . send_size and isinstance ( img , Image . Image ) :
w = img . width
h = img . height
else :
w = gr . update ( )
h = gr . update ( )
return img , w , h
2023-01-03 00:42:10 +08:00
def restore_old_hires_fix_params ( res ) :
""" for infotexts that specify old First pass size parameter, convert it into
width , height , and hr scale """
firstpass_width = res . get ( ' First pass size-1 ' , None )
firstpass_height = res . get ( ' First pass size-2 ' , None )
2023-01-09 19:57:47 +08:00
if shared . opts . use_old_hires_fix_width_height :
2023-01-10 07:17:33 +08:00
hires_width = int ( res . get ( " Hires resize-1 " , 0 ) )
hires_height = int ( res . get ( " Hires resize-2 " , 0 ) )
2023-01-09 19:57:47 +08:00
2023-01-10 07:17:33 +08:00
if hires_width and hires_height :
2023-01-09 19:57:47 +08:00
res [ ' Size-1 ' ] = hires_width
res [ ' Size-2 ' ] = hires_height
return
2023-01-03 00:42:10 +08:00
if firstpass_width is None or firstpass_height is None :
return
firstpass_width , firstpass_height = int ( firstpass_width ) , int ( firstpass_height )
width = int ( res . get ( " Size-1 " , 512 ) )
height = int ( res . get ( " Size-2 " , 512 ) )
if firstpass_width == 0 or firstpass_height == 0 :
2023-01-09 19:57:47 +08:00
firstpass_width , firstpass_height = processing . old_hires_fix_first_pass_dimensions ( width , height )
2023-01-03 00:42:10 +08:00
res [ ' Size-1 ' ] = firstpass_width
res [ ' Size-2 ' ] = firstpass_height
2023-01-05 03:04:40 +08:00
res [ ' Hires resize-1 ' ] = width
res [ ' Hires resize-2 ' ] = height
2023-01-03 00:42:10 +08:00
2024-01-16 19:16:07 +08:00
def parse_generation_parameters ( x : str , skip_fields : list [ str ] | None = None ) :
2022-09-24 03:49:21 +08:00
""" parses generation parameters string, the one you see in text field under the picture in UI:
` ` `
girl with an artist ' s beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
Negative prompt : ugly , fat , obese , chubby , ( ( ( deformed ) ) ) , [ blurry ] , bad anatomy , disfigured , poorly drawn face , mutation , mutated , ( extra_limb ) , ( ugly ) , ( poorly drawn hands ) , messy drawing
Steps : 20 , Sampler : Euler a , CFG scale : 7 , Seed : 965400086 , Size : 512 x512 , Model hash : 45 dee52b
` ` `
returns a dict with field values
"""
2024-01-16 19:16:07 +08:00
if skip_fields is None :
skip_fields = shared . opts . infotext_skip_pasting
2022-09-24 03:49:21 +08:00
res = { }
prompt = " "
negative_prompt = " "
done_with_prompt = False
* lines , lastline = x . strip ( ) . split ( " \n " )
2023-01-28 23:04:09 +08:00
if len ( re_param . findall ( lastline ) ) < 3 :
2022-09-25 14:25:28 +08:00
lines . append ( lastline )
lastline = ' '
2023-05-10 16:37:18 +08:00
for line in lines :
2022-09-24 03:49:21 +08:00
line = line . strip ( )
if line . startswith ( " Negative prompt: " ) :
done_with_prompt = True
line = line [ 16 : ] . strip ( )
if done_with_prompt :
2022-09-25 14:25:28 +08:00
negative_prompt + = ( " " if negative_prompt == " " else " \n " ) + line
2022-09-24 03:49:21 +08:00
else :
2022-09-25 14:25:28 +08:00
prompt + = ( " " if prompt == " " else " \n " ) + line
2022-09-24 03:49:21 +08:00
2023-06-04 15:56:48 +08:00
if shared . opts . infotext_styles != " Ignore " :
found_styles , prompt , negative_prompt = shared . prompt_styles . extract_styles_from_prompt ( prompt , negative_prompt )
if shared . opts . infotext_styles == " Apply " :
res [ " Styles array " ] = found_styles
elif shared . opts . infotext_styles == " Apply if any " and found_styles :
res [ " Styles array " ] = found_styles
2022-10-20 02:22:03 +08:00
res [ " Prompt " ] = prompt
res [ " Negative prompt " ] = negative_prompt
2022-09-24 03:49:21 +08:00
for k , v in re_param . findall ( lastline ) :
2023-06-06 04:40:00 +08:00
try :
if v [ 0 ] == ' " ' and v [ - 1 ] == ' " ' :
v = unquote ( v )
m = re_imagesize . match ( v )
if m is not None :
res [ f " { k } -1 " ] = m . group ( 1 )
res [ f " { k } -2 " ] = m . group ( 2 )
else :
res [ k ] = v
except Exception :
print ( f " Error parsing \" { k } : { v } \" " )
2022-09-24 03:49:21 +08:00
2022-12-02 03:34:16 +08:00
# Missing CLIP skip means it was set to 1 (the default)
if " Clip skip " not in res :
res [ " Clip skip " ] = " 1 "
2023-01-21 13:36:07 +08:00
hypernet = res . get ( " Hypernet " , None )
if hypernet is not None :
res [ " Prompt " ] + = f """ <hypernet: { hypernet } : { res . get ( " Hypernet strength " , " 1.0 " ) } > """
2022-12-14 06:25:16 +08:00
2023-01-05 03:04:40 +08:00
if " Hires resize-1 " not in res :
res [ " Hires resize-1 " ] = 0
res [ " Hires resize-2 " ] = 0
2023-05-19 01:16:09 +08:00
if " Hires sampler " not in res :
res [ " Hires sampler " ] = " Use same sampler "
2023-07-30 18:48:27 +08:00
if " Hires checkpoint " not in res :
res [ " Hires checkpoint " ] = " Use same checkpoint "
2023-05-19 01:16:09 +08:00
if " Hires prompt " not in res :
res [ " Hires prompt " ] = " "
if " Hires negative prompt " not in res :
res [ " Hires negative prompt " ] = " "
2024-01-02 12:05:05 +08:00
if " Mask mode " not in res :
res [ " Mask mode " ] = " Inpaint masked "
if " Masked content " not in res :
res [ " Masked content " ] = ' original '
if " Inpaint area " not in res :
res [ " Inpaint area " ] = " Whole picture "
if " Masked area padding " not in res :
res [ " Masked area padding " ] = 32
2023-01-03 00:42:10 +08:00
restore_old_hires_fix_params ( res )
2023-04-29 16:29:37 +08:00
# Missing RNG means the default was set, which is GPU RNG
if " RNG " not in res :
res [ " RNG " ] = " GPU "
2023-05-28 00:53:09 +08:00
if " Schedule type " not in res :
res [ " Schedule type " ] = " Automatic "
2023-05-24 20:35:58 +08:00
2023-05-28 00:53:09 +08:00
if " Schedule max sigma " not in res :
res [ " Schedule max sigma " ] = 0
2023-05-24 20:35:58 +08:00
2023-05-28 00:53:09 +08:00
if " Schedule min sigma " not in res :
res [ " Schedule min sigma " ] = 0
2023-05-24 20:35:58 +08:00
2023-05-28 00:53:09 +08:00
if " Schedule rho " not in res :
res [ " Schedule rho " ] = 0
2023-05-24 20:35:58 +08:00
2023-08-05 12:35:40 +08:00
if " VAE Encoder " not in res :
res [ " VAE Encoder " ] = " Full "
if " VAE Decoder " not in res :
res [ " VAE Decoder " ] = " Full "
2023-12-16 15:08:08 +08:00
if " FP8 weight " not in res :
res [ " FP8 weight " ] = " Disable "
if " Cache FP16 weight for LoRA " not in res and res [ " FP8 weight " ] != " Disable " :
res [ " Cache FP16 weight for LoRA " ] = False
2024-01-01 19:38:29 +08:00
infotext_versions . backcompat ( res )
2024-01-16 19:16:07 +08:00
for key in skip_fields :
res . pop ( key , None )
2023-12-02 13:33:28 +08:00
2022-09-24 03:49:21 +08:00
return res
2022-10-29 14:01:04 +08:00
2023-01-30 05:25:30 +08:00
infotext_to_setting_name_mapping = [
2023-08-10 21:42:26 +08:00
]
""" Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext= ' ... ' ) instead.
Example content :
infotext_to_setting_name_mapping = [
2023-01-30 05:25:30 +08:00
( ' Conditional mask weight ' , ' inpainting_mask_weight ' ) ,
( ' Model hash ' , ' sd_model_checkpoint ' ) ,
( ' ENSD ' , ' eta_noise_seed_delta ' ) ,
2023-05-28 00:53:09 +08:00
( ' Schedule type ' , ' k_sched_type ' ) ,
2023-01-30 05:25:30 +08:00
]
2023-08-10 21:42:26 +08:00
"""
2023-01-30 05:25:30 +08:00
def create_override_settings_dict ( text_pairs ) :
""" creates processing ' s override_settings parameters from gradio ' s multiselect
Example input :
[ ' Clip skip: 2 ' , ' Model hash: e6e99610c4 ' , ' ENSD: 31337 ' ]
Example output :
{ ' CLIP_stop_at_last_layers ' : 2 , ' sd_model_checkpoint ' : ' e6e99610c4 ' , ' eta_noise_seed_delta ' : 31337 }
"""
res = { }
params = { }
for pair in text_pairs :
k , v = pair . split ( " : " , maxsplit = 1 )
params [ k ] = v . strip ( )
2023-08-10 21:42:26 +08:00
mapping = [ ( info . infotext , k ) for k , info in shared . opts . data_labels . items ( ) if info . infotext ]
for param_name , setting_name in mapping + infotext_to_setting_name_mapping :
2023-01-30 05:25:30 +08:00
value = params . get ( param_name , None )
if value is None :
continue
res [ setting_name ] = shared . opts . cast_value ( setting_name , value )
return res
2023-12-30 17:11:09 +08:00
def get_override_settings ( params , * , skip_fields = None ) :
""" Returns a list of settings overrides from the infotext parameters dictionary.
This function checks the ` params ` dictionary for any keys that correspond to settings in ` shared . opts ` and returns
a list of tuples containing the parameter name , setting name , and new value cast to correct type .
It checks for conditions before adding an override :
- ignores settings that match the current value
- ignores parameter keys present in skip_fields argument .
Example input :
{ " Clip skip " : " 2 " }
Example output :
[ ( " Clip skip " , " CLIP_stop_at_last_layers " , 2 ) ]
"""
res = [ ]
mapping = [ ( info . infotext , k ) for k , info in shared . opts . data_labels . items ( ) if info . infotext ]
for param_name , setting_name in mapping + infotext_to_setting_name_mapping :
if param_name in ( skip_fields or { } ) :
continue
v = params . get ( param_name , None )
if v is None :
continue
if setting_name == " sd_model_checkpoint " and shared . opts . disable_weights_auto_swap :
continue
v = shared . opts . cast_value ( setting_name , v )
current_value = getattr ( shared . opts , setting_name , None )
if v == current_value :
continue
res . append ( ( param_name , setting_name , v ) )
return res
2023-01-30 06:03:31 +08:00
def connect_paste ( button , paste_fields , input_comp , override_settings_component , tabname ) :
2022-09-24 03:49:21 +08:00
def paste_func ( prompt ) :
2022-10-13 17:26:34 +08:00
if not prompt and not shared . cmd_opts . hide_ui_dir_config :
2023-01-26 00:15:42 +08:00
filename = os . path . join ( data_path , " params.txt " )
2024-01-04 06:16:58 +08:00
try :
2022-10-09 11:57:19 +08:00
with open ( filename , " r " , encoding = " utf8 " ) as file :
prompt = file . read ( )
2024-01-04 06:16:58 +08:00
except OSError :
pass
2022-10-09 11:57:19 +08:00
2022-09-24 03:49:21 +08:00
params = parse_generation_parameters ( prompt )
2023-01-13 05:50:09 +08:00
script_callbacks . infotext_pasted_callback ( prompt , params )
2022-09-24 03:49:21 +08:00
res = [ ]
2022-09-25 14:25:28 +08:00
for output , key in paste_fields :
if callable ( key ) :
v = key ( params )
else :
v = params . get ( key , None )
2022-09-24 03:49:21 +08:00
if v is None :
res . append ( gr . update ( ) )
2022-09-25 14:25:28 +08:00
elif isinstance ( v , type_of_gr_update ) :
res . append ( v )
2022-09-24 03:49:21 +08:00
else :
try :
valtype = type ( output . value )
2022-10-21 21:10:51 +08:00
if valtype == bool and v == " False " :
val = False
2024-01-04 02:46:05 +08:00
elif valtype == int :
val = float ( v )
2022-10-21 21:10:51 +08:00
else :
val = valtype ( v )
2022-09-24 03:49:21 +08:00
res . append ( gr . update ( value = val ) )
except Exception :
res . append ( gr . update ( ) )
return res
2023-01-30 05:25:30 +08:00
if override_settings_component is not None :
2023-08-07 14:42:13 +08:00
already_handled_fields = { key : 1 for _ , key in paste_fields }
2023-01-30 05:25:30 +08:00
def paste_settings ( params ) :
2023-12-30 17:11:09 +08:00
vals = get_override_settings ( params , skip_fields = already_handled_fields )
2023-01-30 05:25:30 +08:00
2023-12-30 17:11:09 +08:00
vals_pairs = [ f " { infotext_text } : { value } " for infotext_text , setting_name , value in vals ]
2023-01-30 05:25:30 +08:00
2023-06-02 19:58:10 +08:00
return gr . Dropdown . update ( value = vals_pairs , choices = vals_pairs , visible = bool ( vals_pairs ) )
2023-01-30 05:25:30 +08:00
paste_fields = paste_fields + [ ( override_settings_component , paste_settings ) ]
2022-09-24 03:49:21 +08:00
button . click (
fn = paste_func ,
inputs = [ input_comp ] ,
2022-09-25 14:25:28 +08:00
outputs = [ x [ 0 ] for x in paste_fields ] ,
2023-05-08 17:17:36 +08:00
show_progress = False ,
2022-09-24 03:49:21 +08:00
)
2023-03-25 12:29:51 +08:00
button . click (
fn = None ,
_js = f " recalculate_prompts_ { tabname } " ,
inputs = [ ] ,
outputs = [ ] ,
2023-05-08 17:17:36 +08:00
show_progress = False ,
2023-03-25 12:29:51 +08:00
)
2023-12-02 13:33:28 +08:00