feat: Add LoHa/LoKr network support for SDXL and Anima

- networks/network_base.py: shared AdditionalNetwork base class with architecture auto-detection (SDXL/Anima) and generic module injection
- networks/loha.py: LoHa (Low-rank Hadamard Product) module with HadaWeight custom autograd, training/inference classes, and factory functions
- networks/lokr.py: LoKr (Low-rank Kronecker Product) module with factorization, training/inference classes, and factory functions
- library/lora_utils.py: extend weight merge hook to detect and merge LoHa/LoKr weights alongside standard LoRA

Linear and Conv2d 1x1 layers only; Conv2d 3x3 (Tucker decomposition) support will be added separately.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Kohya S
2026-02-15 21:50:50 +09:00
parent ae72efb92b
commit ba72f3b665
6 changed files with 2027 additions and 268 deletions

1
.gitignore vendored
View File

@@ -11,3 +11,4 @@ GEMINI.md
.claude
.gemini
MagicMock
references

View File

@@ -1,267 +1,287 @@
import os
import re
from typing import Dict, List, Optional, Union
import torch
from tqdm import tqdm
from library.device_utils import synchronize_device
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def filter_lora_state_dict(
weights_sd: Dict[str, torch.Tensor],
include_pattern: Optional[str] = None,
exclude_pattern: Optional[str] = None,
) -> Dict[str, torch.Tensor]:
# apply include/exclude patterns
original_key_count = len(weights_sd.keys())
if include_pattern is not None:
regex_include = re.compile(include_pattern)
weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
if exclude_pattern is not None:
original_key_count_ex = len(weights_sd.keys())
regex_exclude = re.compile(exclude_pattern)
weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}")
if len(weights_sd) != original_key_count:
remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
remaining_keys.sort()
logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
if len(weights_sd) == 0:
logger.warning("No keys left after filtering.")
return weights_sd
def load_safetensors_with_lora_and_fp8(
model_files: Union[str, List[str]],
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
lora_multipliers: Optional[List[float]],
fp8_optimization: bool,
calc_device: torch.device,
move_to_device: bool = False,
dit_weight_dtype: Optional[torch.dtype] = None,
target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict[str, torch.Tensor]:
"""
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
Args:
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
fp8_optimization (bool): Whether to apply FP8 optimization.
calc_device (torch.device): Device to calculate on.
move_to_device (bool): Whether to move tensors to the calculation device after loading.
target_keys (Optional[List[str]]): Keys to target for optimization.
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
"""
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
if isinstance(model_files, str):
model_files = [model_files]
extended_model_files = []
for model_file in model_files:
split_filenames = get_split_weight_filenames(model_file)
if split_filenames is not None:
extended_model_files.extend(split_filenames)
else:
extended_model_files.append(model_file)
model_files = extended_model_files
logger.info(f"Loading model files: {model_files}")
# load LoRA weights
weight_hook = None
if lora_weights_list is None or len(lora_weights_list) == 0:
lora_weights_list = []
lora_multipliers = []
list_of_lora_weight_keys = []
else:
list_of_lora_weight_keys = []
for lora_sd in lora_weights_list:
lora_weight_keys = set(lora_sd.keys())
list_of_lora_weight_keys.append(lora_weight_keys)
if lora_multipliers is None:
lora_multipliers = [1.0] * len(lora_weights_list)
while len(lora_multipliers) < len(lora_weights_list):
lora_multipliers.append(1.0)
if len(lora_multipliers) > len(lora_weights_list):
lora_multipliers = lora_multipliers[: len(lora_weights_list)]
# Merge LoRA weights into the state dict
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
# make hook for LoRA merging
def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
if not model_weight_key.endswith(".weight"):
return model_weight
original_device = model_weight.device
if original_device != calc_device:
model_weight = model_weight.to(calc_device) # to make calculation faster
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
# check if this weight has LoRA weights
lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
found = False
for prefix in ["lora_unet_", ""]:
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
down_key = lora_name + ".lora_down.weight"
up_key = lora_name + ".lora_up.weight"
alpha_key = lora_name + ".alpha"
if down_key in lora_weight_keys and up_key in lora_weight_keys:
found = True
break
if not found:
continue # no LoRA weights for this model weight
# get LoRA weights
down_weight = lora_sd[down_key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
down_weight = down_weight.to(calc_device)
up_weight = up_weight.to(calc_device)
original_dtype = model_weight.dtype
if original_dtype.itemsize == 1: # fp8
# temporarily convert to float16 for calculation
model_weight = model_weight.to(torch.float16)
down_weight = down_weight.to(torch.float16)
up_weight = up_weight.to(torch.float16)
# W <- W + U * D
if len(model_weight.size()) == 2:
# linear
if len(up_weight.size()) == 4: # use linear projection mismatch
up_weight = up_weight.squeeze(3).squeeze(2)
down_weight = down_weight.squeeze(3).squeeze(2)
model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
model_weight = (
model_weight
+ multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
model_weight = model_weight + multiplier * conved * scale
if original_dtype.itemsize == 1: # fp8
model_weight = model_weight.to(original_dtype) # convert back to original dtype
# remove LoRA keys from set
lora_weight_keys.remove(down_key)
lora_weight_keys.remove(up_key)
if alpha_key in lora_weight_keys:
lora_weight_keys.remove(alpha_key)
if not keep_on_calc_device and original_device != calc_device:
model_weight = model_weight.to(original_device) # move back to original device
return model_weight
weight_hook = weight_hook_func
state_dict = load_safetensors_with_fp8_optimization_and_hook(
model_files,
fp8_optimization,
calc_device,
move_to_device,
dit_weight_dtype,
target_keys,
exclude_keys,
weight_hook=weight_hook,
disable_numpy_memmap=disable_numpy_memmap,
weight_transform_hooks=weight_transform_hooks,
)
for lora_weight_keys in list_of_lora_weight_keys:
# check if all LoRA keys are used
if len(lora_weight_keys) > 0:
# if there are still LoRA keys left, it means they are not used in the model
# this is a warning, not an error
logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}")
return state_dict
def load_safetensors_with_fp8_optimization_and_hook(
model_files: list[str],
fp8_optimization: bool,
calc_device: torch.device,
move_to_device: bool = False,
dit_weight_dtype: Optional[torch.dtype] = None,
target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
weight_hook: callable = None,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict[str, torch.Tensor]:
"""
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
"""
if fp8_optimization:
logger.info(
f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
)
# dit_weight_dtype is not used because we use fp8 optimization
state_dict = load_safetensors_with_fp8_optimization(
model_files,
calc_device,
target_keys,
exclude_keys,
move_to_device=move_to_device,
weight_hook=weight_hook,
disable_numpy_memmap=disable_numpy_memmap,
weight_transform_hooks=weight_transform_hooks,
)
else:
logger.info(
f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
)
state_dict = {}
for model_file in model_files:
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
if weight_hook is None and move_to_device:
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
else:
value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer
if weight_hook is not None:
value = weight_hook(key, value, keep_on_calc_device=move_to_device)
if move_to_device:
value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
elif dit_weight_dtype is not None:
value = value.to(dit_weight_dtype)
state_dict[key] = value
if move_to_device:
synchronize_device(calc_device)
return state_dict
import os
import re
from typing import Dict, List, Optional, Union
import torch
from tqdm import tqdm
from library.device_utils import synchronize_device
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def filter_lora_state_dict(
weights_sd: Dict[str, torch.Tensor],
include_pattern: Optional[str] = None,
exclude_pattern: Optional[str] = None,
) -> Dict[str, torch.Tensor]:
# apply include/exclude patterns
original_key_count = len(weights_sd.keys())
if include_pattern is not None:
regex_include = re.compile(include_pattern)
weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
if exclude_pattern is not None:
original_key_count_ex = len(weights_sd.keys())
regex_exclude = re.compile(exclude_pattern)
weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}")
if len(weights_sd) != original_key_count:
remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
remaining_keys.sort()
logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
if len(weights_sd) == 0:
logger.warning("No keys left after filtering.")
return weights_sd
def load_safetensors_with_lora_and_fp8(
model_files: Union[str, List[str]],
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
lora_multipliers: Optional[List[float]],
fp8_optimization: bool,
calc_device: torch.device,
move_to_device: bool = False,
dit_weight_dtype: Optional[torch.dtype] = None,
target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict[str, torch.Tensor]:
"""
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
Args:
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
fp8_optimization (bool): Whether to apply FP8 optimization.
calc_device (torch.device): Device to calculate on.
move_to_device (bool): Whether to move tensors to the calculation device after loading.
target_keys (Optional[List[str]]): Keys to target for optimization.
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
"""
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
if isinstance(model_files, str):
model_files = [model_files]
extended_model_files = []
for model_file in model_files:
split_filenames = get_split_weight_filenames(model_file)
if split_filenames is not None:
extended_model_files.extend(split_filenames)
else:
extended_model_files.append(model_file)
model_files = extended_model_files
logger.info(f"Loading model files: {model_files}")
# load LoRA weights
weight_hook = None
if lora_weights_list is None or len(lora_weights_list) == 0:
lora_weights_list = []
lora_multipliers = []
list_of_lora_weight_keys = []
else:
list_of_lora_weight_keys = []
for lora_sd in lora_weights_list:
lora_weight_keys = set(lora_sd.keys())
list_of_lora_weight_keys.append(lora_weight_keys)
if lora_multipliers is None:
lora_multipliers = [1.0] * len(lora_weights_list)
while len(lora_multipliers) < len(lora_weights_list):
lora_multipliers.append(1.0)
if len(lora_multipliers) > len(lora_weights_list):
lora_multipliers = lora_multipliers[: len(lora_weights_list)]
# Merge LoRA weights into the state dict
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
# make hook for LoRA merging
def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
if not model_weight_key.endswith(".weight"):
return model_weight
original_device = model_weight.device
if original_device != calc_device:
model_weight = model_weight.to(calc_device) # to make calculation faster
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
# check if this weight has LoRA weights
lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
found = False
for prefix in ["lora_unet_", ""]:
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
down_key = lora_name + ".lora_down.weight"
up_key = lora_name + ".lora_up.weight"
alpha_key = lora_name + ".alpha"
if down_key in lora_weight_keys and up_key in lora_weight_keys:
found = True
break
if found:
# Standard LoRA merge
# get LoRA weights
down_weight = lora_sd[down_key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
down_weight = down_weight.to(calc_device)
up_weight = up_weight.to(calc_device)
original_dtype = model_weight.dtype
if original_dtype.itemsize == 1: # fp8
# temporarily convert to float16 for calculation
model_weight = model_weight.to(torch.float16)
down_weight = down_weight.to(torch.float16)
up_weight = up_weight.to(torch.float16)
# W <- W + U * D
if len(model_weight.size()) == 2:
# linear
if len(up_weight.size()) == 4: # use linear projection mismatch
up_weight = up_weight.squeeze(3).squeeze(2)
down_weight = down_weight.squeeze(3).squeeze(2)
model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
model_weight = (
model_weight
+ multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
model_weight = model_weight + multiplier * conved * scale
if original_dtype.itemsize == 1: # fp8
model_weight = model_weight.to(original_dtype) # convert back to original dtype
# remove LoRA keys from set
lora_weight_keys.remove(down_key)
lora_weight_keys.remove(up_key)
if alpha_key in lora_weight_keys:
lora_weight_keys.remove(alpha_key)
continue
# Check for LoHa/LoKr weights with same prefix search
for prefix in ["lora_unet_", ""]:
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
hada_key = lora_name + ".hada_w1_a"
lokr_key = lora_name + ".lokr_w1"
if hada_key in lora_weight_keys:
# LoHa merge
from networks.loha import merge_weights_to_tensor as loha_merge
model_weight = loha_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device)
break
elif lokr_key in lora_weight_keys:
# LoKr merge
from networks.lokr import merge_weights_to_tensor as lokr_merge
model_weight = lokr_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device)
break
if not keep_on_calc_device and original_device != calc_device:
model_weight = model_weight.to(original_device) # move back to original device
return model_weight
weight_hook = weight_hook_func
state_dict = load_safetensors_with_fp8_optimization_and_hook(
model_files,
fp8_optimization,
calc_device,
move_to_device,
dit_weight_dtype,
target_keys,
exclude_keys,
weight_hook=weight_hook,
disable_numpy_memmap=disable_numpy_memmap,
weight_transform_hooks=weight_transform_hooks,
)
for lora_weight_keys in list_of_lora_weight_keys:
# check if all LoRA keys are used
if len(lora_weight_keys) > 0:
# if there are still LoRA keys left, it means they are not used in the model
# this is a warning, not an error
logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}")
return state_dict
def load_safetensors_with_fp8_optimization_and_hook(
model_files: list[str],
fp8_optimization: bool,
calc_device: torch.device,
move_to_device: bool = False,
dit_weight_dtype: Optional[torch.dtype] = None,
target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
weight_hook: callable = None,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict[str, torch.Tensor]:
"""
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
"""
if fp8_optimization:
logger.info(
f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
)
# dit_weight_dtype is not used because we use fp8 optimization
state_dict = load_safetensors_with_fp8_optimization(
model_files,
calc_device,
target_keys,
exclude_keys,
move_to_device=move_to_device,
weight_hook=weight_hook,
disable_numpy_memmap=disable_numpy_memmap,
weight_transform_hooks=weight_transform_hooks,
)
else:
logger.info(
f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
)
state_dict = {}
for model_file in model_files:
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
if weight_hook is None and move_to_device:
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
else:
value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer
if weight_hook is not None:
value = weight_hook(key, value, keep_on_calc_device=move_to_device)
if move_to_device:
value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
elif dit_weight_dtype is not None:
value = value.to(dit_weight_dtype)
state_dict[key] = value
if move_to_device:
synchronize_device(calc_device)
return state_dict

450
networks/loha.py Normal file
View File

@@ -0,0 +1,450 @@
# LoHa (Low-rank Hadamard Product) network module
# Reference: https://arxiv.org/abs/2108.06098
#
# Based on the LyCORIS project by KohakuBlueleaf
# https://github.com/KohakuBlueleaf/LyCORIS
import ast
import os
import logging
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs
from library.utils import setup_logging
setup_logging()
logger = logging.getLogger(__name__)
class HadaWeight(torch.autograd.Function):
"""Efficient Hadamard product forward/backward for LoHa.
Computes ((w1a @ w1b) * (w2a @ w2b)) * scale with custom backward
that recomputes intermediates instead of storing them.
"""
@staticmethod
def forward(ctx, w1a, w1b, w2a, w2b, scale=None):
if scale is None:
scale = torch.tensor(1, device=w1a.device, dtype=w1a.dtype)
ctx.save_for_backward(w1a, w1b, w2a, w2b, scale)
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
return diff_weight
@staticmethod
def backward(ctx, grad_out):
(w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors
grad_out = grad_out * scale
temp = grad_out * (w2a @ w2b)
grad_w1a = temp @ w1b.T
grad_w1b = w1a.T @ temp
temp = grad_out * (w1a @ w1b)
grad_w2a = temp @ w2b.T
grad_w2b = w2a.T @ temp
del temp
return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
class LoHaModule(torch.nn.Module):
"""LoHa module for training. Replaces forward method of the original Linear/Conv2d."""
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=None,
rank_dropout=None,
module_dropout=None,
**kwargs,
):
super().__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
is_conv2d = org_module.__class__.__name__ == "Conv2d"
if is_conv2d:
in_dim = org_module.in_channels
out_dim = org_module.out_channels
self.is_conv = True
if org_module.kernel_size != (1, 1):
raise ValueError("LoHa Conv2d 3x3 (Tucker decomposition) is not supported yet")
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.is_conv = False
# Hadamard product parameters: ΔW = (w1a @ w1b) * (w2a @ w2b)
self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
# Initialization: w1_a normal(0.1), w1_b normal(1.0), w2_a = 0, w2_b normal(1.0)
# Ensures ΔW = 0 at init since w2_a = 0
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
torch.nn.init.constant_(self.hada_w2_a, 0)
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy()
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha))
self.multiplier = multiplier
self.org_module = org_module # remove in applying
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def get_diff_weight(self):
"""Return materialized weight delta as a 2D matrix."""
scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
return HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
def forward(self, x):
org_forwarded = self.org_forward(x)
# module dropout
if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout:
return org_forwarded
diff_weight = self.get_diff_weight()
# rank dropout (applied on output dimension)
if self.rank_dropout is not None and self.training:
drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
drop = drop.view(-1, 1)
diff_weight = diff_weight * drop
scale = 1.0 / (1.0 - self.rank_dropout)
else:
scale = 1.0
if self.is_conv:
# Conv2d 1x1: reshape to 4D for conv operation
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
return org_forwarded + F.conv2d(x, diff_weight) * self.multiplier * scale
else:
return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
class LoHaInfModule(LoHaModule):
"""LoHa module for inference. Supports merge_to and get_weight."""
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
**kwargs,
):
# no dropout for inference
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
self.org_module_ref = [org_module]
self.enabled = True
self.network: AdditionalNetwork = None
def set_network(self, network):
self.network = network
def merge_to(self, sd, dtype, device):
# extract weight from org_module
org_sd = self.org_module.state_dict()
weight = org_sd["weight"]
org_dtype = weight.dtype
org_device = weight.device
weight = weight.to(torch.float)
if dtype is None:
dtype = org_dtype
if device is None:
device = org_device
# get LoHa weights
w1a = sd["hada_w1_a"].to(torch.float).to(device)
w1b = sd["hada_w1_b"].to(torch.float).to(device)
w2a = sd["hada_w2_a"].to(torch.float).to(device)
w2b = sd["hada_w2_b"].to(torch.float).to(device)
# compute ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale
if self.is_conv:
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
weight = weight.to(device) + self.multiplier * diff_weight
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
w1a = self.hada_w1_a.to(torch.float)
w1b = self.hada_w1_b.to(torch.float)
w2a = self.hada_w2_a.to(torch.float)
w2b = self.hada_w2_b.to(torch.float)
weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale * multiplier
if self.is_conv:
weight = weight.unsqueeze(2).unsqueeze(3)
return weight
def default_forward(self, x):
diff_weight = self.get_diff_weight()
if self.is_conv:
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
return self.org_forward(x) + F.conv2d(x, diff_weight) * self.multiplier
else:
return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
def forward(self, x):
if not self.enabled:
return self.org_forward(x)
return self.default_forward(x)
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae,
text_encoder,
unet,
neuron_dropout: Optional[float] = None,
**kwargs,
):
"""Create a LoHa network. Called by train_network.py via network_module.create_network()."""
if network_dim is None:
network_dim = 4
if network_alpha is None:
network_alpha = 1.0
# handle text_encoder as list
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
# detect architecture
arch_config = detect_arch_config(unet, text_encoders)
# train LLM adapter
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
if train_llm_adapter is not None:
train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False
# exclude patterns
exclude_patterns = kwargs.get("exclude_patterns", None)
if exclude_patterns is None:
exclude_patterns = []
else:
exclude_patterns = ast.literal_eval(exclude_patterns)
if not isinstance(exclude_patterns, list):
exclude_patterns = [exclude_patterns]
# add default exclude patterns from arch config
exclude_patterns.extend(arch_config.default_excludes)
# include patterns
include_patterns = kwargs.get("include_patterns", None)
if include_patterns is not None:
include_patterns = ast.literal_eval(include_patterns)
if not isinstance(include_patterns, list):
include_patterns = [include_patterns]
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
rank_dropout = float(rank_dropout)
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
# conv dim/alpha (for future Conv2d 3x3 support)
conv_lora_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get("conv_alpha", None)
if conv_lora_dim is not None:
conv_lora_dim = int(conv_lora_dim)
if conv_alpha is None:
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)
# verbose
verbose = kwargs.get("verbose", "false")
if verbose is not None:
verbose = True if str(verbose).lower() == "true" else False
# regex-specific learning rates / dimensions
network_reg_lrs = kwargs.get("network_reg_lrs", None)
reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None
network_reg_dims = kwargs.get("network_reg_dims", None)
reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None
network = AdditionalNetwork(
text_encoders,
unet,
arch_config=arch_config,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
module_class=LoHaModule,
conv_lora_dim=conv_lora_dim,
conv_alpha=conv_alpha,
train_llm_adapter=train_llm_adapter,
exclude_patterns=exclude_patterns,
include_patterns=include_patterns,
reg_dims=reg_dims,
reg_lrs=reg_lrs,
verbose=verbose,
)
# LoRA+ support
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
"""Create a LoHa network from saved weights. Called by train_network.py."""
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# detect dim/alpha from weights
modules_dim = {}
modules_alpha = {}
train_llm_adapter = False
for key, value in weights_sd.items():
if "." not in key:
continue
lora_name = key.split(".")[0]
if "alpha" in key:
modules_alpha[lora_name] = value
elif "hada_w1_b" in key:
dim = value.shape[0]
modules_dim[lora_name] = dim
if "llm_adapter" in lora_name:
train_llm_adapter = True
# handle text_encoder as list
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
# detect architecture
arch_config = detect_arch_config(unet, text_encoders)
module_class = LoHaInfModule if for_inference else LoHaModule
network = AdditionalNetwork(
text_encoders,
unet,
arch_config=arch_config,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
train_llm_adapter=train_llm_adapter,
)
return network, weights_sd
def merge_weights_to_tensor(
model_weight: torch.Tensor,
lora_name: str,
lora_sd: Dict[str, torch.Tensor],
lora_weight_keys: set,
multiplier: float,
calc_device: torch.device,
) -> torch.Tensor:
"""Merge LoHa weights directly into a model weight tensor.
No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
Returns model_weight unchanged if no matching LoHa keys found.
"""
w1a_key = lora_name + ".hada_w1_a"
w1b_key = lora_name + ".hada_w1_b"
w2a_key = lora_name + ".hada_w2_a"
w2b_key = lora_name + ".hada_w2_b"
alpha_key = lora_name + ".alpha"
if w1a_key not in lora_weight_keys:
return model_weight
w1a = lora_sd[w1a_key].to(calc_device)
w1b = lora_sd[w1b_key].to(calc_device)
w2a = lora_sd[w2a_key].to(calc_device)
w2b = lora_sd[w2b_key].to(calc_device)
dim = w1b.shape[0]
alpha = lora_sd.get(alpha_key, torch.tensor(dim))
if isinstance(alpha, torch.Tensor):
alpha = alpha.item()
scale = alpha / dim
original_dtype = model_weight.dtype
if original_dtype.itemsize == 1: # fp8
model_weight = model_weight.to(torch.float16)
w1a, w1b, w2a, w2b = w1a.to(torch.float16), w1b.to(torch.float16), w2a.to(torch.float16), w2b.to(torch.float16)
# ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
# handle Conv2d 1x1 weights (4D tensors)
if len(model_weight.shape) == 4:
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
model_weight = model_weight + multiplier * diff_weight
if original_dtype.itemsize == 1:
model_weight = model_weight.to(original_dtype)
# remove consumed keys
for key in [w1a_key, w1b_key, w2a_key, w2b_key, alpha_key]:
lora_weight_keys.discard(key)
return model_weight

