mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
* feat: Add LoHa/LoKr network support for SDXL and Anima - networks/network_base.py: shared AdditionalNetwork base class with architecture auto-detection (SDXL/Anima) and generic module injection - networks/loha.py: LoHa (Low-rank Hadamard Product) module with HadaWeight custom autograd, training/inference classes, and factory functions - networks/lokr.py: LoKr (Low-rank Kronecker Product) module with factorization, training/inference classes, and factory functions - library/lora_utils.py: extend weight merge hook to detect and merge LoHa/LoKr weights alongside standard LoRA Linear and Conv2d 1x1 layers only; Conv2d 3x3 (Tucker decomposition) support will be added separately. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: Enhance LoHa and LoKr modules with Tucker decomposition support - Added Tucker decomposition functionality to LoHa and LoKr modules. - Implemented new methods for weight rebuilding using Tucker decomposition. - Updated initialization and weight handling for Conv2d 3x3+ layers. - Modified get_diff_weight methods to accommodate Tucker and non-Tucker modes. - Enhanced network base to include unet_conv_target_modules for architecture detection. * fix: rank dropout handling in LoRAModule for Conv2d and Linear layers, see #2272 for details * doc: add dtype comment for load_safetensors_with_lora_and_fp8 function * fix: enhance architecture detection to support InferSdxlUNet2DConditionModel for gen_img.py * doc: update model support structure to include Lumina Image 2.0, HunyuanImage-2.1, and Anima-Preview * doc: add documentation for LoHa and LoKr fine-tuning methods * Update networks/network_base.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update docs/loha_lokr.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: refactor LoHa and LoKr imports for weight merging in load_safetensors_with_lora_and_fp8 function --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
644 lines
24 KiB
Python
644 lines
24 KiB
Python
# LoHa (Low-rank Hadamard Product) network module
|
|
# Reference: https://arxiv.org/abs/2108.06098
|
|
#
|
|
# Based on the LyCORIS project by KohakuBlueleaf
|
|
# https://github.com/KohakuBlueleaf/LyCORIS
|
|
|
|
import ast
|
|
import os
|
|
import logging
|
|
from typing import Dict, List, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs
|
|
from library.utils import setup_logging
|
|
|
|
setup_logging()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class HadaWeight(torch.autograd.Function):
|
|
"""Efficient Hadamard product forward/backward for LoHa.
|
|
|
|
Computes ((w1a @ w1b) * (w2a @ w2b)) * scale with custom backward
|
|
that recomputes intermediates instead of storing them.
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, w1a, w1b, w2a, w2b, scale=None):
|
|
if scale is None:
|
|
scale = torch.tensor(1, device=w1a.device, dtype=w1a.dtype)
|
|
ctx.save_for_backward(w1a, w1b, w2a, w2b, scale)
|
|
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
|
return diff_weight
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
(w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors
|
|
grad_out = grad_out * scale
|
|
temp = grad_out * (w2a @ w2b)
|
|
grad_w1a = temp @ w1b.T
|
|
grad_w1b = w1a.T @ temp
|
|
|
|
temp = grad_out * (w1a @ w1b)
|
|
grad_w2a = temp @ w2b.T
|
|
grad_w2b = w2a.T @ temp
|
|
|
|
del temp
|
|
return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
|
|
|
|
|
|
class HadaWeightTucker(torch.autograd.Function):
|
|
"""Tucker-decomposed Hadamard product forward/backward for LoHa Conv2d 3x3+.
|
|
|
|
Computes (rebuild(t1, w1b, w1a) * rebuild(t2, w2b, w2a)) * scale
|
|
where rebuild = einsum("i j ..., j r, i p -> p r ...", t, wb, wa).
|
|
Compatible with LyCORIS parameter naming convention.
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, t1, w1b, w1a, t2, w2b, w2a, scale=None):
|
|
if scale is None:
|
|
scale = torch.tensor(1, device=t1.device, dtype=t1.dtype)
|
|
ctx.save_for_backward(t1, w1b, w1a, t2, w2b, w2a, scale)
|
|
|
|
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
|
|
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
|
|
|
|
return rebuild1 * rebuild2 * scale
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
(t1, w1b, w1a, t2, w2b, w2a, scale) = ctx.saved_tensors
|
|
grad_out = grad_out * scale
|
|
|
|
# Gradients for w1a, w1b, t1 (using rebuild2)
|
|
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2b)
|
|
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2a)
|
|
|
|
grad_w = rebuild * grad_out
|
|
del rebuild
|
|
|
|
grad_w1a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
|
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1a.T)
|
|
del grad_w, temp
|
|
|
|
grad_w1b = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
|
|
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1b.T)
|
|
del grad_temp
|
|
|
|
# Gradients for w2a, w2b, t2 (using rebuild1)
|
|
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1b)
|
|
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1a)
|
|
|
|
grad_w = rebuild * grad_out
|
|
del rebuild
|
|
|
|
grad_w2a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
|
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2a.T)
|
|
del grad_w, temp
|
|
|
|
grad_w2b = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
|
|
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2b.T)
|
|
del grad_temp
|
|
|
|
return grad_t1, grad_w1b, grad_w1a, grad_t2, grad_w2b, grad_w2a, None
|
|
|
|
|
|
class LoHaModule(torch.nn.Module):
|
|
"""LoHa module for training. Replaces forward method of the original Linear/Conv2d."""
|
|
|
|
def __init__(
|
|
self,
|
|
lora_name,
|
|
org_module: torch.nn.Module,
|
|
multiplier=1.0,
|
|
lora_dim=4,
|
|
alpha=1,
|
|
dropout=None,
|
|
rank_dropout=None,
|
|
module_dropout=None,
|
|
use_tucker=False,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.lora_name = lora_name
|
|
self.lora_dim = lora_dim
|
|
|
|
is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
|
if is_conv2d:
|
|
in_dim = org_module.in_channels
|
|
out_dim = org_module.out_channels
|
|
kernel_size = org_module.kernel_size
|
|
self.is_conv = True
|
|
self.stride = org_module.stride
|
|
self.padding = org_module.padding
|
|
self.dilation = org_module.dilation
|
|
self.groups = org_module.groups
|
|
self.kernel_size = kernel_size
|
|
|
|
self.tucker = use_tucker and any(k != 1 for k in kernel_size)
|
|
|
|
if kernel_size == (1, 1):
|
|
self.conv_mode = "1x1"
|
|
elif self.tucker:
|
|
self.conv_mode = "tucker"
|
|
else:
|
|
self.conv_mode = "flat"
|
|
else:
|
|
in_dim = org_module.in_features
|
|
out_dim = org_module.out_features
|
|
self.is_conv = False
|
|
self.tucker = False
|
|
self.conv_mode = None
|
|
self.kernel_size = None
|
|
|
|
self.in_dim = in_dim
|
|
self.out_dim = out_dim
|
|
|
|
# Create parameters based on mode
|
|
if self.conv_mode == "tucker":
|
|
# Tucker decomposition for Conv2d 3x3+
|
|
# Shapes follow LyCORIS convention: w_a = (rank, out_dim), w_b = (rank, in_dim)
|
|
self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size))
|
|
self.hada_w1_a = nn.Parameter(torch.empty(lora_dim, out_dim))
|
|
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
|
self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size))
|
|
self.hada_w2_a = nn.Parameter(torch.empty(lora_dim, out_dim))
|
|
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
|
|
|
# LyCORIS init: w1_a = 0 (ensures ΔW=0), t1/t2 normal(0.1)
|
|
torch.nn.init.normal_(self.hada_t1, std=0.1)
|
|
torch.nn.init.normal_(self.hada_t2, std=0.1)
|
|
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
|
|
torch.nn.init.constant_(self.hada_w1_a, 0)
|
|
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
|
|
torch.nn.init.normal_(self.hada_w2_a, std=0.1)
|
|
elif self.conv_mode == "flat":
|
|
# Non-Tucker Conv2d 3x3+: flatten kernel into in_dim
|
|
k_prod = 1
|
|
for k in kernel_size:
|
|
k_prod *= k
|
|
flat_in = in_dim * k_prod
|
|
|
|
self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
|
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, flat_in))
|
|
self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
|
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, flat_in))
|
|
|
|
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
|
|
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
|
|
torch.nn.init.constant_(self.hada_w2_a, 0)
|
|
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
|
|
else:
|
|
# Linear or Conv2d 1x1
|
|
self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
|
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
|
self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
|
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
|
|
|
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
|
|
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
|
|
torch.nn.init.constant_(self.hada_w2_a, 0)
|
|
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
|
|
|
|
if type(alpha) == torch.Tensor:
|
|
alpha = alpha.detach().float().numpy()
|
|
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
|
self.scale = alpha / self.lora_dim
|
|
self.register_buffer("alpha", torch.tensor(alpha))
|
|
|
|
self.multiplier = multiplier
|
|
self.org_module = org_module # remove in applying
|
|
self.dropout = dropout
|
|
self.rank_dropout = rank_dropout
|
|
self.module_dropout = module_dropout
|
|
|
|
def apply_to(self):
|
|
self.org_forward = self.org_module.forward
|
|
self.org_module.forward = self.forward
|
|
del self.org_module
|
|
|
|
def get_diff_weight(self):
|
|
"""Return materialized weight delta.
|
|
|
|
Returns:
|
|
- Linear: 2D tensor (out_dim, in_dim)
|
|
- Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d
|
|
- Conv2d 3x3+ Tucker: 4D tensor (out_dim, in_dim, k1, k2)
|
|
- Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2)
|
|
"""
|
|
if self.tucker:
|
|
scale = torch.tensor(self.scale, dtype=self.hada_t1.dtype, device=self.hada_t1.device)
|
|
return HadaWeightTucker.apply(
|
|
self.hada_t1, self.hada_w1_b, self.hada_w1_a,
|
|
self.hada_t2, self.hada_w2_b, self.hada_w2_a, scale
|
|
)
|
|
elif self.conv_mode == "flat":
|
|
scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
|
|
diff = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
|
return diff.reshape(self.out_dim, self.in_dim, *self.kernel_size)
|
|
else:
|
|
scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
|
|
return HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
|
|
|
def forward(self, x):
|
|
org_forwarded = self.org_forward(x)
|
|
|
|
# module dropout
|
|
if self.module_dropout is not None and self.training:
|
|
if torch.rand(1) < self.module_dropout:
|
|
return org_forwarded
|
|
|
|
diff_weight = self.get_diff_weight()
|
|
|
|
# rank dropout (applied on output dimension)
|
|
if self.rank_dropout is not None and self.training:
|
|
drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
|
|
drop = drop.view(-1, *([1] * (diff_weight.dim() - 1)))
|
|
diff_weight = diff_weight * drop
|
|
scale = 1.0 / (1.0 - self.rank_dropout)
|
|
else:
|
|
scale = 1.0
|
|
|
|
if self.is_conv:
|
|
if self.conv_mode == "1x1":
|
|
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
|
return org_forwarded + F.conv2d(
|
|
x, diff_weight, stride=self.stride, padding=self.padding,
|
|
dilation=self.dilation, groups=self.groups
|
|
) * self.multiplier * scale
|
|
else:
|
|
# Conv2d 3x3+: diff_weight is already 4D from get_diff_weight
|
|
return org_forwarded + F.conv2d(
|
|
x, diff_weight, stride=self.stride, padding=self.padding,
|
|
dilation=self.dilation, groups=self.groups
|
|
) * self.multiplier * scale
|
|
else:
|
|
return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
|
|
|
|
@property
|
|
def device(self):
|
|
return next(self.parameters()).device
|
|
|
|
@property
|
|
def dtype(self):
|
|
return next(self.parameters()).dtype
|
|
|
|
|
|
class LoHaInfModule(LoHaModule):
|
|
"""LoHa module for inference. Supports merge_to and get_weight."""
|
|
|
|
def __init__(
|
|
self,
|
|
lora_name,
|
|
org_module: torch.nn.Module,
|
|
multiplier=1.0,
|
|
lora_dim=4,
|
|
alpha=1,
|
|
**kwargs,
|
|
):
|
|
# no dropout for inference; pass use_tucker from kwargs
|
|
use_tucker = kwargs.pop("use_tucker", False)
|
|
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, use_tucker=use_tucker)
|
|
|
|
self.org_module_ref = [org_module]
|
|
self.enabled = True
|
|
self.network: AdditionalNetwork = None
|
|
|
|
def set_network(self, network):
|
|
self.network = network
|
|
|
|
def merge_to(self, sd, dtype, device):
|
|
# extract weight from org_module
|
|
org_sd = self.org_module.state_dict()
|
|
weight = org_sd["weight"]
|
|
org_dtype = weight.dtype
|
|
org_device = weight.device
|
|
weight = weight.to(torch.float)
|
|
|
|
if dtype is None:
|
|
dtype = org_dtype
|
|
if device is None:
|
|
device = org_device
|
|
|
|
# get LoHa weights
|
|
w1a = sd["hada_w1_a"].to(torch.float).to(device)
|
|
w1b = sd["hada_w1_b"].to(torch.float).to(device)
|
|
w2a = sd["hada_w2_a"].to(torch.float).to(device)
|
|
w2b = sd["hada_w2_b"].to(torch.float).to(device)
|
|
|
|
if self.tucker:
|
|
# Tucker mode
|
|
t1 = sd["hada_t1"].to(torch.float).to(device)
|
|
t2 = sd["hada_t2"].to(torch.float).to(device)
|
|
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
|
|
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
|
|
diff_weight = rebuild1 * rebuild2 * self.scale
|
|
else:
|
|
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale
|
|
# reshape diff_weight to match original weight shape if needed
|
|
if diff_weight.shape != weight.shape:
|
|
diff_weight = diff_weight.reshape(weight.shape)
|
|
|
|
weight = weight.to(device) + self.multiplier * diff_weight
|
|
|
|
org_sd["weight"] = weight.to(dtype)
|
|
self.org_module.load_state_dict(org_sd)
|
|
|
|
def get_weight(self, multiplier=None):
|
|
if multiplier is None:
|
|
multiplier = self.multiplier
|
|
|
|
if self.tucker:
|
|
t1 = self.hada_t1.to(torch.float)
|
|
w1a = self.hada_w1_a.to(torch.float)
|
|
w1b = self.hada_w1_b.to(torch.float)
|
|
t2 = self.hada_t2.to(torch.float)
|
|
w2a = self.hada_w2_a.to(torch.float)
|
|
w2b = self.hada_w2_b.to(torch.float)
|
|
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
|
|
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
|
|
weight = rebuild1 * rebuild2 * self.scale * multiplier
|
|
else:
|
|
w1a = self.hada_w1_a.to(torch.float)
|
|
w1b = self.hada_w1_b.to(torch.float)
|
|
w2a = self.hada_w2_a.to(torch.float)
|
|
w2b = self.hada_w2_b.to(torch.float)
|
|
weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale * multiplier
|
|
|
|
if self.is_conv:
|
|
if self.conv_mode == "1x1":
|
|
weight = weight.unsqueeze(2).unsqueeze(3)
|
|
elif self.conv_mode == "flat":
|
|
weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size)
|
|
|
|
return weight
|
|
|
|
def default_forward(self, x):
|
|
diff_weight = self.get_diff_weight()
|
|
if self.is_conv:
|
|
if self.conv_mode == "1x1":
|
|
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
|
return self.org_forward(x) + F.conv2d(
|
|
x, diff_weight, stride=self.stride, padding=self.padding,
|
|
dilation=self.dilation, groups=self.groups
|
|
) * self.multiplier
|
|
else:
|
|
return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
|
|
|
|
def forward(self, x):
|
|
if not self.enabled:
|
|
return self.org_forward(x)
|
|
return self.default_forward(x)
|
|
|
|
|
|
def create_network(
|
|
multiplier: float,
|
|
network_dim: Optional[int],
|
|
network_alpha: Optional[float],
|
|
vae,
|
|
text_encoder,
|
|
unet,
|
|
neuron_dropout: Optional[float] = None,
|
|
**kwargs,
|
|
):
|
|
"""Create a LoHa network. Called by train_network.py via network_module.create_network()."""
|
|
if network_dim is None:
|
|
network_dim = 4
|
|
if network_alpha is None:
|
|
network_alpha = 1.0
|
|
|
|
# handle text_encoder as list
|
|
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
|
|
|
# detect architecture
|
|
arch_config = detect_arch_config(unet, text_encoders)
|
|
|
|
# train LLM adapter
|
|
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
|
|
if train_llm_adapter is not None:
|
|
train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False
|
|
|
|
# exclude patterns
|
|
exclude_patterns = kwargs.get("exclude_patterns", None)
|
|
if exclude_patterns is None:
|
|
exclude_patterns = []
|
|
else:
|
|
exclude_patterns = ast.literal_eval(exclude_patterns)
|
|
if not isinstance(exclude_patterns, list):
|
|
exclude_patterns = [exclude_patterns]
|
|
|
|
# add default exclude patterns from arch config
|
|
exclude_patterns.extend(arch_config.default_excludes)
|
|
|
|
# include patterns
|
|
include_patterns = kwargs.get("include_patterns", None)
|
|
if include_patterns is not None:
|
|
include_patterns = ast.literal_eval(include_patterns)
|
|
if not isinstance(include_patterns, list):
|
|
include_patterns = [include_patterns]
|
|
|
|
# rank/module dropout
|
|
rank_dropout = kwargs.get("rank_dropout", None)
|
|
if rank_dropout is not None:
|
|
rank_dropout = float(rank_dropout)
|
|
module_dropout = kwargs.get("module_dropout", None)
|
|
if module_dropout is not None:
|
|
module_dropout = float(module_dropout)
|
|
|
|
# conv dim/alpha for Conv2d 3x3
|
|
conv_lora_dim = kwargs.get("conv_dim", None)
|
|
conv_alpha = kwargs.get("conv_alpha", None)
|
|
if conv_lora_dim is not None:
|
|
conv_lora_dim = int(conv_lora_dim)
|
|
if conv_alpha is None:
|
|
conv_alpha = 1.0
|
|
else:
|
|
conv_alpha = float(conv_alpha)
|
|
|
|
# Tucker decomposition for Conv2d 3x3
|
|
use_tucker = kwargs.get("use_tucker", "false")
|
|
if use_tucker is not None:
|
|
use_tucker = True if str(use_tucker).lower() == "true" else False
|
|
|
|
# verbose
|
|
verbose = kwargs.get("verbose", "false")
|
|
if verbose is not None:
|
|
verbose = True if str(verbose).lower() == "true" else False
|
|
|
|
# regex-specific learning rates / dimensions
|
|
network_reg_lrs = kwargs.get("network_reg_lrs", None)
|
|
reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None
|
|
|
|
network_reg_dims = kwargs.get("network_reg_dims", None)
|
|
reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None
|
|
|
|
network = AdditionalNetwork(
|
|
text_encoders,
|
|
unet,
|
|
arch_config=arch_config,
|
|
multiplier=multiplier,
|
|
lora_dim=network_dim,
|
|
alpha=network_alpha,
|
|
dropout=neuron_dropout,
|
|
rank_dropout=rank_dropout,
|
|
module_dropout=module_dropout,
|
|
module_class=LoHaModule,
|
|
module_kwargs={"use_tucker": use_tucker},
|
|
conv_lora_dim=conv_lora_dim,
|
|
conv_alpha=conv_alpha,
|
|
train_llm_adapter=train_llm_adapter,
|
|
exclude_patterns=exclude_patterns,
|
|
include_patterns=include_patterns,
|
|
reg_dims=reg_dims,
|
|
reg_lrs=reg_lrs,
|
|
verbose=verbose,
|
|
)
|
|
|
|
# LoRA+ support
|
|
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
|
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
|
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
|
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
|
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
|
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
|
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
|
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
|
|
|
return network
|
|
|
|
|
|
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
|
"""Create a LoHa network from saved weights. Called by train_network.py."""
|
|
if weights_sd is None:
|
|
if os.path.splitext(file)[1] == ".safetensors":
|
|
from safetensors.torch import load_file
|
|
|
|
weights_sd = load_file(file)
|
|
else:
|
|
weights_sd = torch.load(file, map_location="cpu")
|
|
|
|
# detect dim/alpha from weights
|
|
modules_dim = {}
|
|
modules_alpha = {}
|
|
train_llm_adapter = False
|
|
for key, value in weights_sd.items():
|
|
if "." not in key:
|
|
continue
|
|
|
|
lora_name = key.split(".")[0]
|
|
if "alpha" in key:
|
|
modules_alpha[lora_name] = value
|
|
elif "hada_w1_b" in key:
|
|
dim = value.shape[0]
|
|
modules_dim[lora_name] = dim
|
|
|
|
if "llm_adapter" in lora_name:
|
|
train_llm_adapter = True
|
|
|
|
# detect Tucker mode from weights
|
|
use_tucker = any("hada_t1" in key for key in weights_sd.keys())
|
|
|
|
# handle text_encoder as list
|
|
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
|
|
|
# detect architecture
|
|
arch_config = detect_arch_config(unet, text_encoders)
|
|
|
|
module_class = LoHaInfModule if for_inference else LoHaModule
|
|
module_kwargs = {"use_tucker": use_tucker}
|
|
|
|
network = AdditionalNetwork(
|
|
text_encoders,
|
|
unet,
|
|
arch_config=arch_config,
|
|
multiplier=multiplier,
|
|
modules_dim=modules_dim,
|
|
modules_alpha=modules_alpha,
|
|
module_class=module_class,
|
|
module_kwargs=module_kwargs,
|
|
train_llm_adapter=train_llm_adapter,
|
|
)
|
|
return network, weights_sd
|
|
|
|
|
|
def merge_weights_to_tensor(
|
|
model_weight: torch.Tensor,
|
|
lora_name: str,
|
|
lora_sd: Dict[str, torch.Tensor],
|
|
lora_weight_keys: set,
|
|
multiplier: float,
|
|
calc_device: torch.device,
|
|
) -> torch.Tensor:
|
|
"""Merge LoHa weights directly into a model weight tensor.
|
|
|
|
Supports standard LoHa, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3.
|
|
No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
|
|
Returns model_weight unchanged if no matching LoHa keys found.
|
|
"""
|
|
w1a_key = lora_name + ".hada_w1_a"
|
|
w1b_key = lora_name + ".hada_w1_b"
|
|
w2a_key = lora_name + ".hada_w2_a"
|
|
w2b_key = lora_name + ".hada_w2_b"
|
|
t1_key = lora_name + ".hada_t1"
|
|
t2_key = lora_name + ".hada_t2"
|
|
alpha_key = lora_name + ".alpha"
|
|
|
|
if w1a_key not in lora_weight_keys:
|
|
return model_weight
|
|
|
|
w1a = lora_sd[w1a_key].to(calc_device)
|
|
w1b = lora_sd[w1b_key].to(calc_device)
|
|
w2a = lora_sd[w2a_key].to(calc_device)
|
|
w2b = lora_sd[w2b_key].to(calc_device)
|
|
|
|
has_tucker = t1_key in lora_weight_keys
|
|
|
|
dim = w1b.shape[0]
|
|
alpha = lora_sd.get(alpha_key, torch.tensor(dim))
|
|
if isinstance(alpha, torch.Tensor):
|
|
alpha = alpha.item()
|
|
scale = alpha / dim
|
|
|
|
original_dtype = model_weight.dtype
|
|
if original_dtype.itemsize == 1: # fp8
|
|
model_weight = model_weight.to(torch.float16)
|
|
w1a, w1b = w1a.to(torch.float16), w1b.to(torch.float16)
|
|
w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16)
|
|
|
|
if has_tucker:
|
|
# Tucker decomposition: rebuild via einsum
|
|
t1 = lora_sd[t1_key].to(calc_device)
|
|
t2 = lora_sd[t2_key].to(calc_device)
|
|
if original_dtype.itemsize == 1:
|
|
t1, t2 = t1.to(torch.float16), t2.to(torch.float16)
|
|
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
|
|
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
|
|
diff_weight = rebuild1 * rebuild2 * scale
|
|
else:
|
|
# Standard LoHa: ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
|
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
|
|
|
# Reshape diff_weight to match model_weight shape if needed
|
|
# (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.)
|
|
if diff_weight.shape != model_weight.shape:
|
|
diff_weight = diff_weight.reshape(model_weight.shape)
|
|
|
|
model_weight = model_weight + multiplier * diff_weight
|
|
|
|
if original_dtype.itemsize == 1:
|
|
model_weight = model_weight.to(original_dtype)
|
|
|
|
# remove consumed keys
|
|
consumed = [w1a_key, w1b_key, w2a_key, w2b_key, alpha_key]
|
|
if has_tucker:
|
|
consumed.extend([t1_key, t2_key])
|
|
for key in consumed:
|
|
lora_weight_keys.discard(key)
|
|
|
|
return model_weight
|