feat: Enhance LoHa and LoKr modules with Tucker decomposition support

- Added Tucker decomposition functionality to LoHa and LoKr modules.
- Implemented new methods for weight rebuilding using Tucker decomposition.
- Updated initialization and weight handling for Conv2d 3x3+ layers.
- Modified get_diff_weight methods to accommodate Tucker and non-Tucker modes.
- Enhanced network base to include unet_conv_target_modules for architecture detection.
This commit is contained in:
Kohya S
2026-02-15 23:35:17 +09:00
parent ba72f3b665
commit aff4a2e732
3 changed files with 417 additions and 81 deletions

View File

@@ -51,6 +51,63 @@ class HadaWeight(torch.autograd.Function):
return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
class HadaWeightTucker(torch.autograd.Function):
"""Tucker-decomposed Hadamard product forward/backward for LoHa Conv2d 3x3+.
Computes (rebuild(t1, w1b, w1a) * rebuild(t2, w2b, w2a)) * scale
where rebuild = einsum("i j ..., j r, i p -> p r ...", t, wb, wa).
Compatible with LyCORIS parameter naming convention.
"""
@staticmethod
def forward(ctx, t1, w1b, w1a, t2, w2b, w2a, scale=None):
if scale is None:
scale = torch.tensor(1, device=t1.device, dtype=t1.dtype)
ctx.save_for_backward(t1, w1b, w1a, t2, w2b, w2a, scale)
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
return rebuild1 * rebuild2 * scale
@staticmethod
def backward(ctx, grad_out):
(t1, w1b, w1a, t2, w2b, w2a, scale) = ctx.saved_tensors
grad_out = grad_out * scale
# Gradients for w1a, w1b, t1 (using rebuild2)
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2b)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2a)
grad_w = rebuild * grad_out
del rebuild
grad_w1a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1a.T)
del grad_w, temp
grad_w1b = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1b.T)
del grad_temp
# Gradients for w2a, w2b, t2 (using rebuild1)
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1b)
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1a)
grad_w = rebuild * grad_out
del rebuild
grad_w2a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2a.T)
del grad_w, temp
grad_w2b = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2b.T)
del grad_temp
return grad_t1, grad_w1b, grad_w1a, grad_t2, grad_w2b, grad_w2a, None
class LoHaModule(torch.nn.Module):
"""LoHa module for training. Replaces forward method of the original Linear/Conv2d."""
@@ -64,6 +121,7 @@ class LoHaModule(torch.nn.Module):
dropout=None,
rank_dropout=None,
module_dropout=None,
use_tucker=False,
**kwargs,
):
super().__init__()
@@ -74,26 +132,78 @@ class LoHaModule(torch.nn.Module):
if is_conv2d:
in_dim = org_module.in_channels
out_dim = org_module.out_channels
kernel_size = org_module.kernel_size
self.is_conv = True
if org_module.kernel_size != (1, 1):
raise ValueError("LoHa Conv2d 3x3 (Tucker decomposition) is not supported yet")
self.stride = org_module.stride
self.padding = org_module.padding
self.dilation = org_module.dilation
self.groups = org_module.groups
self.kernel_size = kernel_size
self.tucker = use_tucker and any(k != 1 for k in kernel_size)
if kernel_size == (1, 1):
self.conv_mode = "1x1"
elif self.tucker:
self.conv_mode = "tucker"
else:
self.conv_mode = "flat"
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.is_conv = False
self.tucker = False
self.conv_mode = None
self.kernel_size = None
# Hadamard product parameters: ΔW = (w1a @ w1b) * (w2a @ w2b)
self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
self.in_dim = in_dim
self.out_dim = out_dim
# Initialization: w1_a normal(0.1), w1_b normal(1.0), w2_a = 0, w2_b normal(1.0)
# Ensures ΔW = 0 at init since w2_a = 0
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
torch.nn.init.constant_(self.hada_w2_a, 0)
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
# Create parameters based on mode
if self.conv_mode == "tucker":
# Tucker decomposition for Conv2d 3x3+
# Shapes follow LyCORIS convention: w_a = (rank, out_dim), w_b = (rank, in_dim)
self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size))
self.hada_w1_a = nn.Parameter(torch.empty(lora_dim, out_dim))
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size))
self.hada_w2_a = nn.Parameter(torch.empty(lora_dim, out_dim))
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
# LyCORIS init: w1_a = 0 (ensures ΔW=0), t1/t2 normal(0.1)
torch.nn.init.normal_(self.hada_t1, std=0.1)
torch.nn.init.normal_(self.hada_t2, std=0.1)
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
torch.nn.init.constant_(self.hada_w1_a, 0)
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
torch.nn.init.normal_(self.hada_w2_a, std=0.1)
elif self.conv_mode == "flat":
# Non-Tucker Conv2d 3x3+: flatten kernel into in_dim
k_prod = 1
for k in kernel_size:
k_prod *= k
flat_in = in_dim * k_prod
self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, flat_in))
self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, flat_in))
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
torch.nn.init.constant_(self.hada_w2_a, 0)
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
else:
# Linear or Conv2d 1x1
self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
torch.nn.init.constant_(self.hada_w2_a, 0)
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy()
@@ -113,9 +223,27 @@ class LoHaModule(torch.nn.Module):
del self.org_module
def get_diff_weight(self):
"""Return materialized weight delta as a 2D matrix."""
scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
return HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
"""Return materialized weight delta.
Returns:
- Linear: 2D tensor (out_dim, in_dim)
- Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d
- Conv2d 3x3+ Tucker: 4D tensor (out_dim, in_dim, k1, k2)
- Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2)
"""
if self.tucker:
scale = torch.tensor(self.scale, dtype=self.hada_t1.dtype, device=self.hada_t1.device)
return HadaWeightTucker.apply(
self.hada_t1, self.hada_w1_b, self.hada_w1_a,
self.hada_t2, self.hada_w2_b, self.hada_w2_a, scale
)
elif self.conv_mode == "flat":
scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
diff = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
return diff.reshape(self.out_dim, self.in_dim, *self.kernel_size)
else:
scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
return HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
def forward(self, x):
org_forwarded = self.org_forward(x)
@@ -130,16 +258,25 @@ class LoHaModule(torch.nn.Module):
# rank dropout (applied on output dimension)
if self.rank_dropout is not None and self.training:
drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
drop = drop.view(-1, 1)
drop = drop.view(-1, *([1] * (diff_weight.dim() - 1)))
diff_weight = diff_weight * drop
scale = 1.0 / (1.0 - self.rank_dropout)
else:
scale = 1.0
if self.is_conv:
# Conv2d 1x1: reshape to 4D for conv operation
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
return org_forwarded + F.conv2d(x, diff_weight) * self.multiplier * scale
if self.conv_mode == "1x1":
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
return org_forwarded + F.conv2d(
x, diff_weight, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups
) * self.multiplier * scale
else:
# Conv2d 3x3+: diff_weight is already 4D from get_diff_weight
return org_forwarded + F.conv2d(
x, diff_weight, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups
) * self.multiplier * scale
else:
return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
@@ -164,8 +301,9 @@ class LoHaInfModule(LoHaModule):
alpha=1,
**kwargs,
):
# no dropout for inference
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
# no dropout for inference; pass use_tucker from kwargs
use_tucker = kwargs.pop("use_tucker", False)
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, use_tucker=use_tucker)
self.org_module_ref = [org_module]
self.enabled = True
@@ -193,11 +331,18 @@ class LoHaInfModule(LoHaModule):
w2a = sd["hada_w2_a"].to(torch.float).to(device)
w2b = sd["hada_w2_b"].to(torch.float).to(device)
# compute ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale
if self.is_conv:
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
if self.tucker:
# Tucker mode
t1 = sd["hada_t1"].to(torch.float).to(device)
t2 = sd["hada_t2"].to(torch.float).to(device)
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
diff_weight = rebuild1 * rebuild2 * self.scale
else:
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale
# reshape diff_weight to match original weight shape if needed
if diff_weight.shape != weight.shape:
diff_weight = diff_weight.reshape(weight.shape)
weight = weight.to(device) + self.multiplier * diff_weight
@@ -208,23 +353,40 @@ class LoHaInfModule(LoHaModule):
if multiplier is None:
multiplier = self.multiplier
w1a = self.hada_w1_a.to(torch.float)
w1b = self.hada_w1_b.to(torch.float)
w2a = self.hada_w2_a.to(torch.float)
w2b = self.hada_w2_b.to(torch.float)
if self.tucker:
t1 = self.hada_t1.to(torch.float)
w1a = self.hada_w1_a.to(torch.float)
w1b = self.hada_w1_b.to(torch.float)
t2 = self.hada_t2.to(torch.float)
w2a = self.hada_w2_a.to(torch.float)
w2b = self.hada_w2_b.to(torch.float)
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
weight = rebuild1 * rebuild2 * self.scale * multiplier
else:
w1a = self.hada_w1_a.to(torch.float)
w1b = self.hada_w1_b.to(torch.float)
w2a = self.hada_w2_a.to(torch.float)
w2b = self.hada_w2_b.to(torch.float)
weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale * multiplier
weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale * multiplier
if self.is_conv:
weight = weight.unsqueeze(2).unsqueeze(3)
if self.is_conv:
if self.conv_mode == "1x1":
weight = weight.unsqueeze(2).unsqueeze(3)
elif self.conv_mode == "flat":
weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size)
return weight
def default_forward(self, x):
diff_weight = self.get_diff_weight()
if self.is_conv:
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
return self.org_forward(x) + F.conv2d(x, diff_weight) * self.multiplier
if self.conv_mode == "1x1":
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
return self.org_forward(x) + F.conv2d(
x, diff_weight, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups
) * self.multiplier
else:
return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
@@ -288,7 +450,7 @@ def create_network(
if module_dropout is not None:
module_dropout = float(module_dropout)
# conv dim/alpha (for future Conv2d 3x3 support)
# conv dim/alpha for Conv2d 3x3
conv_lora_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get("conv_alpha", None)
if conv_lora_dim is not None:
@@ -298,6 +460,11 @@ def create_network(
else:
conv_alpha = float(conv_alpha)
# Tucker decomposition for Conv2d 3x3
use_tucker = kwargs.get("use_tucker", "false")
if use_tucker is not None:
use_tucker = True if str(use_tucker).lower() == "true" else False
# verbose
verbose = kwargs.get("verbose", "false")
if verbose is not None:
@@ -321,6 +488,7 @@ def create_network(
rank_dropout=rank_dropout,
module_dropout=module_dropout,
module_class=LoHaModule,
module_kwargs={"use_tucker": use_tucker},
conv_lora_dim=conv_lora_dim,
conv_alpha=conv_alpha,
train_llm_adapter=train_llm_adapter,
@@ -372,6 +540,9 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
if "llm_adapter" in lora_name:
train_llm_adapter = True
# detect Tucker mode from weights
use_tucker = any("hada_t1" in key for key in weights_sd.keys())
# handle text_encoder as list
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
@@ -379,6 +550,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
arch_config = detect_arch_config(unet, text_encoders)
module_class = LoHaInfModule if for_inference else LoHaModule
module_kwargs = {"use_tucker": use_tucker}
network = AdditionalNetwork(
text_encoders,
@@ -388,6 +560,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
module_kwargs=module_kwargs,
train_llm_adapter=train_llm_adapter,
)
return network, weights_sd
@@ -403,6 +576,7 @@ def merge_weights_to_tensor(
) -> torch.Tensor:
"""Merge LoHa weights directly into a model weight tensor.
Supports standard LoHa, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3.
No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
Returns model_weight unchanged if no matching LoHa keys found.
"""
@@ -410,6 +584,8 @@ def merge_weights_to_tensor(
w1b_key = lora_name + ".hada_w1_b"
w2a_key = lora_name + ".hada_w2_a"
w2b_key = lora_name + ".hada_w2_b"
t1_key = lora_name + ".hada_t1"
t2_key = lora_name + ".hada_t2"
alpha_key = lora_name + ".alpha"
if w1a_key not in lora_weight_keys:
@@ -420,6 +596,8 @@ def merge_weights_to_tensor(
w2a = lora_sd[w2a_key].to(calc_device)
w2b = lora_sd[w2b_key].to(calc_device)
has_tucker = t1_key in lora_weight_keys
dim = w1b.shape[0]
alpha = lora_sd.get(alpha_key, torch.tensor(dim))
if isinstance(alpha, torch.Tensor):
@@ -429,14 +607,26 @@ def merge_weights_to_tensor(
original_dtype = model_weight.dtype
if original_dtype.itemsize == 1: # fp8
model_weight = model_weight.to(torch.float16)
w1a, w1b, w2a, w2b = w1a.to(torch.float16), w1b.to(torch.float16), w2a.to(torch.float16), w2b.to(torch.float16)
w1a, w1b = w1a.to(torch.float16), w1b.to(torch.float16)
w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16)
# ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
if has_tucker:
# Tucker decomposition: rebuild via einsum
t1 = lora_sd[t1_key].to(calc_device)
t2 = lora_sd[t2_key].to(calc_device)
if original_dtype.itemsize == 1:
t1, t2 = t1.to(torch.float16), t2.to(torch.float16)
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
diff_weight = rebuild1 * rebuild2 * scale
else:
# Standard LoHa: ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
# handle Conv2d 1x1 weights (4D tensors)
if len(model_weight.shape) == 4:
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
# Reshape diff_weight to match model_weight shape if needed
# (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.)
if diff_weight.shape != model_weight.shape:
diff_weight = diff_weight.reshape(model_weight.shape)
model_weight = model_weight + multiplier * diff_weight
@@ -444,7 +634,10 @@ def merge_weights_to_tensor(
model_weight = model_weight.to(original_dtype)
# remove consumed keys
for key in [w1a_key, w1b_key, w2a_key, w2b_key, alpha_key]:
consumed = [w1a_key, w1b_key, w2a_key, w2b_key, alpha_key]
if has_tucker:
consumed.extend([t1_key, t2_key])
for key in consumed:
lora_weight_keys.discard(key)
return model_weight

View File

@@ -68,6 +68,14 @@ def make_kron(w1, w2, scale):
return rebuild
def rebuild_tucker(t, wa, wb):
"""Rebuild weight from Tucker decomposition: einsum("i j ..., i p, j r -> p r ...", t, wa, wb).
Compatible with LyCORIS convention.
"""
return torch.einsum("i j ..., i p, j r -> p r ...", t, wa, wb)
class LoKrModule(torch.nn.Module):
"""LoKr module for training. Replaces forward method of the original Linear/Conv2d."""
@@ -82,6 +90,7 @@ class LoKrModule(torch.nn.Module):
rank_dropout=None,
module_dropout=None,
factor=-1,
use_tucker=False,
**kwargs,
):
super().__init__()
@@ -92,13 +101,32 @@ class LoKrModule(torch.nn.Module):
if is_conv2d:
in_dim = org_module.in_channels
out_dim = org_module.out_channels
kernel_size = org_module.kernel_size
self.is_conv = True
if org_module.kernel_size != (1, 1):
raise ValueError("LoKr Conv2d 3x3 (Tucker decomposition) is not supported yet")
self.stride = org_module.stride
self.padding = org_module.padding
self.dilation = org_module.dilation
self.groups = org_module.groups
self.kernel_size = kernel_size
self.tucker = use_tucker and any(k != 1 for k in kernel_size)
if kernel_size == (1, 1):
self.conv_mode = "1x1"
elif self.tucker:
self.conv_mode = "tucker"
else:
self.conv_mode = "flat"
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.is_conv = False
self.tucker = False
self.conv_mode = None
self.kernel_size = None
self.in_dim = in_dim
self.out_dim = out_dim
factor = int(factor)
self.use_w2 = False
@@ -110,18 +138,44 @@ class LoKrModule(torch.nn.Module):
# w1 is always a full matrix (the "scale" factor, small)
self.lokr_w1 = nn.Parameter(torch.empty(out_l, in_m))
# w2: low-rank decomposition if rank is small enough, otherwise full matrix
if lora_dim < max(out_k, in_n) / 2:
self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n))
else:
self.use_w2 = True
self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n))
# w2: depends on mode
if self.conv_mode in ("tucker", "flat"):
# Conv2d 3x3+ modes
k_size = kernel_size
if lora_dim >= max(out_k, in_n) / 2:
# Full matrix mode (includes kernel dimensions)
self.use_w2 = True
self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n, *k_size))
logger.warning(
f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} "
f"and factor={factor}, using full matrix mode."
f"and factor={factor}, using full matrix mode for Conv2d."
)
elif self.tucker:
# Tucker mode: separate kernel into t2 tensor
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *k_size))
self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, out_k))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n))
else:
# Non-Tucker: flatten kernel into w2_b
k_prod = 1
for k in k_size:
k_prod *= k
self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n * k_prod))
else:
# Linear or Conv2d 1x1
if lora_dim < max(out_k, in_n) / 2:
self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n))
else:
self.use_w2 = True
self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n))
if lora_dim >= max(out_k, in_n) / 2:
logger.warning(
f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} "
f"and factor={factor}, using full matrix mode."
)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy()
@@ -137,6 +191,8 @@ class LoKrModule(torch.nn.Module):
if self.use_w2:
torch.nn.init.constant_(self.lokr_w2, 0)
else:
if self.tucker:
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
torch.nn.init.constant_(self.lokr_w2_b, 0)
# Ensures ΔW = kron(w1, 0) = 0 at init
@@ -153,13 +209,30 @@ class LoKrModule(torch.nn.Module):
del self.org_module
def get_diff_weight(self):
"""Return materialized weight delta."""
"""Return materialized weight delta.
Returns:
- Linear: 2D tensor (out_dim, in_dim)
- Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d
- Conv2d 3x3+ Tucker/full: 4D tensor (out_dim, in_dim, k1, k2)
- Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2) — reshaped from 2D
"""
w1 = self.lokr_w1
if self.use_w2:
w2 = self.lokr_w2
elif self.tucker:
w2 = rebuild_tucker(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b)
else:
w2 = self.lokr_w2_a @ self.lokr_w2_b
return make_kron(w1, w2, self.scale)
result = make_kron(w1, w2, self.scale)
# For non-Tucker Conv2d 3x3+, result is 2D; reshape to 4D
if self.conv_mode == "flat" and result.dim() == 2:
result = result.reshape(self.out_dim, self.in_dim, *self.kernel_size)
return result
def forward(self, x):
org_forwarded = self.org_forward(x)
@@ -174,15 +247,25 @@ class LoKrModule(torch.nn.Module):
# rank dropout
if self.rank_dropout is not None and self.training:
drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
drop = drop.view(-1, 1)
drop = drop.view(-1, *([1] * (diff_weight.dim() - 1)))
diff_weight = diff_weight * drop
scale = 1.0 / (1.0 - self.rank_dropout)
else:
scale = 1.0
if self.is_conv:
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
return org_forwarded + F.conv2d(x, diff_weight) * self.multiplier * scale
if self.conv_mode == "1x1":
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
return org_forwarded + F.conv2d(
x, diff_weight, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups
) * self.multiplier * scale
else:
# Conv2d 3x3+: diff_weight is already 4D from get_diff_weight
return org_forwarded + F.conv2d(
x, diff_weight, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups
) * self.multiplier * scale
else:
return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
@@ -207,9 +290,10 @@ class LoKrInfModule(LoKrModule):
alpha=1,
**kwargs,
):
# no dropout for inference; pass factor from kwargs if present
# no dropout for inference; pass factor and use_tucker from kwargs
factor = kwargs.pop("factor", -1)
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, factor=factor)
use_tucker = kwargs.pop("use_tucker", False)
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, factor=factor, use_tucker=use_tucker)
self.org_module_ref = [org_module]
self.enabled = True
@@ -236,6 +320,12 @@ class LoKrInfModule(LoKrModule):
if "lokr_w2" in sd:
w2 = sd["lokr_w2"].to(torch.float).to(device)
elif "lokr_t2" in sd:
# Tucker mode
t2 = sd["lokr_t2"].to(torch.float).to(device)
w2a = sd["lokr_w2_a"].to(torch.float).to(device)
w2b = sd["lokr_w2_b"].to(torch.float).to(device)
w2 = rebuild_tucker(t2, w2a, w2b)
else:
w2a = sd["lokr_w2_a"].to(torch.float).to(device)
w2b = sd["lokr_w2_b"].to(torch.float).to(device)
@@ -244,8 +334,9 @@ class LoKrInfModule(LoKrModule):
# compute ΔW via Kronecker product
diff_weight = make_kron(w1, w2, self.scale)
if self.is_conv:
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
# reshape diff_weight to match original weight shape if needed
if diff_weight.shape != weight.shape:
diff_weight = diff_weight.reshape(weight.shape)
weight = weight.to(device) + self.multiplier * diff_weight
@@ -257,23 +348,39 @@ class LoKrInfModule(LoKrModule):
multiplier = self.multiplier
w1 = self.lokr_w1.to(torch.float)
if self.use_w2:
w2 = self.lokr_w2.to(torch.float)
elif self.tucker:
w2 = rebuild_tucker(
self.lokr_t2.to(torch.float),
self.lokr_w2_a.to(torch.float),
self.lokr_w2_b.to(torch.float),
)
else:
w2 = (self.lokr_w2_a @ self.lokr_w2_b).to(torch.float)
weight = make_kron(w1, w2, self.scale) * multiplier
# reshape to match original weight shape if needed
if self.is_conv:
weight = weight.unsqueeze(2).unsqueeze(3)
if self.conv_mode == "1x1":
weight = weight.unsqueeze(2).unsqueeze(3)
elif self.conv_mode == "flat" and weight.dim() == 2:
weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size)
# Tucker and full matrix modes: already 4D from kron
return weight
def default_forward(self, x):
diff_weight = self.get_diff_weight()
if self.is_conv:
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
return self.org_forward(x) + F.conv2d(x, diff_weight) * self.multiplier
if self.conv_mode == "1x1":
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
return self.org_forward(x) + F.conv2d(
x, diff_weight, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups
) * self.multiplier
else:
return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
@@ -337,7 +444,7 @@ def create_network(
if module_dropout is not None:
module_dropout = float(module_dropout)
# conv dim/alpha (for future Conv2d 3x3 support)
# conv dim/alpha for Conv2d 3x3
conv_lora_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get("conv_alpha", None)
if conv_lora_dim is not None:
@@ -347,6 +454,11 @@ def create_network(
else:
conv_alpha = float(conv_alpha)
# Tucker decomposition for Conv2d 3x3
use_tucker = kwargs.get("use_tucker", "false")
if use_tucker is not None:
use_tucker = True if str(use_tucker).lower() == "true" else False
# factor for LoKr
factor = int(kwargs.get("factor", -1))
@@ -373,7 +485,7 @@ def create_network(
rank_dropout=rank_dropout,
module_dropout=module_dropout,
module_class=LoKrModule,
module_kwargs={"factor": factor},
module_kwargs={"factor": factor, "use_tucker": use_tucker},
conv_lora_dim=conv_lora_dim,
conv_alpha=conv_alpha,
train_llm_adapter=train_llm_adapter,
@@ -411,6 +523,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
modules_dim = {}
modules_alpha = {}
train_llm_adapter = False
use_tucker = False
for key, value in weights_sd.items():
if "." not in key:
continue
@@ -419,13 +532,21 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
if "alpha" in key:
modules_alpha[lora_name] = value
elif "lokr_w2_a" in key:
# low-rank mode: dim = w2_a.shape[1]
dim = value.shape[1]
# low-rank mode: dim detection depends on Tucker vs non-Tucker
if "lokr_t2" in key.replace("lokr_w2_a", "lokr_t2") and lora_name + ".lokr_t2" in weights_sd:
# Tucker: w2_a = (rank, out_k) → dim = w2_a.shape[0]
dim = value.shape[0]
else:
# Non-Tucker: w2_a = (out_k, rank) → dim = w2_a.shape[1]
dim = value.shape[1]
modules_dim[lora_name] = dim
elif "lokr_w2" in key and "lokr_w2_a" not in key and "lokr_w2_b" not in key:
# full matrix mode: set dim large enough to trigger full-matrix path
if lora_name not in modules_dim:
modules_dim[lora_name] = max(value.shape)
modules_dim[lora_name] = max(value.shape[0], value.shape[1])
if "lokr_t2" in key:
use_tucker = True
if "llm_adapter" in lora_name:
train_llm_adapter = True
@@ -440,7 +561,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
factor = int(kwargs.get("factor", -1))
module_class = LoKrInfModule if for_inference else LoKrModule
module_kwargs = {"factor": factor}
module_kwargs = {"factor": factor, "use_tucker": use_tucker}
network = AdditionalNetwork(
text_encoders,
@@ -466,6 +587,7 @@ def merge_weights_to_tensor(
) -> torch.Tensor:
"""Merge LoKr weights directly into a model weight tensor.
Supports standard LoKr, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3.
No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
Returns model_weight unchanged if no matching LoKr keys found.
"""
@@ -473,6 +595,7 @@ def merge_weights_to_tensor(
w2_key = lora_name + ".lokr_w2"
w2a_key = lora_name + ".lokr_w2_a"
w2b_key = lora_name + ".lokr_w2_b"
t2_key = lora_name + ".lokr_t2"
alpha_key = lora_name + ".alpha"
if w1_key not in lora_weight_keys:
@@ -480,13 +603,23 @@ def merge_weights_to_tensor(
w1 = lora_sd[w1_key].to(calc_device)
# determine low-rank vs full matrix mode
# determine mode: full matrix vs Tucker vs low-rank
has_tucker = t2_key in lora_weight_keys
if w2a_key in lora_weight_keys:
# low-rank: w2 = w2_a @ w2_b
w2a = lora_sd[w2a_key].to(calc_device)
w2b = lora_sd[w2b_key].to(calc_device)
dim = w2a.shape[1]
if has_tucker:
# Tucker: w2a = (rank, out_k), dim = rank
dim = w2a.shape[0]
else:
# Non-Tucker low-rank: w2a = (out_k, rank), dim = rank
dim = w2a.shape[1]
consumed_keys = [w1_key, w2a_key, w2b_key, alpha_key]
if has_tucker:
consumed_keys.append(t2_key)
elif w2_key in lora_weight_keys:
# full matrix mode
w2a = None
@@ -502,7 +635,6 @@ def merge_weights_to_tensor(
# compute scale
if w2a is not None:
# low-rank mode
if alpha is None:
alpha = dim
scale = alpha / dim
@@ -519,7 +651,13 @@ def merge_weights_to_tensor(
# compute w2
if w2a is not None:
w2 = w2a @ w2b
if has_tucker:
t2 = lora_sd[t2_key].to(calc_device)
if original_dtype.itemsize == 1:
t2 = t2.to(torch.float16)
w2 = rebuild_tucker(t2, w2a, w2b)
else:
w2 = w2a @ w2b
else:
w2 = lora_sd[w2_key].to(calc_device)
if original_dtype.itemsize == 1:
@@ -528,9 +666,10 @@ def merge_weights_to_tensor(
# ΔW = kron(w1, w2) * scale
diff_weight = make_kron(w1, w2, scale)
# handle Conv2d 1x1 weights (4D tensors)
if len(model_weight.shape) == 4:
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
# Reshape diff_weight to match model_weight shape if needed
# (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.)
if diff_weight.shape != model_weight.shape:
diff_weight = diff_weight.reshape(model_weight.shape)
model_weight = model_weight + multiplier * diff_weight

View File

@@ -25,6 +25,7 @@ class ArchConfig:
te_prefixes: List[str]
default_excludes: List[str] = field(default_factory=list)
adapter_target_modules: List[str] = field(default_factory=list)
unet_conv_target_modules: List[str] = field(default_factory=list)
def detect_arch_config(unet, text_encoders) -> ArchConfig:
@@ -39,6 +40,7 @@ def detect_arch_config(unet, text_encoders) -> ArchConfig:
unet_prefix="lora_unet",
te_prefixes=["lora_te1", "lora_te2"],
default_excludes=[],
unet_conv_target_modules=["ResnetBlock2D", "Downsample2D", "Upsample2D"],
)
# Check Anima: look for Block class in named_modules
@@ -258,6 +260,8 @@ class AdditionalNetwork(torch.nn.Module):
# Create modules for UNet/DiT
target_modules = list(arch_config.unet_target_modules)
if modules_dim is not None or conv_lora_dim is not None:
target_modules.extend(arch_config.unet_conv_target_modules)
if train_llm_adapter and arch_config.adapter_target_modules:
target_modules.extend(arch_config.adapter_target_modules)