feat: simplify target module selection by regular expression patterns

This commit is contained in:
kohya-ss
2026-02-08 13:45:52 +09:00
parent d992037984
commit c3556d455f

View File

@@ -1,18 +1,17 @@
# LoRA network module for Anima
import math
import ast
import os
import re
from typing import Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch
from library.utils import setup_logging
from networks.lora_flux import LoRAModule, LoRAInfModule
setup_logging()
import logging
setup_logging()
logger = logging.getLogger(__name__)
from networks.lora_flux import LoRAModule, LoRAInfModule
def create_network(
multiplier: float,
@@ -29,68 +28,28 @@ def create_network(
if network_alpha is None:
network_alpha = 1.0
# type_dims: [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
self_attn_dim = kwargs.get("self_attn_dim", None)
cross_attn_dim = kwargs.get("cross_attn_dim", None)
mlp_dim = kwargs.get("mlp_dim", None)
mod_dim = kwargs.get("mod_dim", None)
llm_adapter_dim = kwargs.get("llm_adapter_dim", None)
if self_attn_dim is not None:
self_attn_dim = int(self_attn_dim)
if cross_attn_dim is not None:
cross_attn_dim = int(cross_attn_dim)
if mlp_dim is not None:
mlp_dim = int(mlp_dim)
if mod_dim is not None:
mod_dim = int(mod_dim)
if llm_adapter_dim is not None:
llm_adapter_dim = int(llm_adapter_dim)
type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
if all([d is None for d in type_dims]):
type_dims = None
# emb_dims: [x_embedder, t_embedder, final_layer]
emb_dims = kwargs.get("emb_dims", None)
if emb_dims is not None:
emb_dims = emb_dims.strip()
if emb_dims.startswith("[") and emb_dims.endswith("]"):
emb_dims = emb_dims[1:-1]
emb_dims = [int(d) for d in emb_dims.split(",")]
assert len(emb_dims) == 3, f"invalid emb_dims: {emb_dims}, must be 3 dimensions (x_embedder, t_embedder, final_layer)"
# block selection
def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
if selection == "all":
return [True] * total_blocks
if selection == "none" or selection == "":
return [False] * total_blocks
selected = [False] * total_blocks
ranges = selection.split(",")
for r in ranges:
if "-" in r:
start, end = map(str.strip, r.split("-"))
start, end = int(start), int(end)
assert 0 <= start < total_blocks and 0 <= end < total_blocks and start <= end
for i in range(start, end + 1):
selected[i] = True
else:
index = int(r)
assert 0 <= index < total_blocks
selected[index] = True
return selected
train_block_indices = kwargs.get("train_block_indices", None)
if train_block_indices is not None:
num_blocks = len(unet.blocks) if hasattr(unet, 'blocks') else 999
train_block_indices = parse_block_selection(train_block_indices, num_blocks)
# train LLM adapter
train_llm_adapter = kwargs.get("train_llm_adapter", False)
if train_llm_adapter is not None:
train_llm_adapter = True if train_llm_adapter == "True" else False
train_llm_adapter = True if train_llm_adapter.lower() == "true" else False
exclude_patterns = kwargs.get("exclude_patterns", None)
if exclude_patterns is None:
exclude_patterns = []
else:
exclude_patterns = ast.literal_eval(exclude_patterns)
if not isinstance(exclude_patterns, list):
exclude_patterns = [exclude_patterns]
# add default exclude patterns
exclude_patterns.append(r".*(_modulation|_norm|_embedder|final_layer).*")
# regular expression for module selection: exclude and include
include_patterns = kwargs.get("include_patterns", None)
if include_patterns is not None:
include_patterns = ast.literal_eval(include_patterns)
if not isinstance(include_patterns, list):
include_patterns = [include_patterns]
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
@@ -103,7 +62,7 @@ def create_network(
# verbose
verbose = kwargs.get("verbose", False)
if verbose is not None:
verbose = True if verbose == "True" else False
verbose = True if verbose.lower() == "true" else False
network = LoRANetwork(
text_encoders,
@@ -115,9 +74,8 @@ def create_network(
rank_dropout=rank_dropout,
module_dropout=module_dropout,
train_llm_adapter=train_llm_adapter,
type_dims=type_dims,
emb_dims=emb_dims,
train_block_indices=train_block_indices,
exclude_patterns=exclude_patterns,
include_patterns=include_patterns,
verbose=verbose,
)
@@ -137,6 +95,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
@@ -173,8 +132,8 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh
class LoRANetwork(torch.nn.Module):
# Target modules: DiT blocks
ANIMA_TARGET_REPLACE_MODULE = ["Block"]
# Target modules: DiT blocks, embedders, final layer. embedders and final layer are excluded by default.
ANIMA_TARGET_REPLACE_MODULE = ["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"]
# Target modules: LLM Adapter blocks
ANIMA_ADAPTER_TARGET_REPLACE_MODULE = ["LLMAdapterTransformerBlock"]
# Target modules for text encoder (Qwen3)
@@ -197,9 +156,8 @@ class LoRANetwork(torch.nn.Module):
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
train_llm_adapter: bool = False,
type_dims: Optional[List[int]] = None,
emb_dims: Optional[List[int]] = None,
train_block_indices: Optional[List[bool]] = None,
exclude_patterns: Optional[List[str]] = None,
include_patterns: Optional[List[str]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
@@ -210,21 +168,36 @@ class LoRANetwork(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.train_llm_adapter = train_llm_adapter
self.type_dims = type_dims
self.emb_dims = emb_dims
self.train_block_indices = train_block_indices
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info(f"create LoRA network from weights")
logger.info("create LoRA network from weights")
if self.emb_dims is None:
self.emb_dims = [0] * 3
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
# compile regular expression if specified
def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
re_patterns = []
if patterns is not None:
for pattern in patterns:
try:
re_pattern = re.compile(pattern)
except re.error as e:
logger.error(f"Invalid pattern '{pattern}': {e}")
continue
re_patterns.append(re_pattern)
return re_patterns
exclude_re_patterns = str_to_re_patterns(exclude_patterns)
include_re_patterns = str_to_re_patterns(include_patterns)
# create module instances
def create_modules(
@@ -232,15 +205,9 @@ class LoRANetwork(torch.nn.Module):
text_encoder_idx: Optional[int],
root_module: torch.nn.Module,
target_replace_modules: List[str],
filter: Optional[str] = None,
default_dim: Optional[int] = None,
include_conv2d_if_filter: bool = False,
) -> Tuple[List[LoRAModule], List[str]]:
prefix = (
self.LORA_PREFIX_ANIMA
if is_unet
else self.LORA_PREFIX_TEXT_ENCODER
)
prefix = self.LORA_PREFIX_ANIMA if is_unet else self.LORA_PREFIX_TEXT_ENCODER
loras = []
skipped = []
@@ -255,14 +222,16 @@ class LoRANetwork(torch.nn.Module):
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_")
original_name = (name + "." if name else "") + child_name
lora_name = f"{prefix}.{original_name}".replace(".", "_")
force_incl_conv2d = False
if filter is not None:
if filter not in lora_name:
continue
force_incl_conv2d = include_conv2d_if_filter
# exclude/include filter
excluded = any(pattern.match(original_name) for pattern in exclude_re_patterns)
included = any(pattern.match(original_name) for pattern in include_re_patterns)
if excluded and not included:
if verbose:
logger.info(f"exclude: {original_name}")
continue
dim = None
alpha_val = None
@@ -276,40 +245,6 @@ class LoRANetwork(torch.nn.Module):
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
if is_unet and type_dims is not None:
# type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
# Order matters: check most specific identifiers first to avoid mismatches.
identifier_order = [
(4, ("llm_adapter",)),
(3, ("adaln_modulation",)),
(0, ("self_attn",)),
(1, ("cross_attn",)),
(2, ("mlp",)),
]
for idx, ids in identifier_order:
d = type_dims[idx]
if d is not None and all(id_str in lora_name for id_str in ids):
dim = d # 0 means skip
break
# block index filtering
if is_unet and dim and self.train_block_indices is not None and "blocks_" in lora_name:
# Extract block index from lora_name: "lora_unet_blocks_0_self_attn..."
parts = lora_name.split("_")
for pi, part in enumerate(parts):
if part == "blocks" and pi + 1 < len(parts):
try:
block_index = int(parts[pi + 1])
if not self.train_block_indices[block_index]:
dim = 0
except (ValueError, IndexError):
pass
break
elif force_incl_conv2d:
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
if dim is None or dim == 0:
if is_linear or is_conv2d_1x1:
skipped.append(lora_name)
@@ -339,9 +274,7 @@ class LoRANetwork(torch.nn.Module):
if text_encoder is None:
continue
logger.info(f"create LoRA for Text Encoder {i+1}:")
te_loras, te_skipped = create_modules(
False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
)
te_loras, te_skipped = create_modules(False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder {i+1}: {len(te_loras)} modules.")
self.text_encoder_loras.extend(te_loras)
skipped_te += te_skipped
@@ -354,19 +287,6 @@ class LoRANetwork(torch.nn.Module):
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
# emb_dims: [x_embedder, t_embedder, final_layer]
if self.emb_dims:
for filter_name, in_dim in zip(
["x_embedder", "t_embedder", "final_layer"],
self.emb_dims,
):
loras, _ = create_modules(
True, None, unet, None,
filter=filter_name, default_dim=in_dim,
include_conv2d_if_filter=(filter_name == "x_embedder"),
)
self.unet_loras.extend(loras)
logger.info(f"create LoRA for Anima DiT: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
@@ -396,6 +316,7 @@ class LoRANetwork(torch.nn.Module):
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
@@ -443,10 +364,10 @@ class LoRANetwork(torch.nn.Module):
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1:]] = weights_sd[key]
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
logger.info(f"weights are merged")
logger.info("weights are merged")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
@@ -498,10 +419,7 @@ class LoRANetwork(torch.nn.Module):
if self.text_encoder_loras:
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
te1_loras = [
lora for lora in self.text_encoder_loras
if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)
]
te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)]
if len(te1_loras) > 0:
logger.info(f"Text Encoder 1 (Qwen3): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_ratio)