RAM optimization round 2

This commit is contained in:
AUTOMATIC1111 2023-08-16 09:55:35 +03:00
parent 85fcb7b8df
commit 86221269f9
2 changed files with 48 additions and 8 deletions

View File

@ -304,7 +304,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
weights_backup = getattr(self, "network_weights_backup", None)
if weights_backup is None:
if weights_backup is None and wanted_names != ():
if current_names != ():
raise RuntimeError("no backup weights found and current weights are not unchanged")
if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
else:

View File

@ -168,22 +168,59 @@ class LoadStateDictOnMeta(ReplaceHelper):
device = self.device
def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]
used_param_keys = []
for name, param in self._parameters.items():
if param is None:
continue
key = prefix + name
sd_param = sd.pop(key, None)
if sd_param is not None:
state_dict[key] = sd_param
used_param_keys.append(key)
for name, param in params:
if param.is_meta:
self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)
dtype = sd_param.dtype if sd_param is not None else param.dtype
self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
for name in self._buffers:
key = prefix + name
sd_param = sd.pop(key, None)
if sd_param is not None:
state_dict[key] = sd_param
used_param_keys.append(key)
original(self, state_dict, prefix, *args, **kwargs)
for name, _ in params:
key = prefix + name
if key in sd:
del sd[key]
for key in used_param_keys:
state_dict.pop(key, None)
def load_state_dict(original, self, state_dict, strict=True):
"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).
The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
the function and does not call the original) the state dict will just fail to load because weights
would be on the meta device.
"""
if state_dict == sd:
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
original(self, state_dict, strict=strict)
module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
def __exit__(self, exc_type, exc_val, exc_tb):
self.restore()