544
networks/lokr.py Normal file
View File

@@ -0,0 +1,544 @@
# LoKr (Low-rank Kronecker Product) network module
# Reference: https://arxiv.org/abs/2309.14859
#
# Based on the LyCORIS project by KohakuBlueleaf
# https://github.com/KohakuBlueleaf/LyCORIS
import ast
import math
import os
import logging
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs
from library.utils import setup_logging
setup_logging()
logger = logging.getLogger(__name__)
def factorization(dimension: int, factor: int = -1) -> tuple:
"""Return a tuple of two values whose product equals dimension,
optimized for balanced factors.
In LoKr, the first value is for the weight scale (smaller),
and the second value is for the weight (larger).
Examples:
factor=-1: 128 -> (8, 16), 512 -> (16, 32), 1024 -> (32, 32)
factor=4: 128 -> (4, 32), 512 -> (4, 128)
"""
if factor > 0 and (dimension % factor) == 0:
m = factor
n = dimension // factor
if m > n:
n, m = m, n
return m, n
if factor < 0:
factor = dimension
m, n = 1, dimension
length = m + n
while m < n:
new_m = m + 1
while dimension % new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m > factor:
break
else:
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n
def make_kron(w1, w2, scale):
"""Compute Kronecker product of w1 and w2, scaled by scale."""
if w1.dim() != w2.dim():
for _ in range(w2.dim() - w1.dim()):
w1 = w1.unsqueeze(-1)
w2 = w2.contiguous()
rebuild = torch.kron(w1, w2)
if scale != 1:
rebuild = rebuild * scale
return rebuild
class LoKrModule(torch.nn.Module):
"""LoKr module for training. Replaces forward method of the original Linear/Conv2d."""
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=None,
rank_dropout=None,
module_dropout=None,
factor=-1,
**kwargs,
):
super().__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
is_conv2d = org_module.__class__.__name__ == "Conv2d"
if is_conv2d:
in_dim = org_module.in_channels
out_dim = org_module.out_channels
self.is_conv = True
if org_module.kernel_size != (1, 1):
raise ValueError("LoKr Conv2d 3x3 (Tucker decomposition) is not supported yet")
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.is_conv = False
factor = int(factor)
self.use_w2 = False
# Factorize dimensions
in_m, in_n = factorization(in_dim, factor)
out_l, out_k = factorization(out_dim, factor)
# w1 is always a full matrix (the "scale" factor, small)
self.lokr_w1 = nn.Parameter(torch.empty(out_l, in_m))
# w2: low-rank decomposition if rank is small enough, otherwise full matrix
if lora_dim < max(out_k, in_n) / 2:
self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim))
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n))
else:
self.use_w2 = True
self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n))
if lora_dim >= max(out_k, in_n) / 2:
logger.warning(
f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} "
f"and factor={factor}, using full matrix mode."
)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy()
alpha = lora_dim if alpha is None or alpha == 0 else alpha
# if both w1 and w2 are full matrices, use scale = 1
if self.use_w2:
alpha = lora_dim
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha))
# Initialization
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
if self.use_w2:
torch.nn.init.constant_(self.lokr_w2, 0)
else:
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
torch.nn.init.constant_(self.lokr_w2_b, 0)
# Ensures ΔW = kron(w1, 0) = 0 at init
self.multiplier = multiplier
self.org_module = org_module # remove in applying
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def get_diff_weight(self):
"""Return materialized weight delta."""
w1 = self.lokr_w1
if self.use_w2:
w2 = self.lokr_w2
else:
w2 = self.lokr_w2_a @ self.lokr_w2_b
return make_kron(w1, w2, self.scale)
def forward(self, x):
org_forwarded = self.org_forward(x)
# module dropout
if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout:
return org_forwarded
diff_weight = self.get_diff_weight()
# rank dropout
if self.rank_dropout is not None and self.training:
drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
drop = drop.view(-1, 1)
diff_weight = diff_weight * drop
scale = 1.0 / (1.0 - self.rank_dropout)
else:
scale = 1.0
if self.is_conv:
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
return org_forwarded + F.conv2d(x, diff_weight) * self.multiplier * scale
else:
return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
class LoKrInfModule(LoKrModule):
"""LoKr module for inference. Supports merge_to and get_weight."""
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
**kwargs,
):
# no dropout for inference; pass factor from kwargs if present
factor = kwargs.pop("factor", -1)
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, factor=factor)
self.org_module_ref = [org_module]
self.enabled = True
self.network: AdditionalNetwork = None
def set_network(self, network):
self.network = network
def merge_to(self, sd, dtype, device):
# extract weight from org_module
org_sd = self.org_module.state_dict()
weight = org_sd["weight"]
org_dtype = weight.dtype
org_device = weight.device
weight = weight.to(torch.float)
if dtype is None:
dtype = org_dtype
if device is None:
device = org_device
# get LoKr weights
w1 = sd["lokr_w1"].to(torch.float).to(device)
if "lokr_w2" in sd:
w2 = sd["lokr_w2"].to(torch.float).to(device)
else:
w2a = sd["lokr_w2_a"].to(torch.float).to(device)
w2b = sd["lokr_w2_b"].to(torch.float).to(device)
w2 = w2a @ w2b
# compute ΔW via Kronecker product
diff_weight = make_kron(w1, w2, self.scale)
if self.is_conv:
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
weight = weight.to(device) + self.multiplier * diff_weight
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
w1 = self.lokr_w1.to(torch.float)
if self.use_w2:
w2 = self.lokr_w2.to(torch.float)
else:
w2 = (self.lokr_w2_a @ self.lokr_w2_b).to(torch.float)
weight = make_kron(w1, w2, self.scale) * multiplier
if self.is_conv:
weight = weight.unsqueeze(2).unsqueeze(3)
return weight
def default_forward(self, x):
diff_weight = self.get_diff_weight()
if self.is_conv:
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
return self.org_forward(x) + F.conv2d(x, diff_weight) * self.multiplier
else:
return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
def forward(self, x):
if not self.enabled:
return self.org_forward(x)
return self.default_forward(x)
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae,
text_encoder,
unet,
neuron_dropout: Optional[float] = None,
**kwargs,
):
"""Create a LoKr network. Called by train_network.py via network_module.create_network()."""
if network_dim is None:
network_dim = 4
if network_alpha is None:
network_alpha = 1.0
# handle text_encoder as list
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
# detect architecture
arch_config = detect_arch_config(unet, text_encoders)
# train LLM adapter
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
if train_llm_adapter is not None:
train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False
# exclude patterns
exclude_patterns = kwargs.get("exclude_patterns", None)
if exclude_patterns is None:
exclude_patterns = []
else:
exclude_patterns = ast.literal_eval(exclude_patterns)
if not isinstance(exclude_patterns, list):
exclude_patterns = [exclude_patterns]
# add default exclude patterns from arch config
exclude_patterns.extend(arch_config.default_excludes)
# include patterns
include_patterns = kwargs.get("include_patterns", None)
if include_patterns is not None:
include_patterns = ast.literal_eval(include_patterns)
if not isinstance(include_patterns, list):
include_patterns = [include_patterns]
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
rank_dropout = float(rank_dropout)
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
# conv dim/alpha (for future Conv2d 3x3 support)
conv_lora_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get("conv_alpha", None)
if conv_lora_dim is not None:
conv_lora_dim = int(conv_lora_dim)
if conv_alpha is None:
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)
# factor for LoKr
factor = int(kwargs.get("factor", -1))
# verbose
verbose = kwargs.get("verbose", "false")
if verbose is not None:
verbose = True if str(verbose).lower() == "true" else False
# regex-specific learning rates / dimensions
network_reg_lrs = kwargs.get("network_reg_lrs", None)
reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None
network_reg_dims = kwargs.get("network_reg_dims", None)
reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None
network = AdditionalNetwork(
text_encoders,
unet,
arch_config=arch_config,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
module_class=LoKrModule,
module_kwargs={"factor": factor},
conv_lora_dim=conv_lora_dim,
conv_alpha=conv_alpha,
train_llm_adapter=train_llm_adapter,
exclude_patterns=exclude_patterns,
include_patterns=include_patterns,
reg_dims=reg_dims,
reg_lrs=reg_lrs,
verbose=verbose,
)
# LoRA+ support
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
"""Create a LoKr network from saved weights. Called by train_network.py."""
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# detect dim/alpha from weights
modules_dim = {}
modules_alpha = {}
train_llm_adapter = False
for key, value in weights_sd.items():
if "." not in key:
continue
lora_name = key.split(".")[0]
if "alpha" in key:
modules_alpha[lora_name] = value
elif "lokr_w2_a" in key:
# low-rank mode: dim = w2_a.shape[1]
dim = value.shape[1]
modules_dim[lora_name] = dim
elif "lokr_w2" in key and "lokr_w2_a" not in key and "lokr_w2_b" not in key:
# full matrix mode: set dim large enough to trigger full-matrix path
if lora_name not in modules_dim:
modules_dim[lora_name] = max(value.shape)
if "llm_adapter" in lora_name:
train_llm_adapter = True
# handle text_encoder as list
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
# detect architecture
arch_config = detect_arch_config(unet, text_encoders)
# extract factor for LoKr
factor = int(kwargs.get("factor", -1))
module_class = LoKrInfModule if for_inference else LoKrModule
module_kwargs = {"factor": factor}
network = AdditionalNetwork(
text_encoders,
unet,
arch_config=arch_config,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
module_kwargs=module_kwargs,
train_llm_adapter=train_llm_adapter,
)
return network, weights_sd
def merge_weights_to_tensor(
model_weight: torch.Tensor,
lora_name: str,
lora_sd: Dict[str, torch.Tensor],
lora_weight_keys: set,
multiplier: float,
calc_device: torch.device,
) -> torch.Tensor:
"""Merge LoKr weights directly into a model weight tensor.
No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
Returns model_weight unchanged if no matching LoKr keys found.
"""
w1_key = lora_name + ".lokr_w1"
w2_key = lora_name + ".lokr_w2"
w2a_key = lora_name + ".lokr_w2_a"
w2b_key = lora_name + ".lokr_w2_b"
alpha_key = lora_name + ".alpha"
if w1_key not in lora_weight_keys:
return model_weight
w1 = lora_sd[w1_key].to(calc_device)
# determine low-rank vs full matrix mode
if w2a_key in lora_weight_keys:
# low-rank: w2 = w2_a @ w2_b
w2a = lora_sd[w2a_key].to(calc_device)
w2b = lora_sd[w2b_key].to(calc_device)
dim = w2a.shape[1]
consumed_keys = [w1_key, w2a_key, w2b_key, alpha_key]
elif w2_key in lora_weight_keys:
# full matrix mode
w2a = None
w2b = None
dim = None
consumed_keys = [w1_key, w2_key, alpha_key]
else:
return model_weight
alpha = lora_sd.get(alpha_key, None)
if alpha is not None and isinstance(alpha, torch.Tensor):
alpha = alpha.item()
# compute scale
if w2a is not None:
# low-rank mode
if alpha is None:
alpha = dim
scale = alpha / dim
else:
# full matrix mode: scale = 1.0
scale = 1.0
original_dtype = model_weight.dtype
if original_dtype.itemsize == 1: # fp8
model_weight = model_weight.to(torch.float16)
w1 = w1.to(torch.float16)
if w2a is not None:
w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16)
# compute w2
if w2a is not None:
w2 = w2a @ w2b
else:
w2 = lora_sd[w2_key].to(calc_device)
if original_dtype.itemsize == 1:
w2 = w2.to(torch.float16)
# ΔW = kron(w1, w2) * scale
diff_weight = make_kron(w1, w2, scale)
# handle Conv2d 1x1 weights (4D tensors)
if len(model_weight.shape) == 4:
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
model_weight = model_weight + multiplier * diff_weight
if original_dtype.itemsize == 1:
model_weight = model_weight.to(original_dtype)
# remove consumed keys
for key in consumed_keys:
lora_weight_keys.discard(key)
return model_weight

