refactor: move factorization to lyco_helpers, separate calc_updown for kohya and kb

This commit is contained in:
v0xie 2023-11-03 19:35:15 -07:00
parent fe1967a4c4
commit f6c8201e56
2 changed files with 77 additions and 101 deletions

View File

@ -19,3 +19,50 @@ def rebuild_cp_decomposition(up, down, mid):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
'''
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
In LoRA with Kroneckor Product, first value is a value for weight scale.
secon value is a value for weight.
Becuase of non-commutative property, AB BA. Meaning of two matrices is slightly different.
examples)
factor
-1 2 4 8 16 ...
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
'''
if factor > 0 and (dimension % factor) == 0:
m = factor
n = dimension // factor
if m > n:
n, m = m, n
return m, n
if factor < 0:
factor = dimension
m, n = 1, dimension
length = m + n
while m<n:
new_m = m + 1
while dimension%new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m>factor:
break
else:
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n

View File

@ -1,7 +1,7 @@
import torch
import network
from lyco_helpers import factorization
from einops import rearrange
from modules import devices
class ModuleTypeOFT(network.ModuleType):
@ -11,7 +11,8 @@ class ModuleTypeOFT(network.ModuleType):
return None
# adapted from kohya's implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
# adapted from kohya-ss' implementation https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py
# and KohakuBlueleaf's implementation https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py
class NetworkModuleOFT(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
@ -19,6 +20,7 @@ class NetworkModuleOFT(network.NetworkModule):
self.lin_module = None
self.org_module: list[torch.Module] = [self.sd_module]
# kohya-ss
if "oft_blocks" in weights.w.keys():
self.is_kohya = True
@ -37,61 +39,31 @@ class NetworkModuleOFT(network.NetworkModule):
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention]
#if "Linear" in self.sd_module.__class__.__name__ or is_linear:
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention]
if is_linear:
self.out_dim = self.sd_module.out_features
#elif hasattr(self.sd_module, "embed_dim"):
# self.out_dim = self.sd_module.embed_dim
#else:
# raise ValueError("Linear sd_module must have out_features or embed_dim")
elif is_other_linear:
self.out_dim = self.sd_module.embed_dim
#self.org_weight = self.org_module[0].weight
# if hasattr(self.sd_module, "in_proj_weight"):
# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1]
# if hasattr(self.sd_module, "out_proj_weight"):
# self.out_proj_dim = self.sd_module.out_proj_weight.shape[0]
# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1]
elif is_conv:
self.out_dim = self.sd_module.out_channels
else:
raise ValueError("sd_module must be Linear or Conv")
if self.is_kohya:
self.num_blocks = self.dim
self.block_size = self.out_dim // self.num_blocks
self.constraint = self.alpha * self.out_dim
#elif is_linear or is_conv:
else:
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
self.constraint = None
# if is_other_linear:
# weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1)
# module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
# with torch.no_grad():
# if weight.shape != module.weight.shape:
# weight = weight.reshape(module.weight.shape)
# module.weight.copy_(weight)
# module.to(device=devices.cpu, dtype=devices.dtype)
# module.weight.requires_grad_(False)
# self.lin_module = module
#return module
def merge_weight(self, R_weight, org_weight):
R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype)
if org_weight.dim() == 4:
weight = torch.einsum("oihw, op -> pihw", org_weight, R_weight)
else:
weight = torch.einsum("oi, op -> pi", org_weight, R_weight)
#weight = torch.einsum(
# "k n m, k n ... -> k m ...",
# self.oft_diag * scale + torch.eye(self.block_size, device=device),
# org_weight
#)
return weight
def get_weight(self, oft_blocks, multiplier=None):
@ -111,48 +83,51 @@ class NetworkModuleOFT(network.NetworkModule):
block_R_weighted = multiplier * block_R + (1 - multiplier) * m_I
R = torch.block_diag(*block_R_weighted)
return R
#return self.oft_blocks
def calc_updown_kohya(self, orig_weight, multiplier):
R = self.get_weight(self.oft_blocks, multiplier)
merged_weight = self.merge_weight(R, orig_weight)
def calc_updown(self, orig_weight):
multiplier = self.multiplier() * self.calc_scale()
is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention]
if self.is_kohya and not is_other_linear:
R = self.get_weight(self.oft_blocks, multiplier)
#R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
merged_weight = self.merge_weight(R, orig_weight)
elif not self.is_kohya and not is_other_linear:
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
output_shape = orig_weight.shape
orig_weight = orig_weight
return self.finalize_updown(updown, orig_weight, output_shape)
def calc_updown_kb(self, orig_weight, multiplier):
is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention]
if not is_other_linear:
if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
orig_weight=orig_weight.permute(1, 0)
R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
#orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks)
merged_weight = torch.einsum(
'k n m, k n ... -> k m ...',
R * multiplier + torch.eye(self.block_size, device=orig_weight.device),
merged_weight
merged_weight
)
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
orig_weight=orig_weight.permute(1, 0)
#merged_weight=merged_weight.permute(1, 0)
updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
#updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
output_shape = orig_weight.shape
else:
# skip for now
# FIXME: skip MultiheadAttention for now
updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype)
output_shape = (orig_weight.shape[1], orig_weight.shape[1])
#if self.lin_module is not None:
# R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype)
# weight = torch.mul(torch.mul(R, multiplier), orig_weight)
#else:
orig_weight = orig_weight
return self.finalize_updown(updown, orig_weight, output_shape)
def calc_updown(self, orig_weight):
multiplier = self.multiplier() * self.calc_scale()
if self.is_kohya:
return self.calc_updown_kohya(orig_weight, multiplier)
else:
return self.calc_updown_kb(orig_weight, multiplier)
# override to remove the multiplier/scale factor; it's already multiplied in get_weight
def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
#return super().finalize_updown(updown, orig_weight, output_shape, ex_bias)
@ -172,49 +147,3 @@ class NetworkModuleOFT(network.NetworkModule):
ex_bias = ex_bias * self.multiplier()
return updown, ex_bias
# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py
def factorization(dimension: int, factor:int=-1) -> tuple[int, int]:
'''
return a tuple of two value of input dimension decomposed by the number closest to factor
second value is higher or equal than first value.
In LoRA with Kroneckor Product, first value is a value for weight scale.
secon value is a value for weight.
Becuase of non-commutative property, AB BA. Meaning of two matrices is slightly different.
examples)
factor
-1 2 4 8 16 ...
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
'''
if factor > 0 and (dimension % factor) == 0:
m = factor
n = dimension // factor
if m > n:
n, m = m, n
return m, n
if factor < 0:
factor = dimension
m, n = 1, dimension
length = m + n
while m<n:
new_m = m + 1
while dimension%new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m>factor:
break
else:
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n