medvram support for SD3

This commit is contained in:
AUTOMATIC1111 2024-06-24 10:15:46 +03:00
parent a65dd315ad
commit a8fba9af35
4 changed files with 35 additions and 8 deletions

View File

@ -1,9 +1,12 @@
from collections import namedtuple
import torch
from modules import devices, shared
module_in_gpu = None
cpu = torch.device("cpu")
ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None'])
def send_everything_to_cpu():
global module_in_gpu
@ -75,13 +78,14 @@ def setup_for_low_vram(sd_model, use_medvram):
(sd_model, 'depth_model'),
(sd_model, 'embedder'),
(sd_model, 'model'),
(sd_model, 'embedder'),
]
is_sdxl = hasattr(sd_model, 'conditioner')
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
if is_sdxl:
if hasattr(sd_model, 'medvram_fields'):
to_remain_in_cpu = sd_model.medvram_fields()
elif is_sdxl:
to_remain_in_cpu.append((sd_model, 'conditioner'))
elif is_sd2:
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
@ -103,7 +107,21 @@ def setup_for_low_vram(sd_model, use_medvram):
setattr(obj, field, module)
# register hooks for those the first three models
if is_sdxl:
if hasattr(sd_model.cond_stage_model, "medvram_modules"):
for module in sd_model.cond_stage_model.medvram_modules():
if isinstance(module, ModuleWithParent):
parent = module.parent
module = module.module
else:
parent = None
if module:
module.register_forward_pre_hook(send_me_to_gpu)
if parent:
parents[module] = parent
elif is_sdxl:
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
elif is_sd2:
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
@ -117,9 +135,9 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
if sd_model.depth_model:
if hasattr(sd_model, 'depth_model'):
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
if sd_model.embedder:
if hasattr(sd_model, 'embedder'):
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
if use_medvram:

View File

@ -492,7 +492,6 @@ class MMDiT(nn.Module):
device = None,
):
super().__init__()
print(f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {attn_mode=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}")
self.dtype = dtype
self.learn_sigma = learn_sigma
self.in_channels = in_channels

View File

@ -120,6 +120,9 @@ class SD3Cond(torch.nn.Module):
def encode_embedding_init_text(self, init_text, nvpt):
return torch.tensor([[0]], device=devices.device) # XXX
def medvram_modules(self):
return [self.clip_g, self.clip_l, self.t5xxl]
class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
def __init__(self, inner_model, sigmas):
@ -163,7 +166,7 @@ class SD3Inferencer(torch.nn.Module):
return self.cond_stage_model(batch)
def apply_model(self, x, t, cond):
return self.model.apply_model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
def decode_first_stage(self, latent):
latent = self.latent_format.process_out(latent)
@ -175,3 +178,10 @@ class SD3Inferencer(torch.nn.Module):
def create_denoiser(self):
return SD3Denoiser(self, self.model.model_sampling.sigmas)
def medvram_fields(self):
return [
(self, 'first_stage_model'),
(self, 'cond_stage_model'),
(self, 'model'),
]

View File

@ -163,7 +163,7 @@ def apply_refiner(cfg_denoiser, sigma=None):
else:
# torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
try:
timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma)))
timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas.to(sigma.device) - torch.max(sigma)))
except AttributeError: # for samplers that don't use sigmas (DDIM) sigma is actually the timestep
timestep = torch.max(sigma).to(dtype=int)
completed_ratio = (999 - timestep) / 1000