View File

@@ -1,11 +1,11 @@
# LoRA network module for Anima
import ast
import math
import os
import re
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
from library.utils import setup_logging
from networks.lora_flux import LoRAModule, LoRAInfModule
import logging
@@ -13,6 +13,210 @@ setup_logging()
logger = logging.getLogger(__name__)
class LoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=None,
rank_dropout=None,
module_dropout=None,
):
"""
if alpha == 0 or None, alpha is rank (no scaling).
"""
super().__init__()
self.lora_name = lora_name
if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_dim = lora_dim
if org_module.__class__.__name__ == "Conv2d":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
# same as microsoft's
self.multiplier = multiplier
self.org_module = org_module # remove in applying
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
org_forwarded = self.org_forward(x)
# module dropout
if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout:
return org_forwarded
lx = self.lora_down(x)
# normal dropout
if self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout)
# rank dropout
if self.rank_dropout is not None and self.training:
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
if len(lx.size()) == 3:
mask = mask.unsqueeze(1) # for Text Encoder
elif len(lx.size()) == 4:
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
lx = lx * mask
# scaling for rank dropout: treat as if the rank is changed
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
else:
scale = self.scale
lx = self.lora_up(lx)
return org_forwarded + lx * self.multiplier * scale
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
class LoRAInfModule(LoRAModule):
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
**kwargs,
):
# no dropout for inference
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
self.org_module_ref = [org_module] # 後から参照できるように
self.enabled = True
self.network: LoRANetwork = None
def set_network(self, network):
self.network = network
# freezeしてマージする
def merge_to(self, sd, dtype, device):
# extract weight from org_module
org_sd = self.org_module.state_dict()
weight = org_sd["weight"]
org_dtype = weight.dtype
org_device = weight.device
weight = weight.to(torch.float) # calc in float
if dtype is None:
dtype = org_dtype
if device is None:
device = org_device
# get up/down weight
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
# merge weight
if len(weight.size()) == 2:
# linear
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* self.scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale
# set weight to org_module
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
# 復元できるマージのため、このモジュールのweightを返す
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
# get up/down weight from module
up_weight = self.lora_up.weight.to(torch.float)
down_weight = self.lora_down.weight.to(torch.float)
# pre-calculated weight
if len(down_weight.size()) == 2:
# linear
weight = self.multiplier * (up_weight @ down_weight) * self.scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* self.scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = self.multiplier * conved * self.scale
return weight
def default_forward(self, x):
# logger.info(f"default_forward {self.lora_name} {x.size()}")
lx = self.lora_down(x)
lx = self.lora_up(lx)
return self.org_forward(x) + lx * self.multiplier * self.scale
def forward(self, x):
if not self.enabled:
return self.org_forward(x)
return self.default_forward(x)
def create_network(
multiplier: float,
network_dim: Optional[int],

540
networks/network_base.py Normal file
View File

@@ -0,0 +1,540 @@
# Shared network base for additional network modules (like LyCORIS-family modules: LoHa, LoKr, etc).
# Provides architecture detection and a generic AdditionalNetwork class.
import ast
import math
import os
import re
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
@dataclass
class ArchConfig:
unet_target_modules: List[str]
te_target_modules: List[str]
unet_prefix: str
te_prefixes: List[str]
default_excludes: List[str] = field(default_factory=list)
adapter_target_modules: List[str] = field(default_factory=list)
def detect_arch_config(unet, text_encoders) -> ArchConfig:
"""Detect architecture from model structure and return ArchConfig."""
from library.sdxl_original_unet import SdxlUNet2DConditionModel
# Check SDXL first
if unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel):
return ArchConfig(
unet_target_modules=["Transformer2DModel"],
te_target_modules=["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"],
unet_prefix="lora_unet",
te_prefixes=["lora_te1", "lora_te2"],
default_excludes=[],
)
# Check Anima: look for Block class in named_modules
module_class_names = set()
if unet is not None:
for module in unet.modules():
module_class_names.add(type(module).__name__)
if "Block" in module_class_names:
return ArchConfig(
unet_target_modules=["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"],
te_target_modules=["Qwen3Attention", "Qwen3MLP", "Qwen3SdpaAttention", "Qwen3FlashAttention2"],
unet_prefix="lora_unet",
te_prefixes=["lora_te"],
default_excludes=[r".*(_modulation|_norm|_embedder|final_layer).*"],
adapter_target_modules=["LLMAdapterTransformerBlock"],
)
raise ValueError(f"Cannot auto-detect architecture for LyCORIS. Module classes found: {sorted(module_class_names)}")
def _parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, Union[int, float]]:
"""Parse a string of key-value pairs separated by commas."""
pairs = {}
for pair in kv_pair_str.split(","):
pair = pair.strip()
if not pair:
continue
if "=" not in pair:
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
continue
key, value = pair.split("=", 1)
key = key.strip()
value = value.strip()
try:
pairs[key] = int(value) if is_int else float(value)
except ValueError:
logger.warning(f"Invalid value for {key}: {value}")
return pairs
class AdditionalNetwork(torch.nn.Module):
"""Generic Additional network that supports LoHa, LoKr, and similar module types.
Constructed with a module_class parameter to inject the specific module type.
Based on the lora_anima.py LoRANetwork, generalized for multiple architectures.
"""
def __init__(
self,
text_encoders: list,
unet,
arch_config: ArchConfig,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
module_class: Type[torch.nn.Module] = None,
module_kwargs: Optional[Dict] = None,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
exclude_patterns: Optional[List[str]] = None,
include_patterns: Optional[List[str]] = None,
reg_dims: Optional[Dict[str, int]] = None,
reg_lrs: Optional[Dict[str, float]] = None,
train_llm_adapter: bool = False,
verbose: bool = False,
) -> None:
super().__init__()
assert module_class is not None, "module_class must be specified"
self.multiplier = multiplier
self.lora_dim = lora_dim
self.alpha = alpha
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.train_llm_adapter = train_llm_adapter
self.reg_dims = reg_dims
self.reg_lrs = reg_lrs
self.arch_config = arch_config
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if module_kwargs is None:
module_kwargs = {}
if modules_dim is not None:
logger.info(f"create {module_class.__name__} network from weights")
else:
logger.info(f"create {module_class.__name__} network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
# compile regular expressions
def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
re_patterns = []
if patterns is not None:
for pattern in patterns:
try:
re_pattern = re.compile(pattern)
except re.error as e:
logger.error(f"Invalid pattern '{pattern}': {e}")
continue
re_patterns.append(re_pattern)
return re_patterns
exclude_re_patterns = str_to_re_patterns(exclude_patterns)
include_re_patterns = str_to_re_patterns(include_patterns)
# create module instances
def create_modules(
prefix: str,
root_module: torch.nn.Module,
target_replace_modules: List[str],
default_dim: Optional[int] = None,
) -> Tuple[List[torch.nn.Module], List[str]]:
loras = []
skipped = []
for name, module in root_module.named_modules():
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
if target_replace_modules is None:
module = root_module
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
original_name = (name + "." if name else "") + child_name
lora_name = f"{prefix}.{original_name}".replace(".", "_")
# exclude/include filter
excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns)
included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns)
if excluded and not included:
if verbose:
logger.info(f"exclude: {original_name}")
continue
dim = None
alpha_val = None
if modules_dim is not None:
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha_val = modules_alpha[lora_name]
else:
if self.reg_dims is not None:
for reg, d in self.reg_dims.items():
if re.fullmatch(reg, original_name):
dim = d
alpha_val = self.alpha
logger.info(f"Module {original_name} matched with regex '{reg}' -> dim: {dim}")
break
# fallback to default dim
if dim is None:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
elif is_conv2d and self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha_val = self.conv_alpha
if dim is None or dim == 0:
if is_linear or is_conv2d_1x1:
skipped.append(lora_name)
continue
lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha_val,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
**module_kwargs,
)
lora.original_name = original_name
loras.append(lora)
if target_replace_modules is None:
break
return loras, skipped
# Create modules for text encoders
self.text_encoder_loras: List[torch.nn.Module] = []
skipped_te = []
if text_encoders is not None:
for i, text_encoder in enumerate(text_encoders):
if text_encoder is None:
continue
# Determine prefix for this text encoder
if i < len(arch_config.te_prefixes):
te_prefix = arch_config.te_prefixes[i]
else:
te_prefix = arch_config.te_prefixes[0]
logger.info(f"create {module_class.__name__} for Text Encoder {i+1} (prefix={te_prefix}):")
te_loras, te_skipped = create_modules(te_prefix, text_encoder, arch_config.te_target_modules)
logger.info(f"create {module_class.__name__} for Text Encoder {i+1}: {len(te_loras)} modules.")
self.text_encoder_loras.extend(te_loras)
skipped_te += te_skipped
# Create modules for UNet/DiT
target_modules = list(arch_config.unet_target_modules)
if train_llm_adapter and arch_config.adapter_target_modules:
target_modules.extend(arch_config.adapter_target_modules)
self.unet_loras: List[torch.nn.Module]
self.unet_loras, skipped_un = create_modules(arch_config.unet_prefix, unet, target_modules)
logger.info(f"create {module_class.__name__} for UNet/DiT: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
logger.info(f"\t{lora.lora_name:60} {lora.lora_dim}, {lora.alpha}")
skipped = skipped_te + skipped_un
if verbose and len(skipped) > 0:
logger.warning(f"dim (rank) is 0, {len(skipped)} modules are skipped:")
for name in skipped:
logger.info(f"\t{name}")
# assertion: no duplicate names
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def set_enabled(self, is_enabled):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def apply_to(self, text_encoders, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
logger.info(f"enable modules for text encoder: {len(self.text_encoder_loras)} modules")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info(f"enable modules for UNet/DiT: {len(self.unet_loras)} modules")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
lora.apply_to()
self.add_module(lora.lora_name, lora)
def is_mergeable(self):
return True
def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None):
apply_text_encoder = apply_unet = False
te_prefixes = self.arch_config.te_prefixes
unet_prefix = self.arch_config.unet_prefix
for key in weights_sd.keys():
if any(key.startswith(p) for p in te_prefixes):
apply_text_encoder = True
elif key.startswith(unet_prefix):
apply_unet = True
if apply_text_encoder:
logger.info("enable modules for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info("enable modules for UNet/DiT")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
logger.info("weights are merged")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
text_encoder_lr = [default_lr]
elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
text_encoder_lr = [float(text_encoder_lr)]
elif len(text_encoder_lr) == 1:
pass # already a list with one element
self.requires_grad_(True)
all_params = []
lr_descriptions = []
def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}}
reg_groups = {}
reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
for lora in loras:
matched_reg_lr = None
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
if re.fullmatch(regex_str, lora.original_name):
matched_reg_lr = (i, reg_lr)
logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
break
for name, param in lora.named_parameters():
if matched_reg_lr is not None:
reg_idx, reg_lr = matched_reg_lr
group_key = f"reg_lr_{reg_idx}"
if group_key not in reg_groups:
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
# LoRA+ detection: check for "up" weight parameters
if loraplus_ratio is not None and self._is_plus_param(name):
reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param
else:
reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param
continue
if loraplus_ratio is not None and self._is_plus_param(name):
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
params = []
descriptions = []
for group_key, group in reg_groups.items():
reg_lr = group["lr"]
for key in ("lora", "plus"):
param_data = {"params": group[key].values()}
if len(param_data["params"]) == 0:
continue
if key == "plus":
param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr
else:
param_data["lr"] = reg_lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
desc = f"reg_lr_{group_key.split('_')[-1]}"
descriptions.append(desc + (" plus" if key == "plus" else ""))
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * loraplus_ratio
else:
param_data["lr"] = lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
descriptions.append("plus" if key == "plus" else "")
return params, descriptions
if self.text_encoder_loras:
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
# Group TE loras by prefix
for te_idx, te_prefix in enumerate(self.arch_config.te_prefixes):
te_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(te_prefix)]
if len(te_loras) > 0:
te_lr = text_encoder_lr[te_idx] if te_idx < len(text_encoder_lr) else text_encoder_lr[0]
logger.info(f"Text Encoder {te_idx+1} ({te_prefix}): {len(te_loras)} modules, LR {te_lr}")
params, descriptions = assemble_params(te_loras, te_lr, loraplus_ratio)
all_params.extend(params)
lr_descriptions.extend([f"textencoder {te_idx+1}" + (" " + d if d else "") for d in descriptions])
if self.unet_loras:
params, descriptions = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
return all_params, lr_descriptions
def _is_plus_param(self, name: str) -> bool:
"""Check if a parameter name corresponds to a 'plus' (higher LR) param for LoRA+.
For LoRA: lora_up. For LoHa: hada_w2_a (the second pair). For LoKr: lokr_w1 (the scale factor).
Override in subclass if needed. Default: check for common 'up' patterns.
"""
return "lora_up" in name or "hada_w2_a" in name or "lokr_w1" in name
def enable_gradient_checkpointing(self):
pass # not supported
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def backup_weights(self):
loras = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
loras = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
loras = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
sd = org_module.state_dict()
org_weight = sd["weight"]
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
sd["weight"] = org_weight + lora_weight
assert sd["weight"].shape == org_weight.shape
org_module.load_state_dict(sd)
org_module._lora_restored = False
lora.enabled = False