mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 00:32:25 +00:00
feat: Add LoHa/LoKr network support for SDXL and Anima
- networks/network_base.py: shared AdditionalNetwork base class with architecture auto-detection (SDXL/Anima) and generic module injection - networks/loha.py: LoHa (Low-rank Hadamard Product) module with HadaWeight custom autograd, training/inference classes, and factory functions - networks/lokr.py: LoKr (Low-rank Kronecker Product) module with factorization, training/inference classes, and factory functions - library/lora_utils.py: extend weight merge hook to detect and merge LoHa/LoKr weights alongside standard LoRA Linear and Conv2d 1x1 layers only; Conv2d 3x3 (Tucker decomposition) support will be added separately. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -11,3 +11,4 @@ GEMINI.md
|
||||
.claude
|
||||
.gemini
|
||||
MagicMock
|
||||
references
|
||||
@@ -1,267 +1,287 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from library.device_utils import synchronize_device
|
||||
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def filter_lora_state_dict(
|
||||
weights_sd: Dict[str, torch.Tensor],
|
||||
include_pattern: Optional[str] = None,
|
||||
exclude_pattern: Optional[str] = None,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# apply include/exclude patterns
|
||||
original_key_count = len(weights_sd.keys())
|
||||
if include_pattern is not None:
|
||||
regex_include = re.compile(include_pattern)
|
||||
weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
|
||||
logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
|
||||
|
||||
if exclude_pattern is not None:
|
||||
original_key_count_ex = len(weights_sd.keys())
|
||||
regex_exclude = re.compile(exclude_pattern)
|
||||
weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
|
||||
logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}")
|
||||
|
||||
if len(weights_sd) != original_key_count:
|
||||
remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
|
||||
remaining_keys.sort()
|
||||
logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
|
||||
if len(weights_sd) == 0:
|
||||
logger.warning("No keys left after filtering.")
|
||||
|
||||
return weights_sd
|
||||
|
||||
|
||||
def load_safetensors_with_lora_and_fp8(
|
||||
model_files: Union[str, List[str]],
|
||||
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
|
||||
lora_multipliers: Optional[List[float]],
|
||||
fp8_optimization: bool,
|
||||
calc_device: torch.device,
|
||||
move_to_device: bool = False,
|
||||
dit_weight_dtype: Optional[torch.dtype] = None,
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
disable_numpy_memmap: bool = False,
|
||||
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
|
||||
|
||||
Args:
|
||||
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
|
||||
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
|
||||
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
|
||||
fp8_optimization (bool): Whether to apply FP8 optimization.
|
||||
calc_device (torch.device): Device to calculate on.
|
||||
move_to_device (bool): Whether to move tensors to the calculation device after loading.
|
||||
target_keys (Optional[List[str]]): Keys to target for optimization.
|
||||
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
|
||||
disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
|
||||
weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
|
||||
"""
|
||||
|
||||
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
|
||||
if isinstance(model_files, str):
|
||||
model_files = [model_files]
|
||||
|
||||
extended_model_files = []
|
||||
for model_file in model_files:
|
||||
split_filenames = get_split_weight_filenames(model_file)
|
||||
if split_filenames is not None:
|
||||
extended_model_files.extend(split_filenames)
|
||||
else:
|
||||
extended_model_files.append(model_file)
|
||||
model_files = extended_model_files
|
||||
logger.info(f"Loading model files: {model_files}")
|
||||
|
||||
# load LoRA weights
|
||||
weight_hook = None
|
||||
if lora_weights_list is None or len(lora_weights_list) == 0:
|
||||
lora_weights_list = []
|
||||
lora_multipliers = []
|
||||
list_of_lora_weight_keys = []
|
||||
else:
|
||||
list_of_lora_weight_keys = []
|
||||
for lora_sd in lora_weights_list:
|
||||
lora_weight_keys = set(lora_sd.keys())
|
||||
list_of_lora_weight_keys.append(lora_weight_keys)
|
||||
|
||||
if lora_multipliers is None:
|
||||
lora_multipliers = [1.0] * len(lora_weights_list)
|
||||
while len(lora_multipliers) < len(lora_weights_list):
|
||||
lora_multipliers.append(1.0)
|
||||
if len(lora_multipliers) > len(lora_weights_list):
|
||||
lora_multipliers = lora_multipliers[: len(lora_weights_list)]
|
||||
|
||||
# Merge LoRA weights into the state dict
|
||||
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
|
||||
|
||||
# make hook for LoRA merging
|
||||
def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
|
||||
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
|
||||
|
||||
if not model_weight_key.endswith(".weight"):
|
||||
return model_weight
|
||||
|
||||
original_device = model_weight.device
|
||||
if original_device != calc_device:
|
||||
model_weight = model_weight.to(calc_device) # to make calculation faster
|
||||
|
||||
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
|
||||
# check if this weight has LoRA weights
|
||||
lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
|
||||
found = False
|
||||
for prefix in ["lora_unet_", ""]:
|
||||
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
|
||||
down_key = lora_name + ".lora_down.weight"
|
||||
up_key = lora_name + ".lora_up.weight"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
if down_key in lora_weight_keys and up_key in lora_weight_keys:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
continue # no LoRA weights for this model weight
|
||||
|
||||
# get LoRA weights
|
||||
down_weight = lora_sd[down_key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
down_weight = down_weight.to(calc_device)
|
||||
up_weight = up_weight.to(calc_device)
|
||||
|
||||
original_dtype = model_weight.dtype
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
# temporarily convert to float16 for calculation
|
||||
model_weight = model_weight.to(torch.float16)
|
||||
down_weight = down_weight.to(torch.float16)
|
||||
up_weight = up_weight.to(torch.float16)
|
||||
|
||||
# W <- W + U * D
|
||||
if len(model_weight.size()) == 2:
|
||||
# linear
|
||||
if len(up_weight.size()) == 4: # use linear projection mismatch
|
||||
up_weight = up_weight.squeeze(3).squeeze(2)
|
||||
down_weight = down_weight.squeeze(3).squeeze(2)
|
||||
model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
model_weight = (
|
||||
model_weight
|
||||
+ multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
model_weight = model_weight + multiplier * conved * scale
|
||||
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
model_weight = model_weight.to(original_dtype) # convert back to original dtype
|
||||
|
||||
# remove LoRA keys from set
|
||||
lora_weight_keys.remove(down_key)
|
||||
lora_weight_keys.remove(up_key)
|
||||
if alpha_key in lora_weight_keys:
|
||||
lora_weight_keys.remove(alpha_key)
|
||||
|
||||
if not keep_on_calc_device and original_device != calc_device:
|
||||
model_weight = model_weight.to(original_device) # move back to original device
|
||||
return model_weight
|
||||
|
||||
weight_hook = weight_hook_func
|
||||
|
||||
state_dict = load_safetensors_with_fp8_optimization_and_hook(
|
||||
model_files,
|
||||
fp8_optimization,
|
||||
calc_device,
|
||||
move_to_device,
|
||||
dit_weight_dtype,
|
||||
target_keys,
|
||||
exclude_keys,
|
||||
weight_hook=weight_hook,
|
||||
disable_numpy_memmap=disable_numpy_memmap,
|
||||
weight_transform_hooks=weight_transform_hooks,
|
||||
)
|
||||
|
||||
for lora_weight_keys in list_of_lora_weight_keys:
|
||||
# check if all LoRA keys are used
|
||||
if len(lora_weight_keys) > 0:
|
||||
# if there are still LoRA keys left, it means they are not used in the model
|
||||
# this is a warning, not an error
|
||||
logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}")
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_safetensors_with_fp8_optimization_and_hook(
|
||||
model_files: list[str],
|
||||
fp8_optimization: bool,
|
||||
calc_device: torch.device,
|
||||
move_to_device: bool = False,
|
||||
dit_weight_dtype: Optional[torch.dtype] = None,
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
weight_hook: callable = None,
|
||||
disable_numpy_memmap: bool = False,
|
||||
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
|
||||
"""
|
||||
if fp8_optimization:
|
||||
logger.info(
|
||||
f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
||||
)
|
||||
# dit_weight_dtype is not used because we use fp8 optimization
|
||||
state_dict = load_safetensors_with_fp8_optimization(
|
||||
model_files,
|
||||
calc_device,
|
||||
target_keys,
|
||||
exclude_keys,
|
||||
move_to_device=move_to_device,
|
||||
weight_hook=weight_hook,
|
||||
disable_numpy_memmap=disable_numpy_memmap,
|
||||
weight_transform_hooks=weight_transform_hooks,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
||||
)
|
||||
state_dict = {}
|
||||
for model_file in model_files:
|
||||
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
|
||||
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
|
||||
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
|
||||
if weight_hook is None and move_to_device:
|
||||
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
|
||||
else:
|
||||
value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer
|
||||
if weight_hook is not None:
|
||||
value = weight_hook(key, value, keep_on_calc_device=move_to_device)
|
||||
if move_to_device:
|
||||
value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
|
||||
elif dit_weight_dtype is not None:
|
||||
value = value.to(dit_weight_dtype)
|
||||
|
||||
state_dict[key] = value
|
||||
if move_to_device:
|
||||
synchronize_device(calc_device)
|
||||
|
||||
return state_dict
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from library.device_utils import synchronize_device
|
||||
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def filter_lora_state_dict(
|
||||
weights_sd: Dict[str, torch.Tensor],
|
||||
include_pattern: Optional[str] = None,
|
||||
exclude_pattern: Optional[str] = None,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# apply include/exclude patterns
|
||||
original_key_count = len(weights_sd.keys())
|
||||
if include_pattern is not None:
|
||||
regex_include = re.compile(include_pattern)
|
||||
weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
|
||||
logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
|
||||
|
||||
if exclude_pattern is not None:
|
||||
original_key_count_ex = len(weights_sd.keys())
|
||||
regex_exclude = re.compile(exclude_pattern)
|
||||
weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
|
||||
logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}")
|
||||
|
||||
if len(weights_sd) != original_key_count:
|
||||
remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
|
||||
remaining_keys.sort()
|
||||
logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
|
||||
if len(weights_sd) == 0:
|
||||
logger.warning("No keys left after filtering.")
|
||||
|
||||
return weights_sd
|
||||
|
||||
|
||||
def load_safetensors_with_lora_and_fp8(
|
||||
model_files: Union[str, List[str]],
|
||||
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
|
||||
lora_multipliers: Optional[List[float]],
|
||||
fp8_optimization: bool,
|
||||
calc_device: torch.device,
|
||||
move_to_device: bool = False,
|
||||
dit_weight_dtype: Optional[torch.dtype] = None,
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
disable_numpy_memmap: bool = False,
|
||||
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
|
||||
|
||||
Args:
|
||||
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
|
||||
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
|
||||
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
|
||||
fp8_optimization (bool): Whether to apply FP8 optimization.
|
||||
calc_device (torch.device): Device to calculate on.
|
||||
move_to_device (bool): Whether to move tensors to the calculation device after loading.
|
||||
target_keys (Optional[List[str]]): Keys to target for optimization.
|
||||
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
|
||||
disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
|
||||
weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
|
||||
"""
|
||||
|
||||
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
|
||||
if isinstance(model_files, str):
|
||||
model_files = [model_files]
|
||||
|
||||
extended_model_files = []
|
||||
for model_file in model_files:
|
||||
split_filenames = get_split_weight_filenames(model_file)
|
||||
if split_filenames is not None:
|
||||
extended_model_files.extend(split_filenames)
|
||||
else:
|
||||
extended_model_files.append(model_file)
|
||||
model_files = extended_model_files
|
||||
logger.info(f"Loading model files: {model_files}")
|
||||
|
||||
# load LoRA weights
|
||||
weight_hook = None
|
||||
if lora_weights_list is None or len(lora_weights_list) == 0:
|
||||
lora_weights_list = []
|
||||
lora_multipliers = []
|
||||
list_of_lora_weight_keys = []
|
||||
else:
|
||||
list_of_lora_weight_keys = []
|
||||
for lora_sd in lora_weights_list:
|
||||
lora_weight_keys = set(lora_sd.keys())
|
||||
list_of_lora_weight_keys.append(lora_weight_keys)
|
||||
|
||||
if lora_multipliers is None:
|
||||
lora_multipliers = [1.0] * len(lora_weights_list)
|
||||
while len(lora_multipliers) < len(lora_weights_list):
|
||||
lora_multipliers.append(1.0)
|
||||
if len(lora_multipliers) > len(lora_weights_list):
|
||||
lora_multipliers = lora_multipliers[: len(lora_weights_list)]
|
||||
|
||||
# Merge LoRA weights into the state dict
|
||||
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
|
||||
|
||||
# make hook for LoRA merging
|
||||
def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
|
||||
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
|
||||
|
||||
if not model_weight_key.endswith(".weight"):
|
||||
return model_weight
|
||||
|
||||
original_device = model_weight.device
|
||||
if original_device != calc_device:
|
||||
model_weight = model_weight.to(calc_device) # to make calculation faster
|
||||
|
||||
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
|
||||
# check if this weight has LoRA weights
|
||||
lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
|
||||
found = False
|
||||
for prefix in ["lora_unet_", ""]:
|
||||
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
|
||||
down_key = lora_name + ".lora_down.weight"
|
||||
up_key = lora_name + ".lora_up.weight"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
if down_key in lora_weight_keys and up_key in lora_weight_keys:
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
# Standard LoRA merge
|
||||
# get LoRA weights
|
||||
down_weight = lora_sd[down_key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
down_weight = down_weight.to(calc_device)
|
||||
up_weight = up_weight.to(calc_device)
|
||||
|
||||
original_dtype = model_weight.dtype
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
# temporarily convert to float16 for calculation
|
||||
model_weight = model_weight.to(torch.float16)
|
||||
down_weight = down_weight.to(torch.float16)
|
||||
up_weight = up_weight.to(torch.float16)
|
||||
|
||||
# W <- W + U * D
|
||||
if len(model_weight.size()) == 2:
|
||||
# linear
|
||||
if len(up_weight.size()) == 4: # use linear projection mismatch
|
||||
up_weight = up_weight.squeeze(3).squeeze(2)
|
||||
down_weight = down_weight.squeeze(3).squeeze(2)
|
||||
model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
model_weight = (
|
||||
model_weight
|
||||
+ multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
model_weight = model_weight + multiplier * conved * scale
|
||||
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
model_weight = model_weight.to(original_dtype) # convert back to original dtype
|
||||
|
||||
# remove LoRA keys from set
|
||||
lora_weight_keys.remove(down_key)
|
||||
lora_weight_keys.remove(up_key)
|
||||
if alpha_key in lora_weight_keys:
|
||||
lora_weight_keys.remove(alpha_key)
|
||||
continue
|
||||
|
||||
# Check for LoHa/LoKr weights with same prefix search
|
||||
for prefix in ["lora_unet_", ""]:
|
||||
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
|
||||
hada_key = lora_name + ".hada_w1_a"
|
||||
lokr_key = lora_name + ".lokr_w1"
|
||||
|
||||
if hada_key in lora_weight_keys:
|
||||
# LoHa merge
|
||||
from networks.loha import merge_weights_to_tensor as loha_merge
|
||||
|
||||
model_weight = loha_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device)
|
||||
break
|
||||
elif lokr_key in lora_weight_keys:
|
||||
# LoKr merge
|
||||
from networks.lokr import merge_weights_to_tensor as lokr_merge
|
||||
|
||||
model_weight = lokr_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device)
|
||||
break
|
||||
|
||||
if not keep_on_calc_device and original_device != calc_device:
|
||||
model_weight = model_weight.to(original_device) # move back to original device
|
||||
return model_weight
|
||||
|
||||
weight_hook = weight_hook_func
|
||||
|
||||
state_dict = load_safetensors_with_fp8_optimization_and_hook(
|
||||
model_files,
|
||||
fp8_optimization,
|
||||
calc_device,
|
||||
move_to_device,
|
||||
dit_weight_dtype,
|
||||
target_keys,
|
||||
exclude_keys,
|
||||
weight_hook=weight_hook,
|
||||
disable_numpy_memmap=disable_numpy_memmap,
|
||||
weight_transform_hooks=weight_transform_hooks,
|
||||
)
|
||||
|
||||
for lora_weight_keys in list_of_lora_weight_keys:
|
||||
# check if all LoRA keys are used
|
||||
if len(lora_weight_keys) > 0:
|
||||
# if there are still LoRA keys left, it means they are not used in the model
|
||||
# this is a warning, not an error
|
||||
logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}")
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_safetensors_with_fp8_optimization_and_hook(
|
||||
model_files: list[str],
|
||||
fp8_optimization: bool,
|
||||
calc_device: torch.device,
|
||||
move_to_device: bool = False,
|
||||
dit_weight_dtype: Optional[torch.dtype] = None,
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
weight_hook: callable = None,
|
||||
disable_numpy_memmap: bool = False,
|
||||
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
|
||||
"""
|
||||
if fp8_optimization:
|
||||
logger.info(
|
||||
f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
||||
)
|
||||
# dit_weight_dtype is not used because we use fp8 optimization
|
||||
state_dict = load_safetensors_with_fp8_optimization(
|
||||
model_files,
|
||||
calc_device,
|
||||
target_keys,
|
||||
exclude_keys,
|
||||
move_to_device=move_to_device,
|
||||
weight_hook=weight_hook,
|
||||
disable_numpy_memmap=disable_numpy_memmap,
|
||||
weight_transform_hooks=weight_transform_hooks,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
||||
)
|
||||
state_dict = {}
|
||||
for model_file in model_files:
|
||||
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
|
||||
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
|
||||
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
|
||||
if weight_hook is None and move_to_device:
|
||||
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
|
||||
else:
|
||||
value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer
|
||||
if weight_hook is not None:
|
||||
value = weight_hook(key, value, keep_on_calc_device=move_to_device)
|
||||
if move_to_device:
|
||||
value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
|
||||
elif dit_weight_dtype is not None:
|
||||
value = value.to(dit_weight_dtype)
|
||||
|
||||
state_dict[key] = value
|
||||
if move_to_device:
|
||||
synchronize_device(calc_device)
|
||||
|
||||
return state_dict
|
||||
|
||||
450
networks/loha.py
Normal file
450
networks/loha.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# LoHa (Low-rank Hadamard Product) network module
|
||||
# Reference: https://arxiv.org/abs/2108.06098
|
||||
#
|
||||
# Based on the LyCORIS project by KohakuBlueleaf
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS
|
||||
|
||||
import ast
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HadaWeight(torch.autograd.Function):
|
||||
"""Efficient Hadamard product forward/backward for LoHa.
|
||||
|
||||
Computes ((w1a @ w1b) * (w2a @ w2b)) * scale with custom backward
|
||||
that recomputes intermediates instead of storing them.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, w1a, w1b, w2a, w2b, scale=None):
|
||||
if scale is None:
|
||||
scale = torch.tensor(1, device=w1a.device, dtype=w1a.dtype)
|
||||
ctx.save_for_backward(w1a, w1b, w2a, w2b, scale)
|
||||
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
||||
return diff_weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors
|
||||
grad_out = grad_out * scale
|
||||
temp = grad_out * (w2a @ w2b)
|
||||
grad_w1a = temp @ w1b.T
|
||||
grad_w1b = w1a.T @ temp
|
||||
|
||||
temp = grad_out * (w1a @ w1b)
|
||||
grad_w2a = temp @ w2b.T
|
||||
grad_w2b = w2a.T @ temp
|
||||
|
||||
del temp
|
||||
return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None
|
||||
|
||||
|
||||
class LoHaModule(torch.nn.Module):
|
||||
"""LoHa module for training. Replaces forward method of the original Linear/Conv2d."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||||
if is_conv2d:
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
self.is_conv = True
|
||||
if org_module.kernel_size != (1, 1):
|
||||
raise ValueError("LoHa Conv2d 3x3 (Tucker decomposition) is not supported yet")
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
self.is_conv = False
|
||||
|
||||
# Hadamard product parameters: ΔW = (w1a @ w1b) * (w2a @ w2b)
|
||||
self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
||||
self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
||||
self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim))
|
||||
self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim))
|
||||
|
||||
# Initialization: w1_a normal(0.1), w1_b normal(1.0), w2_a = 0, w2_b normal(1.0)
|
||||
# Ensures ΔW = 0 at init since w2_a = 0
|
||||
torch.nn.init.normal_(self.hada_w1_a, std=0.1)
|
||||
torch.nn.init.normal_(self.hada_w1_b, std=1.0)
|
||||
torch.nn.init.constant_(self.hada_w2_a, 0)
|
||||
torch.nn.init.normal_(self.hada_w2_b, std=1.0)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy()
|
||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha))
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
del self.org_module
|
||||
|
||||
def get_diff_weight(self):
|
||||
"""Return materialized weight delta as a 2D matrix."""
|
||||
scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device)
|
||||
return HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
|
||||
def forward(self, x):
|
||||
org_forwarded = self.org_forward(x)
|
||||
|
||||
# module dropout
|
||||
if self.module_dropout is not None and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return org_forwarded
|
||||
|
||||
diff_weight = self.get_diff_weight()
|
||||
|
||||
# rank dropout (applied on output dimension)
|
||||
if self.rank_dropout is not None and self.training:
|
||||
drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
|
||||
drop = drop.view(-1, 1)
|
||||
diff_weight = diff_weight * drop
|
||||
scale = 1.0 / (1.0 - self.rank_dropout)
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
if self.is_conv:
|
||||
# Conv2d 1x1: reshape to 4D for conv operation
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return org_forwarded + F.conv2d(x, diff_weight) * self.multiplier * scale
|
||||
else:
|
||||
return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class LoHaInfModule(LoHaModule):
|
||||
"""LoHa module for inference. Supports merge_to and get_weight."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
||||
|
||||
self.org_module_ref = [org_module]
|
||||
self.enabled = True
|
||||
self.network: AdditionalNetwork = None
|
||||
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
def merge_to(self, sd, dtype, device):
|
||||
# extract weight from org_module
|
||||
org_sd = self.org_module.state_dict()
|
||||
weight = org_sd["weight"]
|
||||
org_dtype = weight.dtype
|
||||
org_device = weight.device
|
||||
weight = weight.to(torch.float)
|
||||
|
||||
if dtype is None:
|
||||
dtype = org_dtype
|
||||
if device is None:
|
||||
device = org_device
|
||||
|
||||
# get LoHa weights
|
||||
w1a = sd["hada_w1_a"].to(torch.float).to(device)
|
||||
w1b = sd["hada_w1_b"].to(torch.float).to(device)
|
||||
w2a = sd["hada_w2_a"].to(torch.float).to(device)
|
||||
w2b = sd["hada_w2_b"].to(torch.float).to(device)
|
||||
|
||||
# compute ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
||||
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale
|
||||
|
||||
if self.is_conv:
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
weight = weight.to(device) + self.multiplier * diff_weight
|
||||
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
|
||||
def get_weight(self, multiplier=None):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
w1a = self.hada_w1_a.to(torch.float)
|
||||
w1b = self.hada_w1_b.to(torch.float)
|
||||
w2a = self.hada_w2_a.to(torch.float)
|
||||
w2b = self.hada_w2_b.to(torch.float)
|
||||
|
||||
weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale * multiplier
|
||||
|
||||
if self.is_conv:
|
||||
weight = weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
return weight
|
||||
|
||||
def default_forward(self, x):
|
||||
diff_weight = self.get_diff_weight()
|
||||
if self.is_conv:
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return self.org_forward(x) + F.conv2d(x, diff_weight) * self.multiplier
|
||||
else:
|
||||
return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
|
||||
|
||||
def forward(self, x):
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
return self.default_forward(x)
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
network_dim: Optional[int],
|
||||
network_alpha: Optional[float],
|
||||
vae,
|
||||
text_encoder,
|
||||
unet,
|
||||
neuron_dropout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a LoHa network. Called by train_network.py via network_module.create_network()."""
|
||||
if network_dim is None:
|
||||
network_dim = 4
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
# handle text_encoder as list
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
||||
|
||||
# detect architecture
|
||||
arch_config = detect_arch_config(unet, text_encoders)
|
||||
|
||||
# train LLM adapter
|
||||
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
|
||||
if train_llm_adapter is not None:
|
||||
train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False
|
||||
|
||||
# exclude patterns
|
||||
exclude_patterns = kwargs.get("exclude_patterns", None)
|
||||
if exclude_patterns is None:
|
||||
exclude_patterns = []
|
||||
else:
|
||||
exclude_patterns = ast.literal_eval(exclude_patterns)
|
||||
if not isinstance(exclude_patterns, list):
|
||||
exclude_patterns = [exclude_patterns]
|
||||
|
||||
# add default exclude patterns from arch config
|
||||
exclude_patterns.extend(arch_config.default_excludes)
|
||||
|
||||
# include patterns
|
||||
include_patterns = kwargs.get("include_patterns", None)
|
||||
if include_patterns is not None:
|
||||
include_patterns = ast.literal_eval(include_patterns)
|
||||
if not isinstance(include_patterns, list):
|
||||
include_patterns = [include_patterns]
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
if rank_dropout is not None:
|
||||
rank_dropout = float(rank_dropout)
|
||||
module_dropout = kwargs.get("module_dropout", None)
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# conv dim/alpha (for future Conv2d 3x3 support)
|
||||
conv_lora_dim = kwargs.get("conv_dim", None)
|
||||
conv_alpha = kwargs.get("conv_alpha", None)
|
||||
if conv_lora_dim is not None:
|
||||
conv_lora_dim = int(conv_lora_dim)
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
# verbose
|
||||
verbose = kwargs.get("verbose", "false")
|
||||
if verbose is not None:
|
||||
verbose = True if str(verbose).lower() == "true" else False
|
||||
|
||||
# regex-specific learning rates / dimensions
|
||||
network_reg_lrs = kwargs.get("network_reg_lrs", None)
|
||||
reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None
|
||||
|
||||
network_reg_dims = kwargs.get("network_reg_dims", None)
|
||||
reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None
|
||||
|
||||
network = AdditionalNetwork(
|
||||
text_encoders,
|
||||
unet,
|
||||
arch_config=arch_config,
|
||||
multiplier=multiplier,
|
||||
lora_dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
dropout=neuron_dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
module_class=LoHaModule,
|
||||
conv_lora_dim=conv_lora_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
exclude_patterns=exclude_patterns,
|
||||
include_patterns=include_patterns,
|
||||
reg_dims=reg_dims,
|
||||
reg_lrs=reg_lrs,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
# LoRA+ support
|
||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
"""Create a LoHa network from saved weights. Called by train_network.py."""
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# detect dim/alpha from weights
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
train_llm_adapter = False
|
||||
for key, value in weights_sd.items():
|
||||
if "." not in key:
|
||||
continue
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
if "alpha" in key:
|
||||
modules_alpha[lora_name] = value
|
||||
elif "hada_w1_b" in key:
|
||||
dim = value.shape[0]
|
||||
modules_dim[lora_name] = dim
|
||||
|
||||
if "llm_adapter" in lora_name:
|
||||
train_llm_adapter = True
|
||||
|
||||
# handle text_encoder as list
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
||||
|
||||
# detect architecture
|
||||
arch_config = detect_arch_config(unet, text_encoders)
|
||||
|
||||
module_class = LoHaInfModule if for_inference else LoHaModule
|
||||
|
||||
network = AdditionalNetwork(
|
||||
text_encoders,
|
||||
unet,
|
||||
arch_config=arch_config,
|
||||
multiplier=multiplier,
|
||||
modules_dim=modules_dim,
|
||||
modules_alpha=modules_alpha,
|
||||
module_class=module_class,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
def merge_weights_to_tensor(
|
||||
model_weight: torch.Tensor,
|
||||
lora_name: str,
|
||||
lora_sd: Dict[str, torch.Tensor],
|
||||
lora_weight_keys: set,
|
||||
multiplier: float,
|
||||
calc_device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Merge LoHa weights directly into a model weight tensor.
|
||||
|
||||
No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
|
||||
Returns model_weight unchanged if no matching LoHa keys found.
|
||||
"""
|
||||
w1a_key = lora_name + ".hada_w1_a"
|
||||
w1b_key = lora_name + ".hada_w1_b"
|
||||
w2a_key = lora_name + ".hada_w2_a"
|
||||
w2b_key = lora_name + ".hada_w2_b"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
|
||||
if w1a_key not in lora_weight_keys:
|
||||
return model_weight
|
||||
|
||||
w1a = lora_sd[w1a_key].to(calc_device)
|
||||
w1b = lora_sd[w1b_key].to(calc_device)
|
||||
w2a = lora_sd[w2a_key].to(calc_device)
|
||||
w2b = lora_sd[w2b_key].to(calc_device)
|
||||
|
||||
dim = w1b.shape[0]
|
||||
alpha = lora_sd.get(alpha_key, torch.tensor(dim))
|
||||
if isinstance(alpha, torch.Tensor):
|
||||
alpha = alpha.item()
|
||||
scale = alpha / dim
|
||||
|
||||
original_dtype = model_weight.dtype
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
model_weight = model_weight.to(torch.float16)
|
||||
w1a, w1b, w2a, w2b = w1a.to(torch.float16), w1b.to(torch.float16), w2a.to(torch.float16), w2b.to(torch.float16)
|
||||
|
||||
# ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
||||
diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale
|
||||
|
||||
# handle Conv2d 1x1 weights (4D tensors)
|
||||
if len(model_weight.shape) == 4:
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
model_weight = model_weight + multiplier * diff_weight
|
||||
|
||||
if original_dtype.itemsize == 1:
|
||||
model_weight = model_weight.to(original_dtype)
|
||||
|
||||
# remove consumed keys
|
||||
for key in [w1a_key, w1b_key, w2a_key, w2b_key, alpha_key]:
|
||||
lora_weight_keys.discard(key)
|
||||
|
||||
return model_weight
|
||||
544
networks/lokr.py
Normal file
544
networks/lokr.py
Normal file
@@ -0,0 +1,544 @@
|
||||
# LoKr (Low-rank Kronecker Product) network module
|
||||
# Reference: https://arxiv.org/abs/2309.14859
|
||||
#
|
||||
# Based on the LyCORIS project by KohakuBlueleaf
|
||||
# https://github.com/KohakuBlueleaf/LyCORIS
|
||||
|
||||
import ast
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def factorization(dimension: int, factor: int = -1) -> tuple:
|
||||
"""Return a tuple of two values whose product equals dimension,
|
||||
optimized for balanced factors.
|
||||
|
||||
In LoKr, the first value is for the weight scale (smaller),
|
||||
and the second value is for the weight (larger).
|
||||
|
||||
Examples:
|
||||
factor=-1: 128 -> (8, 16), 512 -> (16, 32), 1024 -> (32, 32)
|
||||
factor=4: 128 -> (4, 32), 512 -> (4, 128)
|
||||
"""
|
||||
if factor > 0 and (dimension % factor) == 0:
|
||||
m = factor
|
||||
n = dimension // factor
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
if factor < 0:
|
||||
factor = dimension
|
||||
m, n = 1, dimension
|
||||
length = m + n
|
||||
while m < n:
|
||||
new_m = m + 1
|
||||
while dimension % new_m != 0:
|
||||
new_m += 1
|
||||
new_n = dimension // new_m
|
||||
if new_m + new_n > length or new_m > factor:
|
||||
break
|
||||
else:
|
||||
m, n = new_m, new_n
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
|
||||
|
||||
def make_kron(w1, w2, scale):
|
||||
"""Compute Kronecker product of w1 and w2, scaled by scale."""
|
||||
if w1.dim() != w2.dim():
|
||||
for _ in range(w2.dim() - w1.dim()):
|
||||
w1 = w1.unsqueeze(-1)
|
||||
w2 = w2.contiguous()
|
||||
rebuild = torch.kron(w1, w2)
|
||||
if scale != 1:
|
||||
rebuild = rebuild * scale
|
||||
return rebuild
|
||||
|
||||
|
||||
class LoKrModule(torch.nn.Module):
|
||||
"""LoKr module for training. Replaces forward method of the original Linear/Conv2d."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
factor=-1,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||||
if is_conv2d:
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
self.is_conv = True
|
||||
if org_module.kernel_size != (1, 1):
|
||||
raise ValueError("LoKr Conv2d 3x3 (Tucker decomposition) is not supported yet")
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
self.is_conv = False
|
||||
|
||||
factor = int(factor)
|
||||
self.use_w2 = False
|
||||
|
||||
# Factorize dimensions
|
||||
in_m, in_n = factorization(in_dim, factor)
|
||||
out_l, out_k = factorization(out_dim, factor)
|
||||
|
||||
# w1 is always a full matrix (the "scale" factor, small)
|
||||
self.lokr_w1 = nn.Parameter(torch.empty(out_l, in_m))
|
||||
|
||||
# w2: low-rank decomposition if rank is small enough, otherwise full matrix
|
||||
if lora_dim < max(out_k, in_n) / 2:
|
||||
self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim))
|
||||
self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n))
|
||||
else:
|
||||
self.use_w2 = True
|
||||
self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n))
|
||||
if lora_dim >= max(out_k, in_n) / 2:
|
||||
logger.warning(
|
||||
f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} "
|
||||
f"and factor={factor}, using full matrix mode."
|
||||
)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy()
|
||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||
# if both w1 and w2 are full matrices, use scale = 1
|
||||
if self.use_w2:
|
||||
alpha = lora_dim
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha))
|
||||
|
||||
# Initialization
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
|
||||
if self.use_w2:
|
||||
torch.nn.init.constant_(self.lokr_w2, 0)
|
||||
else:
|
||||
torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
|
||||
torch.nn.init.constant_(self.lokr_w2_b, 0)
|
||||
# Ensures ΔW = kron(w1, 0) = 0 at init
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
del self.org_module
|
||||
|
||||
def get_diff_weight(self):
|
||||
"""Return materialized weight delta."""
|
||||
w1 = self.lokr_w1
|
||||
if self.use_w2:
|
||||
w2 = self.lokr_w2
|
||||
else:
|
||||
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
||||
return make_kron(w1, w2, self.scale)
|
||||
|
||||
def forward(self, x):
|
||||
org_forwarded = self.org_forward(x)
|
||||
|
||||
# module dropout
|
||||
if self.module_dropout is not None and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return org_forwarded
|
||||
|
||||
diff_weight = self.get_diff_weight()
|
||||
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.training:
|
||||
drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype)
|
||||
drop = drop.view(-1, 1)
|
||||
diff_weight = diff_weight * drop
|
||||
scale = 1.0 / (1.0 - self.rank_dropout)
|
||||
else:
|
||||
scale = 1.0
|
||||
|
||||
if self.is_conv:
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return org_forwarded + F.conv2d(x, diff_weight) * self.multiplier * scale
|
||||
else:
|
||||
return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class LoKrInfModule(LoKrModule):
|
||||
"""LoKr module for inference. Supports merge_to and get_weight."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference; pass factor from kwargs if present
|
||||
factor = kwargs.pop("factor", -1)
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, factor=factor)
|
||||
|
||||
self.org_module_ref = [org_module]
|
||||
self.enabled = True
|
||||
self.network: AdditionalNetwork = None
|
||||
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
def merge_to(self, sd, dtype, device):
|
||||
# extract weight from org_module
|
||||
org_sd = self.org_module.state_dict()
|
||||
weight = org_sd["weight"]
|
||||
org_dtype = weight.dtype
|
||||
org_device = weight.device
|
||||
weight = weight.to(torch.float)
|
||||
|
||||
if dtype is None:
|
||||
dtype = org_dtype
|
||||
if device is None:
|
||||
device = org_device
|
||||
|
||||
# get LoKr weights
|
||||
w1 = sd["lokr_w1"].to(torch.float).to(device)
|
||||
|
||||
if "lokr_w2" in sd:
|
||||
w2 = sd["lokr_w2"].to(torch.float).to(device)
|
||||
else:
|
||||
w2a = sd["lokr_w2_a"].to(torch.float).to(device)
|
||||
w2b = sd["lokr_w2_b"].to(torch.float).to(device)
|
||||
w2 = w2a @ w2b
|
||||
|
||||
# compute ΔW via Kronecker product
|
||||
diff_weight = make_kron(w1, w2, self.scale)
|
||||
|
||||
if self.is_conv:
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
weight = weight.to(device) + self.multiplier * diff_weight
|
||||
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
|
||||
def get_weight(self, multiplier=None):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
w1 = self.lokr_w1.to(torch.float)
|
||||
if self.use_w2:
|
||||
w2 = self.lokr_w2.to(torch.float)
|
||||
else:
|
||||
w2 = (self.lokr_w2_a @ self.lokr_w2_b).to(torch.float)
|
||||
|
||||
weight = make_kron(w1, w2, self.scale) * multiplier
|
||||
|
||||
if self.is_conv:
|
||||
weight = weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
return weight
|
||||
|
||||
def default_forward(self, x):
|
||||
diff_weight = self.get_diff_weight()
|
||||
if self.is_conv:
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
return self.org_forward(x) + F.conv2d(x, diff_weight) * self.multiplier
|
||||
else:
|
||||
return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier
|
||||
|
||||
def forward(self, x):
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
return self.default_forward(x)
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
network_dim: Optional[int],
|
||||
network_alpha: Optional[float],
|
||||
vae,
|
||||
text_encoder,
|
||||
unet,
|
||||
neuron_dropout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a LoKr network. Called by train_network.py via network_module.create_network()."""
|
||||
if network_dim is None:
|
||||
network_dim = 4
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
# handle text_encoder as list
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
||||
|
||||
# detect architecture
|
||||
arch_config = detect_arch_config(unet, text_encoders)
|
||||
|
||||
# train LLM adapter
|
||||
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
|
||||
if train_llm_adapter is not None:
|
||||
train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False
|
||||
|
||||
# exclude patterns
|
||||
exclude_patterns = kwargs.get("exclude_patterns", None)
|
||||
if exclude_patterns is None:
|
||||
exclude_patterns = []
|
||||
else:
|
||||
exclude_patterns = ast.literal_eval(exclude_patterns)
|
||||
if not isinstance(exclude_patterns, list):
|
||||
exclude_patterns = [exclude_patterns]
|
||||
|
||||
# add default exclude patterns from arch config
|
||||
exclude_patterns.extend(arch_config.default_excludes)
|
||||
|
||||
# include patterns
|
||||
include_patterns = kwargs.get("include_patterns", None)
|
||||
if include_patterns is not None:
|
||||
include_patterns = ast.literal_eval(include_patterns)
|
||||
if not isinstance(include_patterns, list):
|
||||
include_patterns = [include_patterns]
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
if rank_dropout is not None:
|
||||
rank_dropout = float(rank_dropout)
|
||||
module_dropout = kwargs.get("module_dropout", None)
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# conv dim/alpha (for future Conv2d 3x3 support)
|
||||
conv_lora_dim = kwargs.get("conv_dim", None)
|
||||
conv_alpha = kwargs.get("conv_alpha", None)
|
||||
if conv_lora_dim is not None:
|
||||
conv_lora_dim = int(conv_lora_dim)
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
# factor for LoKr
|
||||
factor = int(kwargs.get("factor", -1))
|
||||
|
||||
# verbose
|
||||
verbose = kwargs.get("verbose", "false")
|
||||
if verbose is not None:
|
||||
verbose = True if str(verbose).lower() == "true" else False
|
||||
|
||||
# regex-specific learning rates / dimensions
|
||||
network_reg_lrs = kwargs.get("network_reg_lrs", None)
|
||||
reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None
|
||||
|
||||
network_reg_dims = kwargs.get("network_reg_dims", None)
|
||||
reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None
|
||||
|
||||
network = AdditionalNetwork(
|
||||
text_encoders,
|
||||
unet,
|
||||
arch_config=arch_config,
|
||||
multiplier=multiplier,
|
||||
lora_dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
dropout=neuron_dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
module_class=LoKrModule,
|
||||
module_kwargs={"factor": factor},
|
||||
conv_lora_dim=conv_lora_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
exclude_patterns=exclude_patterns,
|
||||
include_patterns=include_patterns,
|
||||
reg_dims=reg_dims,
|
||||
reg_lrs=reg_lrs,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
# LoRA+ support
|
||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
"""Create a LoKr network from saved weights. Called by train_network.py."""
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# detect dim/alpha from weights
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
train_llm_adapter = False
|
||||
for key, value in weights_sd.items():
|
||||
if "." not in key:
|
||||
continue
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
if "alpha" in key:
|
||||
modules_alpha[lora_name] = value
|
||||
elif "lokr_w2_a" in key:
|
||||
# low-rank mode: dim = w2_a.shape[1]
|
||||
dim = value.shape[1]
|
||||
modules_dim[lora_name] = dim
|
||||
elif "lokr_w2" in key and "lokr_w2_a" not in key and "lokr_w2_b" not in key:
|
||||
# full matrix mode: set dim large enough to trigger full-matrix path
|
||||
if lora_name not in modules_dim:
|
||||
modules_dim[lora_name] = max(value.shape)
|
||||
|
||||
if "llm_adapter" in lora_name:
|
||||
train_llm_adapter = True
|
||||
|
||||
# handle text_encoder as list
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder]
|
||||
|
||||
# detect architecture
|
||||
arch_config = detect_arch_config(unet, text_encoders)
|
||||
|
||||
# extract factor for LoKr
|
||||
factor = int(kwargs.get("factor", -1))
|
||||
|
||||
module_class = LoKrInfModule if for_inference else LoKrModule
|
||||
module_kwargs = {"factor": factor}
|
||||
|
||||
network = AdditionalNetwork(
|
||||
text_encoders,
|
||||
unet,
|
||||
arch_config=arch_config,
|
||||
multiplier=multiplier,
|
||||
modules_dim=modules_dim,
|
||||
modules_alpha=modules_alpha,
|
||||
module_class=module_class,
|
||||
module_kwargs=module_kwargs,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
def merge_weights_to_tensor(
|
||||
model_weight: torch.Tensor,
|
||||
lora_name: str,
|
||||
lora_sd: Dict[str, torch.Tensor],
|
||||
lora_weight_keys: set,
|
||||
multiplier: float,
|
||||
calc_device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Merge LoKr weights directly into a model weight tensor.
|
||||
|
||||
No Module/Network creation needed. Consumed keys are removed from lora_weight_keys.
|
||||
Returns model_weight unchanged if no matching LoKr keys found.
|
||||
"""
|
||||
w1_key = lora_name + ".lokr_w1"
|
||||
w2_key = lora_name + ".lokr_w2"
|
||||
w2a_key = lora_name + ".lokr_w2_a"
|
||||
w2b_key = lora_name + ".lokr_w2_b"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
|
||||
if w1_key not in lora_weight_keys:
|
||||
return model_weight
|
||||
|
||||
w1 = lora_sd[w1_key].to(calc_device)
|
||||
|
||||
# determine low-rank vs full matrix mode
|
||||
if w2a_key in lora_weight_keys:
|
||||
# low-rank: w2 = w2_a @ w2_b
|
||||
w2a = lora_sd[w2a_key].to(calc_device)
|
||||
w2b = lora_sd[w2b_key].to(calc_device)
|
||||
dim = w2a.shape[1]
|
||||
consumed_keys = [w1_key, w2a_key, w2b_key, alpha_key]
|
||||
elif w2_key in lora_weight_keys:
|
||||
# full matrix mode
|
||||
w2a = None
|
||||
w2b = None
|
||||
dim = None
|
||||
consumed_keys = [w1_key, w2_key, alpha_key]
|
||||
else:
|
||||
return model_weight
|
||||
|
||||
alpha = lora_sd.get(alpha_key, None)
|
||||
if alpha is not None and isinstance(alpha, torch.Tensor):
|
||||
alpha = alpha.item()
|
||||
|
||||
# compute scale
|
||||
if w2a is not None:
|
||||
# low-rank mode
|
||||
if alpha is None:
|
||||
alpha = dim
|
||||
scale = alpha / dim
|
||||
else:
|
||||
# full matrix mode: scale = 1.0
|
||||
scale = 1.0
|
||||
|
||||
original_dtype = model_weight.dtype
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
model_weight = model_weight.to(torch.float16)
|
||||
w1 = w1.to(torch.float16)
|
||||
if w2a is not None:
|
||||
w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16)
|
||||
|
||||
# compute w2
|
||||
if w2a is not None:
|
||||
w2 = w2a @ w2b
|
||||
else:
|
||||
w2 = lora_sd[w2_key].to(calc_device)
|
||||
if original_dtype.itemsize == 1:
|
||||
w2 = w2.to(torch.float16)
|
||||
|
||||
# ΔW = kron(w1, w2) * scale
|
||||
diff_weight = make_kron(w1, w2, scale)
|
||||
|
||||
# handle Conv2d 1x1 weights (4D tensors)
|
||||
if len(model_weight.shape) == 4:
|
||||
diff_weight = diff_weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
model_weight = model_weight + multiplier * diff_weight
|
||||
|
||||
if original_dtype.itemsize == 1:
|
||||
model_weight = model_weight.to(original_dtype)
|
||||
|
||||
# remove consumed keys
|
||||
for key in consumed_keys:
|
||||
lora_weight_keys.discard(key)
|
||||
|
||||
return model_weight
|
||||
@@ -1,11 +1,11 @@
|
||||
# LoRA network module for Anima
|
||||
import ast
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
import torch
|
||||
from library.utils import setup_logging
|
||||
from networks.lora_flux import LoRAModule, LoRAInfModule
|
||||
|
||||
import logging
|
||||
|
||||
@@ -13,6 +13,210 @@ setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoRAModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
):
|
||||
"""
|
||||
if alpha == 0 or None, alpha is rank (no scaling).
|
||||
"""
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
kernel_size = org_module.kernel_size
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||
else:
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
|
||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||
torch.nn.init.zeros_(self.lora_up.weight)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||
|
||||
# same as microsoft's
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
org_forwarded = self.org_forward(x)
|
||||
|
||||
# module dropout
|
||||
if self.module_dropout is not None and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return org_forwarded
|
||||
|
||||
lx = self.lora_down(x)
|
||||
|
||||
# normal dropout
|
||||
if self.dropout is not None and self.training:
|
||||
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.training:
|
||||
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
||||
if len(lx.size()) == 3:
|
||||
mask = mask.unsqueeze(1) # for Text Encoder
|
||||
elif len(lx.size()) == 4:
|
||||
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
||||
lx = lx * mask
|
||||
|
||||
# scaling for rank dropout: treat as if the rank is changed
|
||||
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
||||
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
||||
else:
|
||||
scale = self.scale
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class LoRAInfModule(LoRAModule):
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
||||
|
||||
self.org_module_ref = [org_module] # 後から参照できるように
|
||||
self.enabled = True
|
||||
self.network: LoRANetwork = None
|
||||
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
# freezeしてマージする
|
||||
def merge_to(self, sd, dtype, device):
|
||||
# extract weight from org_module
|
||||
org_sd = self.org_module.state_dict()
|
||||
weight = org_sd["weight"]
|
||||
org_dtype = weight.dtype
|
||||
org_device = weight.device
|
||||
weight = weight.to(torch.float) # calc in float
|
||||
|
||||
if dtype is None:
|
||||
dtype = org_dtype
|
||||
if device is None:
|
||||
device = org_device
|
||||
|
||||
# get up/down weight
|
||||
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
|
||||
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
||||
|
||||
# merge weight
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
weight
|
||||
+ self.multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* self.scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + self.multiplier * conved * self.scale
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
|
||||
# 復元できるマージのため、このモジュールのweightを返す
|
||||
def get_weight(self, multiplier=None):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
# get up/down weight from module
|
||||
up_weight = self.lora_up.weight.to(torch.float)
|
||||
down_weight = self.lora_down.weight.to(torch.float)
|
||||
|
||||
# pre-calculated weight
|
||||
if len(down_weight.size()) == 2:
|
||||
# linear
|
||||
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
self.multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* self.scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
weight = self.multiplier * conved * self.scale
|
||||
|
||||
return weight
|
||||
|
||||
def default_forward(self, x):
|
||||
# logger.info(f"default_forward {self.lora_name} {x.size()}")
|
||||
lx = self.lora_down(x)
|
||||
lx = self.lora_up(lx)
|
||||
return self.org_forward(x) + lx * self.multiplier * self.scale
|
||||
|
||||
def forward(self, x):
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
return self.default_forward(x)
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
network_dim: Optional[int],
|
||||
|
||||
540
networks/network_base.py
Normal file
540
networks/network_base.py
Normal file
@@ -0,0 +1,540 @@
|
||||
# Shared network base for additional network modules (like LyCORIS-family modules: LoHa, LoKr, etc).
|
||||
# Provides architecture detection and a generic AdditionalNetwork class.
|
||||
|
||||
import ast
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArchConfig:
|
||||
unet_target_modules: List[str]
|
||||
te_target_modules: List[str]
|
||||
unet_prefix: str
|
||||
te_prefixes: List[str]
|
||||
default_excludes: List[str] = field(default_factory=list)
|
||||
adapter_target_modules: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def detect_arch_config(unet, text_encoders) -> ArchConfig:
|
||||
"""Detect architecture from model structure and return ArchConfig."""
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
|
||||
# Check SDXL first
|
||||
if unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel):
|
||||
return ArchConfig(
|
||||
unet_target_modules=["Transformer2DModel"],
|
||||
te_target_modules=["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"],
|
||||
unet_prefix="lora_unet",
|
||||
te_prefixes=["lora_te1", "lora_te2"],
|
||||
default_excludes=[],
|
||||
)
|
||||
|
||||
# Check Anima: look for Block class in named_modules
|
||||
module_class_names = set()
|
||||
if unet is not None:
|
||||
for module in unet.modules():
|
||||
module_class_names.add(type(module).__name__)
|
||||
|
||||
if "Block" in module_class_names:
|
||||
return ArchConfig(
|
||||
unet_target_modules=["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"],
|
||||
te_target_modules=["Qwen3Attention", "Qwen3MLP", "Qwen3SdpaAttention", "Qwen3FlashAttention2"],
|
||||
unet_prefix="lora_unet",
|
||||
te_prefixes=["lora_te"],
|
||||
default_excludes=[r".*(_modulation|_norm|_embedder|final_layer).*"],
|
||||
adapter_target_modules=["LLMAdapterTransformerBlock"],
|
||||
)
|
||||
|
||||
raise ValueError(f"Cannot auto-detect architecture for LyCORIS. Module classes found: {sorted(module_class_names)}")
|
||||
|
||||
|
||||
def _parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, Union[int, float]]:
|
||||
"""Parse a string of key-value pairs separated by commas."""
|
||||
pairs = {}
|
||||
for pair in kv_pair_str.split(","):
|
||||
pair = pair.strip()
|
||||
if not pair:
|
||||
continue
|
||||
if "=" not in pair:
|
||||
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
|
||||
continue
|
||||
key, value = pair.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
try:
|
||||
pairs[key] = int(value) if is_int else float(value)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid value for {key}: {value}")
|
||||
return pairs
|
||||
|
||||
|
||||
class AdditionalNetwork(torch.nn.Module):
|
||||
"""Generic Additional network that supports LoHa, LoKr, and similar module types.
|
||||
|
||||
Constructed with a module_class parameter to inject the specific module type.
|
||||
Based on the lora_anima.py LoRANetwork, generalized for multiple architectures.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoders: list,
|
||||
unet,
|
||||
arch_config: ArchConfig,
|
||||
multiplier: float = 1.0,
|
||||
lora_dim: int = 4,
|
||||
alpha: float = 1,
|
||||
dropout: Optional[float] = None,
|
||||
rank_dropout: Optional[float] = None,
|
||||
module_dropout: Optional[float] = None,
|
||||
module_class: Type[torch.nn.Module] = None,
|
||||
module_kwargs: Optional[Dict] = None,
|
||||
modules_dim: Optional[Dict[str, int]] = None,
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
conv_lora_dim: Optional[int] = None,
|
||||
conv_alpha: Optional[float] = None,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
include_patterns: Optional[List[str]] = None,
|
||||
reg_dims: Optional[Dict[str, int]] = None,
|
||||
reg_lrs: Optional[Dict[str, float]] = None,
|
||||
train_llm_adapter: bool = False,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert module_class is not None, "module_class must be specified"
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.lora_dim = lora_dim
|
||||
self.alpha = alpha
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.conv_lora_dim = conv_lora_dim
|
||||
self.conv_alpha = conv_alpha
|
||||
self.train_llm_adapter = train_llm_adapter
|
||||
self.reg_dims = reg_dims
|
||||
self.reg_lrs = reg_lrs
|
||||
self.arch_config = arch_config
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
self.loraplus_text_encoder_lr_ratio = None
|
||||
|
||||
if module_kwargs is None:
|
||||
module_kwargs = {}
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info(f"create {module_class.__name__} network from weights")
|
||||
else:
|
||||
logger.info(f"create {module_class.__name__} network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
|
||||
# compile regular expressions
|
||||
def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
|
||||
re_patterns = []
|
||||
if patterns is not None:
|
||||
for pattern in patterns:
|
||||
try:
|
||||
re_pattern = re.compile(pattern)
|
||||
except re.error as e:
|
||||
logger.error(f"Invalid pattern '{pattern}': {e}")
|
||||
continue
|
||||
re_patterns.append(re_pattern)
|
||||
return re_patterns
|
||||
|
||||
exclude_re_patterns = str_to_re_patterns(exclude_patterns)
|
||||
include_re_patterns = str_to_re_patterns(include_patterns)
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
prefix: str,
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[str],
|
||||
default_dim: Optional[int] = None,
|
||||
) -> Tuple[List[torch.nn.Module], List[str]]:
|
||||
loras = []
|
||||
skipped = []
|
||||
for name, module in root_module.named_modules():
|
||||
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
|
||||
if target_replace_modules is None:
|
||||
module = root_module
|
||||
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
original_name = (name + "." if name else "") + child_name
|
||||
lora_name = f"{prefix}.{original_name}".replace(".", "_")
|
||||
|
||||
# exclude/include filter
|
||||
excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns)
|
||||
included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns)
|
||||
if excluded and not included:
|
||||
if verbose:
|
||||
logger.info(f"exclude: {original_name}")
|
||||
continue
|
||||
|
||||
dim = None
|
||||
alpha_val = None
|
||||
|
||||
if modules_dim is not None:
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha_val = modules_alpha[lora_name]
|
||||
else:
|
||||
if self.reg_dims is not None:
|
||||
for reg, d in self.reg_dims.items():
|
||||
if re.fullmatch(reg, original_name):
|
||||
dim = d
|
||||
alpha_val = self.alpha
|
||||
logger.info(f"Module {original_name} matched with regex '{reg}' -> dim: {dim}")
|
||||
break
|
||||
# fallback to default dim
|
||||
if dim is None:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha_val = self.alpha
|
||||
elif is_conv2d and self.conv_lora_dim is not None:
|
||||
dim = self.conv_lora_dim
|
||||
alpha_val = self.conv_alpha
|
||||
|
||||
if dim is None or dim == 0:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
child_module,
|
||||
self.multiplier,
|
||||
dim,
|
||||
alpha_val,
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
**module_kwargs,
|
||||
)
|
||||
lora.original_name = original_name
|
||||
loras.append(lora)
|
||||
|
||||
if target_replace_modules is None:
|
||||
break
|
||||
return loras, skipped
|
||||
|
||||
# Create modules for text encoders
|
||||
self.text_encoder_loras: List[torch.nn.Module] = []
|
||||
skipped_te = []
|
||||
if text_encoders is not None:
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if text_encoder is None:
|
||||
continue
|
||||
|
||||
# Determine prefix for this text encoder
|
||||
if i < len(arch_config.te_prefixes):
|
||||
te_prefix = arch_config.te_prefixes[i]
|
||||
else:
|
||||
te_prefix = arch_config.te_prefixes[0]
|
||||
|
||||
logger.info(f"create {module_class.__name__} for Text Encoder {i+1} (prefix={te_prefix}):")
|
||||
te_loras, te_skipped = create_modules(te_prefix, text_encoder, arch_config.te_target_modules)
|
||||
logger.info(f"create {module_class.__name__} for Text Encoder {i+1}: {len(te_loras)} modules.")
|
||||
self.text_encoder_loras.extend(te_loras)
|
||||
skipped_te += te_skipped
|
||||
|
||||
# Create modules for UNet/DiT
|
||||
target_modules = list(arch_config.unet_target_modules)
|
||||
if train_llm_adapter and arch_config.adapter_target_modules:
|
||||
target_modules.extend(arch_config.adapter_target_modules)
|
||||
|
||||
self.unet_loras: List[torch.nn.Module]
|
||||
self.unet_loras, skipped_un = create_modules(arch_config.unet_prefix, unet, target_modules)
|
||||
logger.info(f"create {module_class.__name__} for UNet/DiT: {len(self.unet_loras)} modules.")
|
||||
|
||||
if verbose:
|
||||
for lora in self.unet_loras:
|
||||
logger.info(f"\t{lora.lora_name:60} {lora.lora_dim}, {lora.alpha}")
|
||||
|
||||
skipped = skipped_te + skipped_un
|
||||
if verbose and len(skipped) > 0:
|
||||
logger.warning(f"dim (rank) is 0, {len(skipped)} modules are skipped:")
|
||||
for name in skipped:
|
||||
logger.info(f"\t{name}")
|
||||
|
||||
# assertion: no duplicate names
|
||||
names = set()
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||
names.add(lora.lora_name)
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.multiplier = self.multiplier
|
||||
|
||||
def set_enabled(self, is_enabled):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.enabled = is_enabled
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def apply_to(self, text_encoders, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
logger.info(f"enable modules for text encoder: {len(self.text_encoder_loras)} modules")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info(f"enable modules for UNet/DiT: {len(self.unet_loras)} modules")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.apply_to()
|
||||
self.add_module(lora.lora_name, lora)
|
||||
|
||||
def is_mergeable(self):
|
||||
return True
|
||||
|
||||
def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None):
|
||||
apply_text_encoder = apply_unet = False
|
||||
te_prefixes = self.arch_config.te_prefixes
|
||||
unet_prefix = self.arch_config.unet_prefix
|
||||
|
||||
for key in weights_sd.keys():
|
||||
if any(key.startswith(p) for p in te_prefixes):
|
||||
apply_text_encoder = True
|
||||
elif key.startswith(unet_prefix):
|
||||
apply_unet = True
|
||||
|
||||
if apply_text_encoder:
|
||||
logger.info("enable modules for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info("enable modules for UNet/DiT")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
sd_for_lora = {}
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(lora.lora_name):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
logger.info("weights are merged")
|
||||
|
||||
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||||
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||||
|
||||
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
||||
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
||||
|
||||
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
|
||||
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
|
||||
text_encoder_lr = [default_lr]
|
||||
elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
|
||||
text_encoder_lr = [float(text_encoder_lr)]
|
||||
elif len(text_encoder_lr) == 1:
|
||||
pass # already a list with one element
|
||||
|
||||
self.requires_grad_(True)
|
||||
|
||||
all_params = []
|
||||
lr_descriptions = []
|
||||
|
||||
def assemble_params(loras, lr, loraplus_ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
reg_groups = {}
|
||||
reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
|
||||
|
||||
for lora in loras:
|
||||
matched_reg_lr = None
|
||||
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
|
||||
if re.fullmatch(regex_str, lora.original_name):
|
||||
matched_reg_lr = (i, reg_lr)
|
||||
logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
|
||||
break
|
||||
|
||||
for name, param in lora.named_parameters():
|
||||
if matched_reg_lr is not None:
|
||||
reg_idx, reg_lr = matched_reg_lr
|
||||
group_key = f"reg_lr_{reg_idx}"
|
||||
if group_key not in reg_groups:
|
||||
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
|
||||
# LoRA+ detection: check for "up" weight parameters
|
||||
if loraplus_ratio is not None and self._is_plus_param(name):
|
||||
reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
continue
|
||||
|
||||
if loraplus_ratio is not None and self._is_plus_param(name):
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
|
||||
params = []
|
||||
descriptions = []
|
||||
for group_key, group in reg_groups.items():
|
||||
reg_lr = group["lr"]
|
||||
for key in ("lora", "plus"):
|
||||
param_data = {"params": group[key].values()}
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
if key == "plus":
|
||||
param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr
|
||||
else:
|
||||
param_data["lr"] = reg_lr
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
logger.info("NO LR skipping!")
|
||||
continue
|
||||
params.append(param_data)
|
||||
desc = f"reg_lr_{group_key.split('_')[-1]}"
|
||||
descriptions.append(desc + (" plus" if key == "plus" else ""))
|
||||
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
if lr is not None:
|
||||
if key == "plus":
|
||||
param_data["lr"] = lr * loraplus_ratio
|
||||
else:
|
||||
param_data["lr"] = lr
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
logger.info("NO LR skipping!")
|
||||
continue
|
||||
params.append(param_data)
|
||||
descriptions.append("plus" if key == "plus" else "")
|
||||
return params, descriptions
|
||||
|
||||
if self.text_encoder_loras:
|
||||
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
|
||||
# Group TE loras by prefix
|
||||
for te_idx, te_prefix in enumerate(self.arch_config.te_prefixes):
|
||||
te_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(te_prefix)]
|
||||
if len(te_loras) > 0:
|
||||
te_lr = text_encoder_lr[te_idx] if te_idx < len(text_encoder_lr) else text_encoder_lr[0]
|
||||
logger.info(f"Text Encoder {te_idx+1} ({te_prefix}): {len(te_loras)} modules, LR {te_lr}")
|
||||
params, descriptions = assemble_params(te_loras, te_lr, loraplus_ratio)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend([f"textencoder {te_idx+1}" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
if self.unet_loras:
|
||||
params, descriptions = assemble_params(
|
||||
self.unet_loras,
|
||||
unet_lr if unet_lr is not None else default_lr,
|
||||
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
|
||||
|
||||
return all_params, lr_descriptions
|
||||
|
||||
def _is_plus_param(self, name: str) -> bool:
|
||||
"""Check if a parameter name corresponds to a 'plus' (higher LR) param for LoRA+.
|
||||
|
||||
For LoRA: lora_up. For LoHa: hada_w2_a (the second pair). For LoKr: lokr_w1 (the scale factor).
|
||||
Override in subclass if needed. Default: check for common 'up' patterns.
|
||||
"""
|
||||
return "lora_up" in name or "hada_w2_a" in name or "lokr_w1" in name
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
pass # not supported
|
||||
|
||||
def prepare_grad_etc(self, text_encoder, unet):
|
||||
self.requires_grad_(True)
|
||||
|
||||
def on_epoch_start(self, text_encoder, unet):
|
||||
self.train()
|
||||
|
||||
def get_trainable_params(self):
|
||||
return self.parameters()
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
from library import train_util
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
def backup_weights(self):
|
||||
loras = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
if not hasattr(org_module, "_lora_org_weight"):
|
||||
sd = org_module.state_dict()
|
||||
org_module._lora_org_weight = sd["weight"].detach().clone()
|
||||
org_module._lora_restored = True
|
||||
|
||||
def restore_weights(self):
|
||||
loras = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
if not org_module._lora_restored:
|
||||
sd = org_module.state_dict()
|
||||
sd["weight"] = org_module._lora_org_weight
|
||||
org_module.load_state_dict(sd)
|
||||
org_module._lora_restored = True
|
||||
|
||||
def pre_calculation(self):
|
||||
loras = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
sd = org_module.state_dict()
|
||||
|
||||
org_weight = sd["weight"]
|
||||
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
||||
sd["weight"] = org_weight + lora_weight
|
||||
assert sd["weight"].shape == org_weight.shape
|
||||
org_module.load_state_dict(sd)
|
||||
|
||||
org_module._lora_restored = False
|
||||
lora.enabled = False
|
||||
Reference in New Issue
Block a user