mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-12-15 07:20:31 +08:00
Merge pull request #14979 from drhead/refiner_cumprod_fix
Protect alphas_cumprod during refiner switchover
This commit is contained in:
commit
06b9200e91
@ -915,33 +915,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if p.n_iter > 1:
|
if p.n_iter > 1:
|
||||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||||
|
|
||||||
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
sd_models.apply_alpha_schedule_override(p.sd_model, p)
|
||||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
|
||||||
|
|
||||||
# Store old values.
|
|
||||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
|
||||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
|
||||||
|
|
||||||
# Shift so the last timestep is zero.
|
|
||||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
|
||||||
|
|
||||||
# Scale so the first timestep is back to the old value.
|
|
||||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
|
||||||
|
|
||||||
# Convert alphas_bar_sqrt to betas
|
|
||||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
|
||||||
alphas_bar[-1] = 4.8973451890853435e-08
|
|
||||||
return alphas_bar
|
|
||||||
|
|
||||||
if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'):
|
|
||||||
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)
|
|
||||||
|
|
||||||
if opts.use_downcasted_alpha_bar:
|
|
||||||
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
|
|
||||||
p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
|
|
||||||
if opts.sd_noise_schedule == "Zero Terminal SNR":
|
|
||||||
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
|
|
||||||
p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
|
|
||||||
|
|
||||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||||
|
@ -15,6 +15,7 @@ from ldm.util import instantiate_from_config
|
|||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
|
from modules.shared import opts
|
||||||
import tomesd
|
import tomesd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -549,6 +550,36 @@ def repair_config(sd_config):
|
|||||||
karlo_path = os.path.join(paths.models_path, 'karlo')
|
karlo_path = os.path.join(paths.models_path, 'karlo')
|
||||||
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
|
sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)
|
||||||
|
|
||||||
|
def apply_alpha_schedule_override(sd_model, p=None):
|
||||||
|
def rescale_zero_terminal_snr_abar(alphas_cumprod):
|
||||||
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||||
|
|
||||||
|
# Store old values.
|
||||||
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||||
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||||
|
|
||||||
|
# Shift so the last timestep is zero.
|
||||||
|
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Scale so the first timestep is back to the old value.
|
||||||
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Convert alphas_bar_sqrt to betas
|
||||||
|
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||||
|
alphas_bar[-1] = 4.8973451890853435e-08
|
||||||
|
return alphas_bar
|
||||||
|
|
||||||
|
if hasattr(sd_model, 'alphas_cumprod') and hasattr(sd_model, 'alphas_cumprod_original'):
|
||||||
|
sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)
|
||||||
|
|
||||||
|
if opts.use_downcasted_alpha_bar:
|
||||||
|
if p is not None:
|
||||||
|
p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
|
||||||
|
sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)
|
||||||
|
if opts.sd_noise_schedule == "Zero Terminal SNR":
|
||||||
|
if p is not None:
|
||||||
|
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
|
||||||
|
sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)
|
||||||
|
|
||||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||||
@ -812,6 +843,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
|||||||
|
|
||||||
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
|
||||||
if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
|
||||||
|
apply_alpha_schedule_override(sd_model)
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
if sd_model is not None:
|
if sd_model is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user