diff --git a/networks/loha.py b/networks/loha.py index 355e3734..8734f9c5 100644 --- a/networks/loha.py +++ b/networks/loha.py @@ -51,6 +51,63 @@ class HadaWeight(torch.autograd.Function): return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None +class HadaWeightTucker(torch.autograd.Function): + """Tucker-decomposed Hadamard product forward/backward for LoHa Conv2d 3x3+. + + Computes (rebuild(t1, w1b, w1a) * rebuild(t2, w2b, w2a)) * scale + where rebuild = einsum("i j ..., j r, i p -> p r ...", t, wb, wa). + Compatible with LyCORIS parameter naming convention. + """ + + @staticmethod + def forward(ctx, t1, w1b, w1a, t2, w2b, w2a, scale=None): + if scale is None: + scale = torch.tensor(1, device=t1.device, dtype=t1.dtype) + ctx.save_for_backward(t1, w1b, w1a, t2, w2b, w2a, scale) + + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a) + + return rebuild1 * rebuild2 * scale + + @staticmethod + def backward(ctx, grad_out): + (t1, w1b, w1a, t2, w2b, w2a, scale) = ctx.saved_tensors + grad_out = grad_out * scale + + # Gradients for w1a, w1b, t1 (using rebuild2) + temp = torch.einsum("i j ..., j r -> i r ...", t2, w2b) + rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2a) + + grad_w = rebuild * grad_out + del rebuild + + grad_w1a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) + grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1a.T) + del grad_w, temp + + grad_w1b = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp) + grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1b.T) + del grad_temp + + # Gradients for w2a, w2b, t2 (using rebuild1) + temp = torch.einsum("i j ..., j r -> i r ...", t1, w1b) + rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1a) + + grad_w = rebuild * grad_out + del rebuild + + grad_w2a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) + grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2a.T) + del grad_w, temp + + grad_w2b = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp) + grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2b.T) + del grad_temp + + return grad_t1, grad_w1b, grad_w1a, grad_t2, grad_w2b, grad_w2a, None + + class LoHaModule(torch.nn.Module): """LoHa module for training. Replaces forward method of the original Linear/Conv2d.""" @@ -64,6 +121,7 @@ class LoHaModule(torch.nn.Module): dropout=None, rank_dropout=None, module_dropout=None, + use_tucker=False, **kwargs, ): super().__init__() @@ -74,26 +132,78 @@ class LoHaModule(torch.nn.Module): if is_conv2d: in_dim = org_module.in_channels out_dim = org_module.out_channels + kernel_size = org_module.kernel_size self.is_conv = True - if org_module.kernel_size != (1, 1): - raise ValueError("LoHa Conv2d 3x3 (Tucker decomposition) is not supported yet") + self.stride = org_module.stride + self.padding = org_module.padding + self.dilation = org_module.dilation + self.groups = org_module.groups + self.kernel_size = kernel_size + + self.tucker = use_tucker and any(k != 1 for k in kernel_size) + + if kernel_size == (1, 1): + self.conv_mode = "1x1" + elif self.tucker: + self.conv_mode = "tucker" + else: + self.conv_mode = "flat" else: in_dim = org_module.in_features out_dim = org_module.out_features self.is_conv = False + self.tucker = False + self.conv_mode = None + self.kernel_size = None - # Hadamard product parameters: ΔW = (w1a @ w1b) * (w2a @ w2b) - self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim)) - self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim)) - self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim)) - self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim)) + self.in_dim = in_dim + self.out_dim = out_dim - # Initialization: w1_a normal(0.1), w1_b normal(1.0), w2_a = 0, w2_b normal(1.0) - # Ensures ΔW = 0 at init since w2_a = 0 - torch.nn.init.normal_(self.hada_w1_a, std=0.1) - torch.nn.init.normal_(self.hada_w1_b, std=1.0) - torch.nn.init.constant_(self.hada_w2_a, 0) - torch.nn.init.normal_(self.hada_w2_b, std=1.0) + # Create parameters based on mode + if self.conv_mode == "tucker": + # Tucker decomposition for Conv2d 3x3+ + # Shapes follow LyCORIS convention: w_a = (rank, out_dim), w_b = (rank, in_dim) + self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size)) + self.hada_w1_a = nn.Parameter(torch.empty(lora_dim, out_dim)) + self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim)) + self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size)) + self.hada_w2_a = nn.Parameter(torch.empty(lora_dim, out_dim)) + self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim)) + + # LyCORIS init: w1_a = 0 (ensures ΔW=0), t1/t2 normal(0.1) + torch.nn.init.normal_(self.hada_t1, std=0.1) + torch.nn.init.normal_(self.hada_t2, std=0.1) + torch.nn.init.normal_(self.hada_w1_b, std=1.0) + torch.nn.init.constant_(self.hada_w1_a, 0) + torch.nn.init.normal_(self.hada_w2_b, std=1.0) + torch.nn.init.normal_(self.hada_w2_a, std=0.1) + elif self.conv_mode == "flat": + # Non-Tucker Conv2d 3x3+: flatten kernel into in_dim + k_prod = 1 + for k in kernel_size: + k_prod *= k + flat_in = in_dim * k_prod + + self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim)) + self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, flat_in)) + self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim)) + self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, flat_in)) + + torch.nn.init.normal_(self.hada_w1_a, std=0.1) + torch.nn.init.normal_(self.hada_w1_b, std=1.0) + torch.nn.init.constant_(self.hada_w2_a, 0) + torch.nn.init.normal_(self.hada_w2_b, std=1.0) + else: + # Linear or Conv2d 1x1 + self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim)) + self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim)) + self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim)) + self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim)) + + torch.nn.init.normal_(self.hada_w1_a, std=0.1) + torch.nn.init.normal_(self.hada_w1_b, std=1.0) + torch.nn.init.constant_(self.hada_w2_a, 0) + torch.nn.init.normal_(self.hada_w2_b, std=1.0) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() @@ -113,9 +223,27 @@ class LoHaModule(torch.nn.Module): del self.org_module def get_diff_weight(self): - """Return materialized weight delta as a 2D matrix.""" - scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device) - return HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale) + """Return materialized weight delta. + + Returns: + - Linear: 2D tensor (out_dim, in_dim) + - Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d + - Conv2d 3x3+ Tucker: 4D tensor (out_dim, in_dim, k1, k2) + - Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2) + """ + if self.tucker: + scale = torch.tensor(self.scale, dtype=self.hada_t1.dtype, device=self.hada_t1.device) + return HadaWeightTucker.apply( + self.hada_t1, self.hada_w1_b, self.hada_w1_a, + self.hada_t2, self.hada_w2_b, self.hada_w2_a, scale + ) + elif self.conv_mode == "flat": + scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device) + diff = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale) + return diff.reshape(self.out_dim, self.in_dim, *self.kernel_size) + else: + scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device) + return HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale) def forward(self, x): org_forwarded = self.org_forward(x) @@ -130,16 +258,25 @@ class LoHaModule(torch.nn.Module): # rank dropout (applied on output dimension) if self.rank_dropout is not None and self.training: drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype) - drop = drop.view(-1, 1) + drop = drop.view(-1, *([1] * (diff_weight.dim() - 1))) diff_weight = diff_weight * drop scale = 1.0 / (1.0 - self.rank_dropout) else: scale = 1.0 if self.is_conv: - # Conv2d 1x1: reshape to 4D for conv operation - diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) - return org_forwarded + F.conv2d(x, diff_weight) * self.multiplier * scale + if self.conv_mode == "1x1": + diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) + return org_forwarded + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier * scale + else: + # Conv2d 3x3+: diff_weight is already 4D from get_diff_weight + return org_forwarded + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier * scale else: return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale @@ -164,8 +301,9 @@ class LoHaInfModule(LoHaModule): alpha=1, **kwargs, ): - # no dropout for inference - super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + # no dropout for inference; pass use_tucker from kwargs + use_tucker = kwargs.pop("use_tucker", False) + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, use_tucker=use_tucker) self.org_module_ref = [org_module] self.enabled = True @@ -193,11 +331,18 @@ class LoHaInfModule(LoHaModule): w2a = sd["hada_w2_a"].to(torch.float).to(device) w2b = sd["hada_w2_b"].to(torch.float).to(device) - # compute ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale - diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale - - if self.is_conv: - diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) + if self.tucker: + # Tucker mode + t1 = sd["hada_t1"].to(torch.float).to(device) + t2 = sd["hada_t2"].to(torch.float).to(device) + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a) + diff_weight = rebuild1 * rebuild2 * self.scale + else: + diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale + # reshape diff_weight to match original weight shape if needed + if diff_weight.shape != weight.shape: + diff_weight = diff_weight.reshape(weight.shape) weight = weight.to(device) + self.multiplier * diff_weight @@ -208,23 +353,40 @@ class LoHaInfModule(LoHaModule): if multiplier is None: multiplier = self.multiplier - w1a = self.hada_w1_a.to(torch.float) - w1b = self.hada_w1_b.to(torch.float) - w2a = self.hada_w2_a.to(torch.float) - w2b = self.hada_w2_b.to(torch.float) + if self.tucker: + t1 = self.hada_t1.to(torch.float) + w1a = self.hada_w1_a.to(torch.float) + w1b = self.hada_w1_b.to(torch.float) + t2 = self.hada_t2.to(torch.float) + w2a = self.hada_w2_a.to(torch.float) + w2b = self.hada_w2_b.to(torch.float) + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a) + weight = rebuild1 * rebuild2 * self.scale * multiplier + else: + w1a = self.hada_w1_a.to(torch.float) + w1b = self.hada_w1_b.to(torch.float) + w2a = self.hada_w2_a.to(torch.float) + w2b = self.hada_w2_b.to(torch.float) + weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale * multiplier - weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale * multiplier - - if self.is_conv: - weight = weight.unsqueeze(2).unsqueeze(3) + if self.is_conv: + if self.conv_mode == "1x1": + weight = weight.unsqueeze(2).unsqueeze(3) + elif self.conv_mode == "flat": + weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size) return weight def default_forward(self, x): diff_weight = self.get_diff_weight() if self.is_conv: - diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) - return self.org_forward(x) + F.conv2d(x, diff_weight) * self.multiplier + if self.conv_mode == "1x1": + diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) + return self.org_forward(x) + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier else: return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier @@ -288,7 +450,7 @@ def create_network( if module_dropout is not None: module_dropout = float(module_dropout) - # conv dim/alpha (for future Conv2d 3x3 support) + # conv dim/alpha for Conv2d 3x3 conv_lora_dim = kwargs.get("conv_dim", None) conv_alpha = kwargs.get("conv_alpha", None) if conv_lora_dim is not None: @@ -298,6 +460,11 @@ def create_network( else: conv_alpha = float(conv_alpha) + # Tucker decomposition for Conv2d 3x3 + use_tucker = kwargs.get("use_tucker", "false") + if use_tucker is not None: + use_tucker = True if str(use_tucker).lower() == "true" else False + # verbose verbose = kwargs.get("verbose", "false") if verbose is not None: @@ -321,6 +488,7 @@ def create_network( rank_dropout=rank_dropout, module_dropout=module_dropout, module_class=LoHaModule, + module_kwargs={"use_tucker": use_tucker}, conv_lora_dim=conv_lora_dim, conv_alpha=conv_alpha, train_llm_adapter=train_llm_adapter, @@ -372,6 +540,9 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh if "llm_adapter" in lora_name: train_llm_adapter = True + # detect Tucker mode from weights + use_tucker = any("hada_t1" in key for key in weights_sd.keys()) + # handle text_encoder as list text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] @@ -379,6 +550,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh arch_config = detect_arch_config(unet, text_encoders) module_class = LoHaInfModule if for_inference else LoHaModule + module_kwargs = {"use_tucker": use_tucker} network = AdditionalNetwork( text_encoders, @@ -388,6 +560,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class, + module_kwargs=module_kwargs, train_llm_adapter=train_llm_adapter, ) return network, weights_sd @@ -403,6 +576,7 @@ def merge_weights_to_tensor( ) -> torch.Tensor: """Merge LoHa weights directly into a model weight tensor. + Supports standard LoHa, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3. No Module/Network creation needed. Consumed keys are removed from lora_weight_keys. Returns model_weight unchanged if no matching LoHa keys found. """ @@ -410,6 +584,8 @@ def merge_weights_to_tensor( w1b_key = lora_name + ".hada_w1_b" w2a_key = lora_name + ".hada_w2_a" w2b_key = lora_name + ".hada_w2_b" + t1_key = lora_name + ".hada_t1" + t2_key = lora_name + ".hada_t2" alpha_key = lora_name + ".alpha" if w1a_key not in lora_weight_keys: @@ -420,6 +596,8 @@ def merge_weights_to_tensor( w2a = lora_sd[w2a_key].to(calc_device) w2b = lora_sd[w2b_key].to(calc_device) + has_tucker = t1_key in lora_weight_keys + dim = w1b.shape[0] alpha = lora_sd.get(alpha_key, torch.tensor(dim)) if isinstance(alpha, torch.Tensor): @@ -429,14 +607,26 @@ def merge_weights_to_tensor( original_dtype = model_weight.dtype if original_dtype.itemsize == 1: # fp8 model_weight = model_weight.to(torch.float16) - w1a, w1b, w2a, w2b = w1a.to(torch.float16), w1b.to(torch.float16), w2a.to(torch.float16), w2b.to(torch.float16) + w1a, w1b = w1a.to(torch.float16), w1b.to(torch.float16) + w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16) - # ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale - diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale + if has_tucker: + # Tucker decomposition: rebuild via einsum + t1 = lora_sd[t1_key].to(calc_device) + t2 = lora_sd[t2_key].to(calc_device) + if original_dtype.itemsize == 1: + t1, t2 = t1.to(torch.float16), t2.to(torch.float16) + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a) + diff_weight = rebuild1 * rebuild2 * scale + else: + # Standard LoHa: ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale + diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale - # handle Conv2d 1x1 weights (4D tensors) - if len(model_weight.shape) == 4: - diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) + # Reshape diff_weight to match model_weight shape if needed + # (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.) + if diff_weight.shape != model_weight.shape: + diff_weight = diff_weight.reshape(model_weight.shape) model_weight = model_weight + multiplier * diff_weight @@ -444,7 +634,10 @@ def merge_weights_to_tensor( model_weight = model_weight.to(original_dtype) # remove consumed keys - for key in [w1a_key, w1b_key, w2a_key, w2b_key, alpha_key]: + consumed = [w1a_key, w1b_key, w2a_key, w2b_key, alpha_key] + if has_tucker: + consumed.extend([t1_key, t2_key]) + for key in consumed: lora_weight_keys.discard(key) return model_weight diff --git a/networks/lokr.py b/networks/lokr.py index 8f2457a5..03b50ca0 100644 --- a/networks/lokr.py +++ b/networks/lokr.py @@ -68,6 +68,14 @@ def make_kron(w1, w2, scale): return rebuild +def rebuild_tucker(t, wa, wb): + """Rebuild weight from Tucker decomposition: einsum("i j ..., i p, j r -> p r ...", t, wa, wb). + + Compatible with LyCORIS convention. + """ + return torch.einsum("i j ..., i p, j r -> p r ...", t, wa, wb) + + class LoKrModule(torch.nn.Module): """LoKr module for training. Replaces forward method of the original Linear/Conv2d.""" @@ -82,6 +90,7 @@ class LoKrModule(torch.nn.Module): rank_dropout=None, module_dropout=None, factor=-1, + use_tucker=False, **kwargs, ): super().__init__() @@ -92,13 +101,32 @@ class LoKrModule(torch.nn.Module): if is_conv2d: in_dim = org_module.in_channels out_dim = org_module.out_channels + kernel_size = org_module.kernel_size self.is_conv = True - if org_module.kernel_size != (1, 1): - raise ValueError("LoKr Conv2d 3x3 (Tucker decomposition) is not supported yet") + self.stride = org_module.stride + self.padding = org_module.padding + self.dilation = org_module.dilation + self.groups = org_module.groups + self.kernel_size = kernel_size + + self.tucker = use_tucker and any(k != 1 for k in kernel_size) + + if kernel_size == (1, 1): + self.conv_mode = "1x1" + elif self.tucker: + self.conv_mode = "tucker" + else: + self.conv_mode = "flat" else: in_dim = org_module.in_features out_dim = org_module.out_features self.is_conv = False + self.tucker = False + self.conv_mode = None + self.kernel_size = None + + self.in_dim = in_dim + self.out_dim = out_dim factor = int(factor) self.use_w2 = False @@ -110,18 +138,44 @@ class LoKrModule(torch.nn.Module): # w1 is always a full matrix (the "scale" factor, small) self.lokr_w1 = nn.Parameter(torch.empty(out_l, in_m)) - # w2: low-rank decomposition if rank is small enough, otherwise full matrix - if lora_dim < max(out_k, in_n) / 2: - self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim)) - self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n)) - else: - self.use_w2 = True - self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n)) + # w2: depends on mode + if self.conv_mode in ("tucker", "flat"): + # Conv2d 3x3+ modes + k_size = kernel_size + if lora_dim >= max(out_k, in_n) / 2: + # Full matrix mode (includes kernel dimensions) + self.use_w2 = True + self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n, *k_size)) logger.warning( f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} " - f"and factor={factor}, using full matrix mode." + f"and factor={factor}, using full matrix mode for Conv2d." ) + elif self.tucker: + # Tucker mode: separate kernel into t2 tensor + self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *k_size)) + self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, out_k)) + self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n)) + else: + # Non-Tucker: flatten kernel into w2_b + k_prod = 1 + for k in k_size: + k_prod *= k + self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim)) + self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n * k_prod)) + else: + # Linear or Conv2d 1x1 + if lora_dim < max(out_k, in_n) / 2: + self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim)) + self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n)) + else: + self.use_w2 = True + self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n)) + if lora_dim >= max(out_k, in_n) / 2: + logger.warning( + f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} " + f"and factor={factor}, using full matrix mode." + ) if type(alpha) == torch.Tensor: alpha = alpha.detach().float().numpy() @@ -137,6 +191,8 @@ class LoKrModule(torch.nn.Module): if self.use_w2: torch.nn.init.constant_(self.lokr_w2, 0) else: + if self.tucker: + torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5)) torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5)) torch.nn.init.constant_(self.lokr_w2_b, 0) # Ensures ΔW = kron(w1, 0) = 0 at init @@ -153,13 +209,30 @@ class LoKrModule(torch.nn.Module): del self.org_module def get_diff_weight(self): - """Return materialized weight delta.""" + """Return materialized weight delta. + + Returns: + - Linear: 2D tensor (out_dim, in_dim) + - Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d + - Conv2d 3x3+ Tucker/full: 4D tensor (out_dim, in_dim, k1, k2) + - Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2) — reshaped from 2D + """ w1 = self.lokr_w1 + if self.use_w2: w2 = self.lokr_w2 + elif self.tucker: + w2 = rebuild_tucker(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) else: w2 = self.lokr_w2_a @ self.lokr_w2_b - return make_kron(w1, w2, self.scale) + + result = make_kron(w1, w2, self.scale) + + # For non-Tucker Conv2d 3x3+, result is 2D; reshape to 4D + if self.conv_mode == "flat" and result.dim() == 2: + result = result.reshape(self.out_dim, self.in_dim, *self.kernel_size) + + return result def forward(self, x): org_forwarded = self.org_forward(x) @@ -174,15 +247,25 @@ class LoKrModule(torch.nn.Module): # rank dropout if self.rank_dropout is not None and self.training: drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype) - drop = drop.view(-1, 1) + drop = drop.view(-1, *([1] * (diff_weight.dim() - 1))) diff_weight = diff_weight * drop scale = 1.0 / (1.0 - self.rank_dropout) else: scale = 1.0 if self.is_conv: - diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) - return org_forwarded + F.conv2d(x, diff_weight) * self.multiplier * scale + if self.conv_mode == "1x1": + diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) + return org_forwarded + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier * scale + else: + # Conv2d 3x3+: diff_weight is already 4D from get_diff_weight + return org_forwarded + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier * scale else: return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale @@ -207,9 +290,10 @@ class LoKrInfModule(LoKrModule): alpha=1, **kwargs, ): - # no dropout for inference; pass factor from kwargs if present + # no dropout for inference; pass factor and use_tucker from kwargs factor = kwargs.pop("factor", -1) - super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, factor=factor) + use_tucker = kwargs.pop("use_tucker", False) + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, factor=factor, use_tucker=use_tucker) self.org_module_ref = [org_module] self.enabled = True @@ -236,6 +320,12 @@ class LoKrInfModule(LoKrModule): if "lokr_w2" in sd: w2 = sd["lokr_w2"].to(torch.float).to(device) + elif "lokr_t2" in sd: + # Tucker mode + t2 = sd["lokr_t2"].to(torch.float).to(device) + w2a = sd["lokr_w2_a"].to(torch.float).to(device) + w2b = sd["lokr_w2_b"].to(torch.float).to(device) + w2 = rebuild_tucker(t2, w2a, w2b) else: w2a = sd["lokr_w2_a"].to(torch.float).to(device) w2b = sd["lokr_w2_b"].to(torch.float).to(device) @@ -244,8 +334,9 @@ class LoKrInfModule(LoKrModule): # compute ΔW via Kronecker product diff_weight = make_kron(w1, w2, self.scale) - if self.is_conv: - diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) + # reshape diff_weight to match original weight shape if needed + if diff_weight.shape != weight.shape: + diff_weight = diff_weight.reshape(weight.shape) weight = weight.to(device) + self.multiplier * diff_weight @@ -257,23 +348,39 @@ class LoKrInfModule(LoKrModule): multiplier = self.multiplier w1 = self.lokr_w1.to(torch.float) + if self.use_w2: w2 = self.lokr_w2.to(torch.float) + elif self.tucker: + w2 = rebuild_tucker( + self.lokr_t2.to(torch.float), + self.lokr_w2_a.to(torch.float), + self.lokr_w2_b.to(torch.float), + ) else: w2 = (self.lokr_w2_a @ self.lokr_w2_b).to(torch.float) weight = make_kron(w1, w2, self.scale) * multiplier + # reshape to match original weight shape if needed if self.is_conv: - weight = weight.unsqueeze(2).unsqueeze(3) + if self.conv_mode == "1x1": + weight = weight.unsqueeze(2).unsqueeze(3) + elif self.conv_mode == "flat" and weight.dim() == 2: + weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size) + # Tucker and full matrix modes: already 4D from kron return weight def default_forward(self, x): diff_weight = self.get_diff_weight() if self.is_conv: - diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) - return self.org_forward(x) + F.conv2d(x, diff_weight) * self.multiplier + if self.conv_mode == "1x1": + diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) + return self.org_forward(x) + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier else: return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier @@ -337,7 +444,7 @@ def create_network( if module_dropout is not None: module_dropout = float(module_dropout) - # conv dim/alpha (for future Conv2d 3x3 support) + # conv dim/alpha for Conv2d 3x3 conv_lora_dim = kwargs.get("conv_dim", None) conv_alpha = kwargs.get("conv_alpha", None) if conv_lora_dim is not None: @@ -347,6 +454,11 @@ def create_network( else: conv_alpha = float(conv_alpha) + # Tucker decomposition for Conv2d 3x3 + use_tucker = kwargs.get("use_tucker", "false") + if use_tucker is not None: + use_tucker = True if str(use_tucker).lower() == "true" else False + # factor for LoKr factor = int(kwargs.get("factor", -1)) @@ -373,7 +485,7 @@ def create_network( rank_dropout=rank_dropout, module_dropout=module_dropout, module_class=LoKrModule, - module_kwargs={"factor": factor}, + module_kwargs={"factor": factor, "use_tucker": use_tucker}, conv_lora_dim=conv_lora_dim, conv_alpha=conv_alpha, train_llm_adapter=train_llm_adapter, @@ -411,6 +523,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh modules_dim = {} modules_alpha = {} train_llm_adapter = False + use_tucker = False for key, value in weights_sd.items(): if "." not in key: continue @@ -419,13 +532,21 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh if "alpha" in key: modules_alpha[lora_name] = value elif "lokr_w2_a" in key: - # low-rank mode: dim = w2_a.shape[1] - dim = value.shape[1] + # low-rank mode: dim detection depends on Tucker vs non-Tucker + if "lokr_t2" in key.replace("lokr_w2_a", "lokr_t2") and lora_name + ".lokr_t2" in weights_sd: + # Tucker: w2_a = (rank, out_k) → dim = w2_a.shape[0] + dim = value.shape[0] + else: + # Non-Tucker: w2_a = (out_k, rank) → dim = w2_a.shape[1] + dim = value.shape[1] modules_dim[lora_name] = dim elif "lokr_w2" in key and "lokr_w2_a" not in key and "lokr_w2_b" not in key: # full matrix mode: set dim large enough to trigger full-matrix path if lora_name not in modules_dim: - modules_dim[lora_name] = max(value.shape) + modules_dim[lora_name] = max(value.shape[0], value.shape[1]) + + if "lokr_t2" in key: + use_tucker = True if "llm_adapter" in lora_name: train_llm_adapter = True @@ -440,7 +561,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh factor = int(kwargs.get("factor", -1)) module_class = LoKrInfModule if for_inference else LoKrModule - module_kwargs = {"factor": factor} + module_kwargs = {"factor": factor, "use_tucker": use_tucker} network = AdditionalNetwork( text_encoders, @@ -466,6 +587,7 @@ def merge_weights_to_tensor( ) -> torch.Tensor: """Merge LoKr weights directly into a model weight tensor. + Supports standard LoKr, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3. No Module/Network creation needed. Consumed keys are removed from lora_weight_keys. Returns model_weight unchanged if no matching LoKr keys found. """ @@ -473,6 +595,7 @@ def merge_weights_to_tensor( w2_key = lora_name + ".lokr_w2" w2a_key = lora_name + ".lokr_w2_a" w2b_key = lora_name + ".lokr_w2_b" + t2_key = lora_name + ".lokr_t2" alpha_key = lora_name + ".alpha" if w1_key not in lora_weight_keys: @@ -480,13 +603,23 @@ def merge_weights_to_tensor( w1 = lora_sd[w1_key].to(calc_device) - # determine low-rank vs full matrix mode + # determine mode: full matrix vs Tucker vs low-rank + has_tucker = t2_key in lora_weight_keys + if w2a_key in lora_weight_keys: - # low-rank: w2 = w2_a @ w2_b w2a = lora_sd[w2a_key].to(calc_device) w2b = lora_sd[w2b_key].to(calc_device) - dim = w2a.shape[1] + + if has_tucker: + # Tucker: w2a = (rank, out_k), dim = rank + dim = w2a.shape[0] + else: + # Non-Tucker low-rank: w2a = (out_k, rank), dim = rank + dim = w2a.shape[1] + consumed_keys = [w1_key, w2a_key, w2b_key, alpha_key] + if has_tucker: + consumed_keys.append(t2_key) elif w2_key in lora_weight_keys: # full matrix mode w2a = None @@ -502,7 +635,6 @@ def merge_weights_to_tensor( # compute scale if w2a is not None: - # low-rank mode if alpha is None: alpha = dim scale = alpha / dim @@ -519,7 +651,13 @@ def merge_weights_to_tensor( # compute w2 if w2a is not None: - w2 = w2a @ w2b + if has_tucker: + t2 = lora_sd[t2_key].to(calc_device) + if original_dtype.itemsize == 1: + t2 = t2.to(torch.float16) + w2 = rebuild_tucker(t2, w2a, w2b) + else: + w2 = w2a @ w2b else: w2 = lora_sd[w2_key].to(calc_device) if original_dtype.itemsize == 1: @@ -528,9 +666,10 @@ def merge_weights_to_tensor( # ΔW = kron(w1, w2) * scale diff_weight = make_kron(w1, w2, scale) - # handle Conv2d 1x1 weights (4D tensors) - if len(model_weight.shape) == 4: - diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) + # Reshape diff_weight to match model_weight shape if needed + # (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.) + if diff_weight.shape != model_weight.shape: + diff_weight = diff_weight.reshape(model_weight.shape) model_weight = model_weight + multiplier * diff_weight diff --git a/networks/network_base.py b/networks/network_base.py index 6a05c203..ab8d25ab 100644 --- a/networks/network_base.py +++ b/networks/network_base.py @@ -25,6 +25,7 @@ class ArchConfig: te_prefixes: List[str] default_excludes: List[str] = field(default_factory=list) adapter_target_modules: List[str] = field(default_factory=list) + unet_conv_target_modules: List[str] = field(default_factory=list) def detect_arch_config(unet, text_encoders) -> ArchConfig: @@ -39,6 +40,7 @@ def detect_arch_config(unet, text_encoders) -> ArchConfig: unet_prefix="lora_unet", te_prefixes=["lora_te1", "lora_te2"], default_excludes=[], + unet_conv_target_modules=["ResnetBlock2D", "Downsample2D", "Upsample2D"], ) # Check Anima: look for Block class in named_modules @@ -258,6 +260,8 @@ class AdditionalNetwork(torch.nn.Module): # Create modules for UNet/DiT target_modules = list(arch_config.unet_target_modules) + if modules_dim is not None or conv_lora_dim is not None: + target_modules.extend(arch_config.unet_conv_target_modules) if train_llm_adapter and arch_config.adapter_target_modules: target_modules.extend(arch_config.adapter_target_modules)