mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 00:32:25 +00:00
feat: Enhance LoHa and LoKr modules with Tucker decomposition support
- Added Tucker decomposition functionality to LoHa and LoKr modules. - Implemented new methods for weight rebuilding using Tucker decomposition. - Updated initialization and weight handling for Conv2d 3x3+ layers. - Modified get_diff_weight methods to accommodate Tucker and non-Tucker modes. - Enhanced network base to include unet_conv_target_modules for architecture detection.
This commit is contained in:
283
networks/loha.py
283
networks/loha.py
@@ -51,6 +51,63 @@ class HadaWeight(torch.autograd.Function):
|
||||
return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
|
||||
|
||||
|
||||
class HadaWeightTucker(torch.autograd.Function):
|
||||
"""Tucker-decomposed Hadamard product forward/backward for LoHa Conv2d 3x3+.
|
||||
|
||||
Computes (rebuild(t1, w1b, w1a) * rebuild(t2, w2b, w2a)) * scale
|
||||
where rebuild = einsum("i j ..., j r, i p -> p r ...", t, wb, wa).
|
||||
Compatible with LyCORIS parameter naming convention.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, t1, w1b, w1a, t2, w2b, w2a, scale=None):
|
||||
if scale is None:
|
||||
scale = torch.tensor(1, device=t1.device, dtype=t1.dtype)
|
||||
ctx.save_for_backward(t1, w1b, w1a, t2, w2b, w2a, scale)
|
||||
|
||||
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
|
||||
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
|
||||
|
||||
return rebuild1 * rebuild2 * scale
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(t1, w1b, w1a, t2, w2b, w2a, scale) = ctx.saved_tensors
|
||||
grad_out = grad_out * scale
|
||||
|
||||
# Gradients for w1a, w1b, t1 (using rebuild2)
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2b)
|
||||
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2a)
|
||||
|
||||
grad_w = rebuild * grad_out
|
||||
del rebuild
|
||||
|
||||
grad_w1a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1a.T)
|
||||
del grad_w, temp
|
||||
|
||||
grad_w1b = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
|
||||
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1b.T)
|
||||
del grad_temp
|
||||
|
||||
# Gradients for w2a, w2b, t2 (using rebuild1)
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1b)
|
||||
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1a)
|
||||
|
||||
grad_w = rebuild * grad_out
|
||||
del rebuild
|
||||
|
||||
grad_w2a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2a.T)
|
||||
del grad_w, temp
|
||||
|
||||
grad_w2b = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
|
||||
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2b.T)
|
||||
del grad_temp
|
||||
|
||||
return grad_t1, grad_w1b, grad_w1a, grad_t2, grad_w2b, grad_w2a, None
|
||||
|
||||
|
||||
class LoHaModule(torch.nn.Module):
|
||||
"""LoHa module for training. Replaces forward method of the original Linear/Conv2d."""
|
||||
|
||||
@@ -64,6 +121,7 @@ class LoHaModule(torch.nn.Module):
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
use_tucker=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -74,26 +132,78 @@ class LoHaModule(torch.nn.Module):
|
||||
if is_conv2d:
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
kernel_size = org_module.kernel_size
|
||||
self.is_conv = True
|
||||
if org_module.kernel_size != (1, 1):
|
||||
raise ValueError("LoHa Conv2d 3x3 (Tucker decomposition) is not supported yet")
|
||||
self.stride = org_module.stride
|
||||
self.padding = org_module.padding
|
||||
self.dilation = org_module.dilation
|
||||
self.groups = org_module.groups
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.tucker = use_tucker and any(k != 1 for k in kernel_size)
|
||||
|
||||
if kernel_size == (1, 1):
|
||||
self.conv_mode = "1x1"
|
||||
elif self.tucker:
|
||||
self.conv_mode = "tucker"
|
||||
else:
|
||||
self.conv_mode = "flat"
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
self.is_conv = False
|
||||
self.tucker = False
|
||||
self.conv_mode = None
|
||||
self.kernel_size = None
|
||||
|
||||
# Hadamard product parameters: ΔW = (w1a @ w1b) * (w2a @ w2b)
|
||||
self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
||||
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
||||
self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
||||
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# Initialization: w1_a normal(0.1), w1_b normal(1.0), w2_a = 0, w2_b normal(1.0)
|
||||
# Ensures ΔW = 0 at init since w2_a = 0
|
||||
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
|
||||
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
|
||||
torch.nn.init.constant_(self.hada_w2_a, 0)
|
||||
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
|
||||
# Create parameters based on mode
|
||||
if self.conv_mode == "tucker":
|
||||
# Tucker decomposition for Conv2d 3x3+
|
||||
# Shapes follow LyCORIS convention: w_a = (rank, out_dim), w_b = (rank, in_dim)
|
||||
self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size))
|
||||
self.hada_w1_a = nn.Parameter(torch.empty(lora_dim, out_dim))
|
||||
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
||||
self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size))
|
||||
self.hada_w2_a = nn.Parameter(torch.empty(lora_dim, out_dim))
|
||||
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
||||
|
||||
# LyCORIS init: w1_a = 0 (ensures ΔW=0), t1/t2 normal(0.1)
|
||||
torch.nn.init.normal_(self.hada_t1, std=0.1)
|
||||
torch.nn.init.normal_(self.hada_t2, std=0.1)
|
||||
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
|
||||
torch.nn.init.constant_(self.hada_w1_a, 0)
|
||||
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
|
||||
torch.nn.init.normal_(self.hada_w2_a, std=0.1)
|
||||
elif self.conv_mode == "flat":
|
||||
# Non-Tucker Conv2d 3x3+: flatten kernel into in_dim
|
||||
k_prod = 1
|
||||
for k in kernel_size:
|
||||
k_prod *= k
|
||||
flat_in = in_dim * k_prod
|
||||
|
||||
self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
||||
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, flat_in))
|
||||
self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
||||
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, flat_in))
|
||||
|
||||
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
|
||||
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
|
||||
torch.nn.init.constant_(self.hada_w2_a, 0)
|
||||
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
|
||||
else:
|
||||
# Linear or Conv2d 1x1
|
||||
self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
||||
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
||||
self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
||||
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
||||
|
||||
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
|
||||
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
|
||||
torch.nn.init.constant_(self.hada_w2_a, 0)
|
||||
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy()
|
||||
@@ -113,9 +223,27 @@ class LoHaModule(torch.nn.Module):
|
||||
del self.org_module
|
||||
|
||||
def get_diff_weight(self):
|
||||
"""Return materialized weight delta as a 2D matrix."""
|
||||
scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
|
||||
return HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
"""Return materialized weight delta.
|
||||
|
||||
Returns:
|
||||
- Linear: 2D tensor (out_dim, in_dim)
|
||||
- Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d
|
||||
- Conv2d 3x3+ Tucker: 4D tensor (out_dim, in_dim, k1, k2)
|
||||
- Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2)
|
||||
"""
|
||||
if self.tucker:
|
||||
scale = torch.tensor(self.scale, dtype=self.hada_t1.dtype, device=self.hada_t1.device)
|
||||
return HadaWeightTucker.apply(
|
||||
self.hada_t1, self.hada_w1_b, self.hada_w1_a,
|
||||
self.hada_t2, self.hada_w2_b, self.hada_w2_a, scale
|
||||
)
|
||||
elif self.conv_mode == "flat":
|
||||
scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
|
||||
diff = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
return diff.reshape(self.out_dim, self.in_dim, *self.kernel_size)
|
||||
else:
|
||||
scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
|
||||
return HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
|
||||
def forward(self, x):
|
||||
org_forwarded = self.org_forward(x)
|
||||
@@ -130,16 +258,25 @@ class LoHaModule(torch.nn.Module):
|
||||
# rank dropout (applied on output dimension)
|
||||
if self.rank_dropout is not None and self.training:
|
||||
drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
|
||||
drop = drop.view(-1, 1)
|
||||
drop = drop.view(-1, *([1] * (diff_weight.dim() - 1)))
|
||||
diff_weight = diff_weight * drop
|
||||
scale = 1.0 / (1.0 - self.rank_dropout)
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
if self.is_conv:
|
||||
# Conv2d 1x1: reshape to 4D for conv operation
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return org_forwarded + F.conv2d(x, diff_weight) * self.multiplier * scale
|
||||
if self.conv_mode == "1x1":
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return org_forwarded + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier * scale
|
||||
else:
|
||||
# Conv2d 3x3+: diff_weight is already 4D from get_diff_weight
|
||||
return org_forwarded + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier * scale
|
||||
else:
|
||||
return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
|
||||
|
||||
@@ -164,8 +301,9 @@ class LoHaInfModule(LoHaModule):
|
||||
alpha=1,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
||||
# no dropout for inference; pass use_tucker from kwargs
|
||||
use_tucker = kwargs.pop("use_tucker", False)
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, use_tucker=use_tucker)
|
||||
|
||||
self.org_module_ref = [org_module]
|
||||
self.enabled = True
|
||||
@@ -193,11 +331,18 @@ class LoHaInfModule(LoHaModule):
|
||||
w2a = sd["hada_w2_a"].to(torch.float).to(device)
|
||||
w2b = sd["hada_w2_b"].to(torch.float).to(device)
|
||||
|
||||
# compute ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
||||
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale
|
||||
|
||||
if self.is_conv:
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
if self.tucker:
|
||||
# Tucker mode
|
||||
t1 = sd["hada_t1"].to(torch.float).to(device)
|
||||
t2 = sd["hada_t2"].to(torch.float).to(device)
|
||||
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
|
||||
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
|
||||
diff_weight = rebuild1 * rebuild2 * self.scale
|
||||
else:
|
||||
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale
|
||||
# reshape diff_weight to match original weight shape if needed
|
||||
if diff_weight.shape != weight.shape:
|
||||
diff_weight = diff_weight.reshape(weight.shape)
|
||||
|
||||
weight = weight.to(device) + self.multiplier * diff_weight
|
||||
|
||||
@@ -208,23 +353,40 @@ class LoHaInfModule(LoHaModule):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
w1a = self.hada_w1_a.to(torch.float)
|
||||
w1b = self.hada_w1_b.to(torch.float)
|
||||
w2a = self.hada_w2_a.to(torch.float)
|
||||
w2b = self.hada_w2_b.to(torch.float)
|
||||
if self.tucker:
|
||||
t1 = self.hada_t1.to(torch.float)
|
||||
w1a = self.hada_w1_a.to(torch.float)
|
||||
w1b = self.hada_w1_b.to(torch.float)
|
||||
t2 = self.hada_t2.to(torch.float)
|
||||
w2a = self.hada_w2_a.to(torch.float)
|
||||
w2b = self.hada_w2_b.to(torch.float)
|
||||
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
|
||||
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
|
||||
weight = rebuild1 * rebuild2 * self.scale * multiplier
|
||||
else:
|
||||
w1a = self.hada_w1_a.to(torch.float)
|
||||
w1b = self.hada_w1_b.to(torch.float)
|
||||
w2a = self.hada_w2_a.to(torch.float)
|
||||
w2b = self.hada_w2_b.to(torch.float)
|
||||
weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale * multiplier
|
||||
|
||||
weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale * multiplier
|
||||
|
||||
if self.is_conv:
|
||||
weight = weight.unsqueeze(2).unsqueeze(3)
|
||||
if self.is_conv:
|
||||
if self.conv_mode == "1x1":
|
||||
weight = weight.unsqueeze(2).unsqueeze(3)
|
||||
elif self.conv_mode == "flat":
|
||||
weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size)
|
||||
|
||||
return weight
|
||||
|
||||
def default_forward(self, x):
|
||||
diff_weight = self.get_diff_weight()
|
||||
if self.is_conv:
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return self.org_forward(x) + F.conv2d(x, diff_weight) * self.multiplier
|
||||
if self.conv_mode == "1x1":
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return self.org_forward(x) + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier
|
||||
else:
|
||||
return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
|
||||
|
||||
@@ -288,7 +450,7 @@ def create_network(
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# conv dim/alpha (for future Conv2d 3x3 support)
|
||||
# conv dim/alpha for Conv2d 3x3
|
||||
conv_lora_dim = kwargs.get("conv_dim", None)
|
||||
conv_alpha = kwargs.get("conv_alpha", None)
|
||||
if conv_lora_dim is not None:
|
||||
@@ -298,6 +460,11 @@ def create_network(
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
# Tucker decomposition for Conv2d 3x3
|
||||
use_tucker = kwargs.get("use_tucker", "false")
|
||||
if use_tucker is not None:
|
||||
use_tucker = True if str(use_tucker).lower() == "true" else False
|
||||
|
||||
# verbose
|
||||
verbose = kwargs.get("verbose", "false")
|
||||
if verbose is not None:
|
||||
@@ -321,6 +488,7 @@ def create_network(
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
module_class=LoHaModule,
|
||||
module_kwargs={"use_tucker": use_tucker},
|
||||
conv_lora_dim=conv_lora_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
@@ -372,6 +540,9 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
if "llm_adapter" in lora_name:
|
||||
train_llm_adapter = True
|
||||
|
||||
# detect Tucker mode from weights
|
||||
use_tucker = any("hada_t1" in key for key in weights_sd.keys())
|
||||
|
||||
# handle text_encoder as list
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
||||
|
||||
@@ -379,6 +550,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
arch_config = detect_arch_config(unet, text_encoders)
|
||||
|
||||
module_class = LoHaInfModule if for_inference else LoHaModule
|
||||
module_kwargs = {"use_tucker": use_tucker}
|
||||
|
||||
network = AdditionalNetwork(
|
||||
text_encoders,
|
||||
@@ -388,6 +560,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
modules_dim=modules_dim,
|
||||
modules_alpha=modules_alpha,
|
||||
module_class=module_class,
|
||||
module_kwargs=module_kwargs,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
)
|
||||
return network, weights_sd
|
||||
@@ -403,6 +576,7 @@ def merge_weights_to_tensor(
|
||||
) -> torch.Tensor:
|
||||
"""Merge LoHa weights directly into a model weight tensor.
|
||||
|
||||
Supports standard LoHa, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3.
|
||||
No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
|
||||
Returns model_weight unchanged if no matching LoHa keys found.
|
||||
"""
|
||||
@@ -410,6 +584,8 @@ def merge_weights_to_tensor(
|
||||
w1b_key = lora_name + ".hada_w1_b"
|
||||
w2a_key = lora_name + ".hada_w2_a"
|
||||
w2b_key = lora_name + ".hada_w2_b"
|
||||
t1_key = lora_name + ".hada_t1"
|
||||
t2_key = lora_name + ".hada_t2"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
|
||||
if w1a_key not in lora_weight_keys:
|
||||
@@ -420,6 +596,8 @@ def merge_weights_to_tensor(
|
||||
w2a = lora_sd[w2a_key].to(calc_device)
|
||||
w2b = lora_sd[w2b_key].to(calc_device)
|
||||
|
||||
has_tucker = t1_key in lora_weight_keys
|
||||
|
||||
dim = w1b.shape[0]
|
||||
alpha = lora_sd.get(alpha_key, torch.tensor(dim))
|
||||
if isinstance(alpha, torch.Tensor):
|
||||
@@ -429,14 +607,26 @@ def merge_weights_to_tensor(
|
||||
original_dtype = model_weight.dtype
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
model_weight = model_weight.to(torch.float16)
|
||||
w1a, w1b, w2a, w2b = w1a.to(torch.float16), w1b.to(torch.float16), w2a.to(torch.float16), w2b.to(torch.float16)
|
||||
w1a, w1b = w1a.to(torch.float16), w1b.to(torch.float16)
|
||||
w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16)
|
||||
|
||||
# ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
||||
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
||||
if has_tucker:
|
||||
# Tucker decomposition: rebuild via einsum
|
||||
t1 = lora_sd[t1_key].to(calc_device)
|
||||
t2 = lora_sd[t2_key].to(calc_device)
|
||||
if original_dtype.itemsize == 1:
|
||||
t1, t2 = t1.to(torch.float16), t2.to(torch.float16)
|
||||
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a)
|
||||
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a)
|
||||
diff_weight = rebuild1 * rebuild2 * scale
|
||||
else:
|
||||
# Standard LoHa: ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
||||
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
||||
|
||||
# handle Conv2d 1x1 weights (4D tensors)
|
||||
if len(model_weight.shape) == 4:
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
# Reshape diff_weight to match model_weight shape if needed
|
||||
# (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.)
|
||||
if diff_weight.shape != model_weight.shape:
|
||||
diff_weight = diff_weight.reshape(model_weight.shape)
|
||||
|
||||
model_weight = model_weight + multiplier * diff_weight
|
||||
|
||||
@@ -444,7 +634,10 @@ def merge_weights_to_tensor(
|
||||
model_weight = model_weight.to(original_dtype)
|
||||
|
||||
# remove consumed keys
|
||||
for key in [w1a_key, w1b_key, w2a_key, w2b_key, alpha_key]:
|
||||
consumed = [w1a_key, w1b_key, w2a_key, w2b_key, alpha_key]
|
||||
if has_tucker:
|
||||
consumed.extend([t1_key, t2_key])
|
||||
for key in consumed:
|
||||
lora_weight_keys.discard(key)
|
||||
|
||||
return model_weight
|
||||
|
||||
211
networks/lokr.py
211
networks/lokr.py
@@ -68,6 +68,14 @@ def make_kron(w1, w2, scale):
|
||||
return rebuild
|
||||
|
||||
|
||||
def rebuild_tucker(t, wa, wb):
|
||||
"""Rebuild weight from Tucker decomposition: einsum("i j ..., i p, j r -> p r ...", t, wa, wb).
|
||||
|
||||
Compatible with LyCORIS convention.
|
||||
"""
|
||||
return torch.einsum("i j ..., i p, j r -> p r ...", t, wa, wb)
|
||||
|
||||
|
||||
class LoKrModule(torch.nn.Module):
|
||||
"""LoKr module for training. Replaces forward method of the original Linear/Conv2d."""
|
||||
|
||||
@@ -82,6 +90,7 @@ class LoKrModule(torch.nn.Module):
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
factor=-1,
|
||||
use_tucker=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -92,13 +101,32 @@ class LoKrModule(torch.nn.Module):
|
||||
if is_conv2d:
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
kernel_size = org_module.kernel_size
|
||||
self.is_conv = True
|
||||
if org_module.kernel_size != (1, 1):
|
||||
raise ValueError("LoKr Conv2d 3x3 (Tucker decomposition) is not supported yet")
|
||||
self.stride = org_module.stride
|
||||
self.padding = org_module.padding
|
||||
self.dilation = org_module.dilation
|
||||
self.groups = org_module.groups
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
self.tucker = use_tucker and any(k != 1 for k in kernel_size)
|
||||
|
||||
if kernel_size == (1, 1):
|
||||
self.conv_mode = "1x1"
|
||||
elif self.tucker:
|
||||
self.conv_mode = "tucker"
|
||||
else:
|
||||
self.conv_mode = "flat"
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
self.is_conv = False
|
||||
self.tucker = False
|
||||
self.conv_mode = None
|
||||
self.kernel_size = None
|
||||
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
factor = int(factor)
|
||||
self.use_w2 = False
|
||||
@@ -110,18 +138,44 @@ class LoKrModule(torch.nn.Module):
|
||||
# w1 is always a full matrix (the "scale" factor, small)
|
||||
self.lokr_w1 = nn.Parameter(torch.empty(out_l, in_m))
|
||||
|
||||
# w2: low-rank decomposition if rank is small enough, otherwise full matrix
|
||||
if lora_dim < max(out_k, in_n) / 2:
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n))
|
||||
else:
|
||||
self.use_w2 = True
|
||||
self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n))
|
||||
# w2: depends on mode
|
||||
if self.conv_mode in ("tucker", "flat"):
|
||||
# Conv2d 3x3+ modes
|
||||
k_size = kernel_size
|
||||
|
||||
if lora_dim >= max(out_k, in_n) / 2:
|
||||
# Full matrix mode (includes kernel dimensions)
|
||||
self.use_w2 = True
|
||||
self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n, *k_size))
|
||||
logger.warning(
|
||||
f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} "
|
||||
f"and factor={factor}, using full matrix mode."
|
||||
f"and factor={factor}, using full matrix mode for Conv2d."
|
||||
)
|
||||
elif self.tucker:
|
||||
# Tucker mode: separate kernel into t2 tensor
|
||||
self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *k_size))
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, out_k))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n))
|
||||
else:
|
||||
# Non-Tucker: flatten kernel into w2_b
|
||||
k_prod = 1
|
||||
for k in k_size:
|
||||
k_prod *= k
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n * k_prod))
|
||||
else:
|
||||
# Linear or Conv2d 1x1
|
||||
if lora_dim < max(out_k, in_n) / 2:
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n))
|
||||
else:
|
||||
self.use_w2 = True
|
||||
self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n))
|
||||
if lora_dim >= max(out_k, in_n) / 2:
|
||||
logger.warning(
|
||||
f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} "
|
||||
f"and factor={factor}, using full matrix mode."
|
||||
)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy()
|
||||
@@ -137,6 +191,8 @@ class LoKrModule(torch.nn.Module):
|
||||
if self.use_w2:
|
||||
torch.nn.init.constant_(self.lokr_w2, 0)
|
||||
else:
|
||||
if self.tucker:
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
|
||||
torch.nn.init.constant_(self.lokr_w2_b, 0)
|
||||
# Ensures ΔW = kron(w1, 0) = 0 at init
|
||||
@@ -153,13 +209,30 @@ class LoKrModule(torch.nn.Module):
|
||||
del self.org_module
|
||||
|
||||
def get_diff_weight(self):
|
||||
"""Return materialized weight delta."""
|
||||
"""Return materialized weight delta.
|
||||
|
||||
Returns:
|
||||
- Linear: 2D tensor (out_dim, in_dim)
|
||||
- Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d
|
||||
- Conv2d 3x3+ Tucker/full: 4D tensor (out_dim, in_dim, k1, k2)
|
||||
- Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2) — reshaped from 2D
|
||||
"""
|
||||
w1 = self.lokr_w1
|
||||
|
||||
if self.use_w2:
|
||||
w2 = self.lokr_w2
|
||||
elif self.tucker:
|
||||
w2 = rebuild_tucker(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b)
|
||||
else:
|
||||
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
||||
return make_kron(w1, w2, self.scale)
|
||||
|
||||
result = make_kron(w1, w2, self.scale)
|
||||
|
||||
# For non-Tucker Conv2d 3x3+, result is 2D; reshape to 4D
|
||||
if self.conv_mode == "flat" and result.dim() == 2:
|
||||
result = result.reshape(self.out_dim, self.in_dim, *self.kernel_size)
|
||||
|
||||
return result
|
||||
|
||||
def forward(self, x):
|
||||
org_forwarded = self.org_forward(x)
|
||||
@@ -174,15 +247,25 @@ class LoKrModule(torch.nn.Module):
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.training:
|
||||
drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
|
||||
drop = drop.view(-1, 1)
|
||||
drop = drop.view(-1, *([1] * (diff_weight.dim() - 1)))
|
||||
diff_weight = diff_weight * drop
|
||||
scale = 1.0 / (1.0 - self.rank_dropout)
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
if self.is_conv:
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return org_forwarded + F.conv2d(x, diff_weight) * self.multiplier * scale
|
||||
if self.conv_mode == "1x1":
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return org_forwarded + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier * scale
|
||||
else:
|
||||
# Conv2d 3x3+: diff_weight is already 4D from get_diff_weight
|
||||
return org_forwarded + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier * scale
|
||||
else:
|
||||
return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
|
||||
|
||||
@@ -207,9 +290,10 @@ class LoKrInfModule(LoKrModule):
|
||||
alpha=1,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference; pass factor from kwargs if present
|
||||
# no dropout for inference; pass factor and use_tucker from kwargs
|
||||
factor = kwargs.pop("factor", -1)
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, factor=factor)
|
||||
use_tucker = kwargs.pop("use_tucker", False)
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, factor=factor, use_tucker=use_tucker)
|
||||
|
||||
self.org_module_ref = [org_module]
|
||||
self.enabled = True
|
||||
@@ -236,6 +320,12 @@ class LoKrInfModule(LoKrModule):
|
||||
|
||||
if "lokr_w2" in sd:
|
||||
w2 = sd["lokr_w2"].to(torch.float).to(device)
|
||||
elif "lokr_t2" in sd:
|
||||
# Tucker mode
|
||||
t2 = sd["lokr_t2"].to(torch.float).to(device)
|
||||
w2a = sd["lokr_w2_a"].to(torch.float).to(device)
|
||||
w2b = sd["lokr_w2_b"].to(torch.float).to(device)
|
||||
w2 = rebuild_tucker(t2, w2a, w2b)
|
||||
else:
|
||||
w2a = sd["lokr_w2_a"].to(torch.float).to(device)
|
||||
w2b = sd["lokr_w2_b"].to(torch.float).to(device)
|
||||
@@ -244,8 +334,9 @@ class LoKrInfModule(LoKrModule):
|
||||
# compute ΔW via Kronecker product
|
||||
diff_weight = make_kron(w1, w2, self.scale)
|
||||
|
||||
if self.is_conv:
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
# reshape diff_weight to match original weight shape if needed
|
||||
if diff_weight.shape != weight.shape:
|
||||
diff_weight = diff_weight.reshape(weight.shape)
|
||||
|
||||
weight = weight.to(device) + self.multiplier * diff_weight
|
||||
|
||||
@@ -257,23 +348,39 @@ class LoKrInfModule(LoKrModule):
|
||||
multiplier = self.multiplier
|
||||
|
||||
w1 = self.lokr_w1.to(torch.float)
|
||||
|
||||
if self.use_w2:
|
||||
w2 = self.lokr_w2.to(torch.float)
|
||||
elif self.tucker:
|
||||
w2 = rebuild_tucker(
|
||||
self.lokr_t2.to(torch.float),
|
||||
self.lokr_w2_a.to(torch.float),
|
||||
self.lokr_w2_b.to(torch.float),
|
||||
)
|
||||
else:
|
||||
w2 = (self.lokr_w2_a @ self.lokr_w2_b).to(torch.float)
|
||||
|
||||
weight = make_kron(w1, w2, self.scale) * multiplier
|
||||
|
||||
# reshape to match original weight shape if needed
|
||||
if self.is_conv:
|
||||
weight = weight.unsqueeze(2).unsqueeze(3)
|
||||
if self.conv_mode == "1x1":
|
||||
weight = weight.unsqueeze(2).unsqueeze(3)
|
||||
elif self.conv_mode == "flat" and weight.dim() == 2:
|
||||
weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size)
|
||||
# Tucker and full matrix modes: already 4D from kron
|
||||
|
||||
return weight
|
||||
|
||||
def default_forward(self, x):
|
||||
diff_weight = self.get_diff_weight()
|
||||
if self.is_conv:
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return self.org_forward(x) + F.conv2d(x, diff_weight) * self.multiplier
|
||||
if self.conv_mode == "1x1":
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return self.org_forward(x) + F.conv2d(
|
||||
x, diff_weight, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups
|
||||
) * self.multiplier
|
||||
else:
|
||||
return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
|
||||
|
||||
@@ -337,7 +444,7 @@ def create_network(
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# conv dim/alpha (for future Conv2d 3x3 support)
|
||||
# conv dim/alpha for Conv2d 3x3
|
||||
conv_lora_dim = kwargs.get("conv_dim", None)
|
||||
conv_alpha = kwargs.get("conv_alpha", None)
|
||||
if conv_lora_dim is not None:
|
||||
@@ -347,6 +454,11 @@ def create_network(
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
# Tucker decomposition for Conv2d 3x3
|
||||
use_tucker = kwargs.get("use_tucker", "false")
|
||||
if use_tucker is not None:
|
||||
use_tucker = True if str(use_tucker).lower() == "true" else False
|
||||
|
||||
# factor for LoKr
|
||||
factor = int(kwargs.get("factor", -1))
|
||||
|
||||
@@ -373,7 +485,7 @@ def create_network(
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
module_class=LoKrModule,
|
||||
module_kwargs={"factor": factor},
|
||||
module_kwargs={"factor": factor, "use_tucker": use_tucker},
|
||||
conv_lora_dim=conv_lora_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
@@ -411,6 +523,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
train_llm_adapter = False
|
||||
use_tucker = False
|
||||
for key, value in weights_sd.items():
|
||||
if "." not in key:
|
||||
continue
|
||||
@@ -419,13 +532,21 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
if "alpha" in key:
|
||||
modules_alpha[lora_name] = value
|
||||
elif "lokr_w2_a" in key:
|
||||
# low-rank mode: dim = w2_a.shape[1]
|
||||
dim = value.shape[1]
|
||||
# low-rank mode: dim detection depends on Tucker vs non-Tucker
|
||||
if "lokr_t2" in key.replace("lokr_w2_a", "lokr_t2") and lora_name + ".lokr_t2" in weights_sd:
|
||||
# Tucker: w2_a = (rank, out_k) → dim = w2_a.shape[0]
|
||||
dim = value.shape[0]
|
||||
else:
|
||||
# Non-Tucker: w2_a = (out_k, rank) → dim = w2_a.shape[1]
|
||||
dim = value.shape[1]
|
||||
modules_dim[lora_name] = dim
|
||||
elif "lokr_w2" in key and "lokr_w2_a" not in key and "lokr_w2_b" not in key:
|
||||
# full matrix mode: set dim large enough to trigger full-matrix path
|
||||
if lora_name not in modules_dim:
|
||||
modules_dim[lora_name] = max(value.shape)
|
||||
modules_dim[lora_name] = max(value.shape[0], value.shape[1])
|
||||
|
||||
if "lokr_t2" in key:
|
||||
use_tucker = True
|
||||
|
||||
if "llm_adapter" in lora_name:
|
||||
train_llm_adapter = True
|
||||
@@ -440,7 +561,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
factor = int(kwargs.get("factor", -1))
|
||||
|
||||
module_class = LoKrInfModule if for_inference else LoKrModule
|
||||
module_kwargs = {"factor": factor}
|
||||
module_kwargs = {"factor": factor, "use_tucker": use_tucker}
|
||||
|
||||
network = AdditionalNetwork(
|
||||
text_encoders,
|
||||
@@ -466,6 +587,7 @@ def merge_weights_to_tensor(
|
||||
) -> torch.Tensor:
|
||||
"""Merge LoKr weights directly into a model weight tensor.
|
||||
|
||||
Supports standard LoKr, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3.
|
||||
No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
|
||||
Returns model_weight unchanged if no matching LoKr keys found.
|
||||
"""
|
||||
@@ -473,6 +595,7 @@ def merge_weights_to_tensor(
|
||||
w2_key = lora_name + ".lokr_w2"
|
||||
w2a_key = lora_name + ".lokr_w2_a"
|
||||
w2b_key = lora_name + ".lokr_w2_b"
|
||||
t2_key = lora_name + ".lokr_t2"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
|
||||
if w1_key not in lora_weight_keys:
|
||||
@@ -480,13 +603,23 @@ def merge_weights_to_tensor(
|
||||
|
||||
w1 = lora_sd[w1_key].to(calc_device)
|
||||
|
||||
# determine low-rank vs full matrix mode
|
||||
# determine mode: full matrix vs Tucker vs low-rank
|
||||
has_tucker = t2_key in lora_weight_keys
|
||||
|
||||
if w2a_key in lora_weight_keys:
|
||||
# low-rank: w2 = w2_a @ w2_b
|
||||
w2a = lora_sd[w2a_key].to(calc_device)
|
||||
w2b = lora_sd[w2b_key].to(calc_device)
|
||||
dim = w2a.shape[1]
|
||||
|
||||
if has_tucker:
|
||||
# Tucker: w2a = (rank, out_k), dim = rank
|
||||
dim = w2a.shape[0]
|
||||
else:
|
||||
# Non-Tucker low-rank: w2a = (out_k, rank), dim = rank
|
||||
dim = w2a.shape[1]
|
||||
|
||||
consumed_keys = [w1_key, w2a_key, w2b_key, alpha_key]
|
||||
if has_tucker:
|
||||
consumed_keys.append(t2_key)
|
||||
elif w2_key in lora_weight_keys:
|
||||
# full matrix mode
|
||||
w2a = None
|
||||
@@ -502,7 +635,6 @@ def merge_weights_to_tensor(
|
||||
|
||||
# compute scale
|
||||
if w2a is not None:
|
||||
# low-rank mode
|
||||
if alpha is None:
|
||||
alpha = dim
|
||||
scale = alpha / dim
|
||||
@@ -519,7 +651,13 @@ def merge_weights_to_tensor(
|
||||
|
||||
# compute w2
|
||||
if w2a is not None:
|
||||
w2 = w2a @ w2b
|
||||
if has_tucker:
|
||||
t2 = lora_sd[t2_key].to(calc_device)
|
||||
if original_dtype.itemsize == 1:
|
||||
t2 = t2.to(torch.float16)
|
||||
w2 = rebuild_tucker(t2, w2a, w2b)
|
||||
else:
|
||||
w2 = w2a @ w2b
|
||||
else:
|
||||
w2 = lora_sd[w2_key].to(calc_device)
|
||||
if original_dtype.itemsize == 1:
|
||||
@@ -528,9 +666,10 @@ def merge_weights_to_tensor(
|
||||
# ΔW = kron(w1, w2) * scale
|
||||
diff_weight = make_kron(w1, w2, scale)
|
||||
|
||||
# handle Conv2d 1x1 weights (4D tensors)
|
||||
if len(model_weight.shape) == 4:
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
# Reshape diff_weight to match model_weight shape if needed
|
||||
# (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.)
|
||||
if diff_weight.shape != model_weight.shape:
|
||||
diff_weight = diff_weight.reshape(model_weight.shape)
|
||||
|
||||
model_weight = model_weight + multiplier * diff_weight
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ class ArchConfig:
|
||||
te_prefixes: List[str]
|
||||
default_excludes: List[str] = field(default_factory=list)
|
||||
adapter_target_modules: List[str] = field(default_factory=list)
|
||||
unet_conv_target_modules: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def detect_arch_config(unet, text_encoders) -> ArchConfig:
|
||||
@@ -39,6 +40,7 @@ def detect_arch_config(unet, text_encoders) -> ArchConfig:
|
||||
unet_prefix="lora_unet",
|
||||
te_prefixes=["lora_te1", "lora_te2"],
|
||||
default_excludes=[],
|
||||
unet_conv_target_modules=["ResnetBlock2D", "Downsample2D", "Upsample2D"],
|
||||
)
|
||||
|
||||
# Check Anima: look for Block class in named_modules
|
||||
@@ -258,6 +260,8 @@ class AdditionalNetwork(torch.nn.Module):
|
||||
|
||||
# Create modules for UNet/DiT
|
||||
target_modules = list(arch_config.unet_target_modules)
|
||||
if modules_dim is not None or conv_lora_dim is not None:
|
||||
target_modules.extend(arch_config.unet_conv_target_modules)
|
||||
if train_llm_adapter and arch_config.adapter_target_modules:
|
||||
target_modules.extend(arch_config.adapter_target_modules)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user