From c3556d455f4dbf43e6a574552f4d21821a2396ea Mon Sep 17 00:00:00 2001 From: kohya-ss <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 8 Feb 2026 13:45:52 +0900 Subject: [PATCH] feat: simplify target module selection by regular expression patterns --- networks/lora_anima.py | 216 +++++++++++++---------------------------- 1 file changed, 67 insertions(+), 149 deletions(-) diff --git a/networks/lora_anima.py b/networks/lora_anima.py index c375ead7..7f44786e 100644 --- a/networks/lora_anima.py +++ b/networks/lora_anima.py @@ -1,18 +1,17 @@ -# LoRA network module for Anima -import math +# LoRA network module for Anima +import ast import os +import re from typing import Dict, List, Optional, Tuple, Type, Union -import numpy as np import torch from library.utils import setup_logging +from networks.lora_flux import LoRAModule, LoRAInfModule -setup_logging() import logging +setup_logging() logger = logging.getLogger(__name__) -from networks.lora_flux import LoRAModule, LoRAInfModule - def create_network( multiplier: float, @@ -29,68 +28,28 @@ def create_network( if network_alpha is None: network_alpha = 1.0 - # type_dims: [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim] - self_attn_dim = kwargs.get("self_attn_dim", None) - cross_attn_dim = kwargs.get("cross_attn_dim", None) - mlp_dim = kwargs.get("mlp_dim", None) - mod_dim = kwargs.get("mod_dim", None) - llm_adapter_dim = kwargs.get("llm_adapter_dim", None) - - if self_attn_dim is not None: - self_attn_dim = int(self_attn_dim) - if cross_attn_dim is not None: - cross_attn_dim = int(cross_attn_dim) - if mlp_dim is not None: - mlp_dim = int(mlp_dim) - if mod_dim is not None: - mod_dim = int(mod_dim) - if llm_adapter_dim is not None: - llm_adapter_dim = int(llm_adapter_dim) - - type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim] - if all([d is None for d in type_dims]): - type_dims = None - - # emb_dims: [x_embedder, t_embedder, final_layer] - emb_dims = kwargs.get("emb_dims", None) - if emb_dims is not None: - emb_dims = emb_dims.strip() - if emb_dims.startswith("[") and emb_dims.endswith("]"): - emb_dims = emb_dims[1:-1] - emb_dims = [int(d) for d in emb_dims.split(",")] - assert len(emb_dims) == 3, f"invalid emb_dims: {emb_dims}, must be 3 dimensions (x_embedder, t_embedder, final_layer)" - - # block selection - def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: - if selection == "all": - return [True] * total_blocks - if selection == "none" or selection == "": - return [False] * total_blocks - - selected = [False] * total_blocks - ranges = selection.split(",") - for r in ranges: - if "-" in r: - start, end = map(str.strip, r.split("-")) - start, end = int(start), int(end) - assert 0 <= start < total_blocks and 0 <= end < total_blocks and start <= end - for i in range(start, end + 1): - selected[i] = True - else: - index = int(r) - assert 0 <= index < total_blocks - selected[index] = True - return selected - - train_block_indices = kwargs.get("train_block_indices", None) - if train_block_indices is not None: - num_blocks = len(unet.blocks) if hasattr(unet, 'blocks') else 999 - train_block_indices = parse_block_selection(train_block_indices, num_blocks) - # train LLM adapter train_llm_adapter = kwargs.get("train_llm_adapter", False) if train_llm_adapter is not None: - train_llm_adapter = True if train_llm_adapter == "True" else False + train_llm_adapter = True if train_llm_adapter.lower() == "true" else False + + exclude_patterns = kwargs.get("exclude_patterns", None) + if exclude_patterns is None: + exclude_patterns = [] + else: + exclude_patterns = ast.literal_eval(exclude_patterns) + if not isinstance(exclude_patterns, list): + exclude_patterns = [exclude_patterns] + + # add default exclude patterns + exclude_patterns.append(r".*(_modulation|_norm|_embedder|final_layer).*") + + # regular expression for module selection: exclude and include + include_patterns = kwargs.get("include_patterns", None) + if include_patterns is not None: + include_patterns = ast.literal_eval(include_patterns) + if not isinstance(include_patterns, list): + include_patterns = [include_patterns] # rank/module dropout rank_dropout = kwargs.get("rank_dropout", None) @@ -103,7 +62,7 @@ def create_network( # verbose verbose = kwargs.get("verbose", False) if verbose is not None: - verbose = True if verbose == "True" else False + verbose = True if verbose.lower() == "true" else False network = LoRANetwork( text_encoders, @@ -115,9 +74,8 @@ def create_network( rank_dropout=rank_dropout, module_dropout=module_dropout, train_llm_adapter=train_llm_adapter, - type_dims=type_dims, - emb_dims=emb_dims, - train_block_indices=train_block_indices, + exclude_patterns=exclude_patterns, + include_patterns=include_patterns, verbose=verbose, ) @@ -137,6 +95,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh if weights_sd is None: if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file + weights_sd = load_file(file) else: weights_sd = torch.load(file, map_location="cpu") @@ -173,8 +132,8 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh class LoRANetwork(torch.nn.Module): - # Target modules: DiT blocks - ANIMA_TARGET_REPLACE_MODULE = ["Block"] + # Target modules: DiT blocks, embedders, final layer. embedders and final layer are excluded by default. + ANIMA_TARGET_REPLACE_MODULE = ["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"] # Target modules: LLM Adapter blocks ANIMA_ADAPTER_TARGET_REPLACE_MODULE = ["LLMAdapterTransformerBlock"] # Target modules for text encoder (Qwen3) @@ -197,9 +156,8 @@ class LoRANetwork(torch.nn.Module): modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, train_llm_adapter: bool = False, - type_dims: Optional[List[int]] = None, - emb_dims: Optional[List[int]] = None, - train_block_indices: Optional[List[bool]] = None, + exclude_patterns: Optional[List[str]] = None, + include_patterns: Optional[List[str]] = None, verbose: Optional[bool] = False, ) -> None: super().__init__() @@ -210,21 +168,36 @@ class LoRANetwork(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout self.train_llm_adapter = train_llm_adapter - self.type_dims = type_dims - self.emb_dims = emb_dims - self.train_block_indices = train_block_indices self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None self.loraplus_text_encoder_lr_ratio = None if modules_dim is not None: - logger.info(f"create LoRA network from weights") + logger.info("create LoRA network from weights") if self.emb_dims is None: self.emb_dims = [0] * 3 else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + + # compile regular expression if specified + def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]: + re_patterns = [] + if patterns is not None: + for pattern in patterns: + try: + re_pattern = re.compile(pattern) + except re.error as e: + logger.error(f"Invalid pattern '{pattern}': {e}") + continue + re_patterns.append(re_pattern) + return re_patterns + + exclude_re_patterns = str_to_re_patterns(exclude_patterns) + include_re_patterns = str_to_re_patterns(include_patterns) # create module instances def create_modules( @@ -232,15 +205,9 @@ class LoRANetwork(torch.nn.Module): text_encoder_idx: Optional[int], root_module: torch.nn.Module, target_replace_modules: List[str], - filter: Optional[str] = None, default_dim: Optional[int] = None, - include_conv2d_if_filter: bool = False, ) -> Tuple[List[LoRAModule], List[str]]: - prefix = ( - self.LORA_PREFIX_ANIMA - if is_unet - else self.LORA_PREFIX_TEXT_ENCODER - ) + prefix = self.LORA_PREFIX_ANIMA if is_unet else self.LORA_PREFIX_TEXT_ENCODER loras = [] skipped = [] @@ -255,14 +222,16 @@ class LoRANetwork(torch.nn.Module): is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) if is_linear or is_conv2d: - lora_name = prefix + "." + (name + "." if name else "") + child_name - lora_name = lora_name.replace(".", "_") + original_name = (name + "." if name else "") + child_name + lora_name = f"{prefix}.{original_name}".replace(".", "_") - force_incl_conv2d = False - if filter is not None: - if filter not in lora_name: - continue - force_incl_conv2d = include_conv2d_if_filter + # exclude/include filter + excluded = any(pattern.match(original_name) for pattern in exclude_re_patterns) + included = any(pattern.match(original_name) for pattern in include_re_patterns) + if excluded and not included: + if verbose: + logger.info(f"exclude: {original_name}") + continue dim = None alpha_val = None @@ -276,40 +245,6 @@ class LoRANetwork(torch.nn.Module): dim = default_dim if default_dim is not None else self.lora_dim alpha_val = self.alpha - if is_unet and type_dims is not None: - # type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim] - # Order matters: check most specific identifiers first to avoid mismatches. - identifier_order = [ - (4, ("llm_adapter",)), - (3, ("adaln_modulation",)), - (0, ("self_attn",)), - (1, ("cross_attn",)), - (2, ("mlp",)), - ] - for idx, ids in identifier_order: - d = type_dims[idx] - if d is not None and all(id_str in lora_name for id_str in ids): - dim = d # 0 means skip - break - - # block index filtering - if is_unet and dim and self.train_block_indices is not None and "blocks_" in lora_name: - # Extract block index from lora_name: "lora_unet_blocks_0_self_attn..." - parts = lora_name.split("_") - for pi, part in enumerate(parts): - if part == "blocks" and pi + 1 < len(parts): - try: - block_index = int(parts[pi + 1]) - if not self.train_block_indices[block_index]: - dim = 0 - except (ValueError, IndexError): - pass - break - - elif force_incl_conv2d: - dim = default_dim if default_dim is not None else self.lora_dim - alpha_val = self.alpha - if dim is None or dim == 0: if is_linear or is_conv2d_1x1: skipped.append(lora_name) @@ -339,9 +274,7 @@ class LoRANetwork(torch.nn.Module): if text_encoder is None: continue logger.info(f"create LoRA for Text Encoder {i+1}:") - te_loras, te_skipped = create_modules( - False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - ) + te_loras, te_skipped = create_modules(False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) logger.info(f"create LoRA for Text Encoder {i+1}: {len(te_loras)} modules.") self.text_encoder_loras.extend(te_loras) skipped_te += te_skipped @@ -354,19 +287,6 @@ class LoRANetwork(torch.nn.Module): self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - # emb_dims: [x_embedder, t_embedder, final_layer] - if self.emb_dims: - for filter_name, in_dim in zip( - ["x_embedder", "t_embedder", "final_layer"], - self.emb_dims, - ): - loras, _ = create_modules( - True, None, unet, None, - filter=filter_name, default_dim=in_dim, - include_conv2d_if_filter=(filter_name == "x_embedder"), - ) - self.unet_loras.extend(loras) - logger.info(f"create LoRA for Anima DiT: {len(self.unet_loras)} modules.") if verbose: for lora in self.unet_loras: @@ -396,6 +316,7 @@ class LoRANetwork(torch.nn.Module): def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file + weights_sd = load_file(file) else: weights_sd = torch.load(file, map_location="cpu") @@ -443,10 +364,10 @@ class LoRANetwork(torch.nn.Module): sd_for_lora = {} for key in weights_sd.keys(): if key.startswith(lora.lora_name): - sd_for_lora[key[len(lora.lora_name) + 1:]] = weights_sd[key] + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - logger.info(f"weights are merged") + logger.info("weights are merged") def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): self.loraplus_lr_ratio = loraplus_lr_ratio @@ -498,10 +419,7 @@ class LoRANetwork(torch.nn.Module): if self.text_encoder_loras: loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio - te1_loras = [ - lora for lora in self.text_encoder_loras - if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER) - ] + te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)] if len(te1_loras) > 0: logger.info(f"Text Encoder 1 (Qwen3): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_ratio)