mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
feat: simplify target module selection by regular expression patterns
This commit is contained in:
@@ -1,18 +1,17 @@
|
||||
# LoRA network module for Anima
|
||||
import math
|
||||
# LoRA network module for Anima
|
||||
import ast
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from library.utils import setup_logging
|
||||
from networks.lora_flux import LoRAModule, LoRAInfModule
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from networks.lora_flux import LoRAModule, LoRAInfModule
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
@@ -29,68 +28,28 @@ def create_network(
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
# type_dims: [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
|
||||
self_attn_dim = kwargs.get("self_attn_dim", None)
|
||||
cross_attn_dim = kwargs.get("cross_attn_dim", None)
|
||||
mlp_dim = kwargs.get("mlp_dim", None)
|
||||
mod_dim = kwargs.get("mod_dim", None)
|
||||
llm_adapter_dim = kwargs.get("llm_adapter_dim", None)
|
||||
|
||||
if self_attn_dim is not None:
|
||||
self_attn_dim = int(self_attn_dim)
|
||||
if cross_attn_dim is not None:
|
||||
cross_attn_dim = int(cross_attn_dim)
|
||||
if mlp_dim is not None:
|
||||
mlp_dim = int(mlp_dim)
|
||||
if mod_dim is not None:
|
||||
mod_dim = int(mod_dim)
|
||||
if llm_adapter_dim is not None:
|
||||
llm_adapter_dim = int(llm_adapter_dim)
|
||||
|
||||
type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
|
||||
if all([d is None for d in type_dims]):
|
||||
type_dims = None
|
||||
|
||||
# emb_dims: [x_embedder, t_embedder, final_layer]
|
||||
emb_dims = kwargs.get("emb_dims", None)
|
||||
if emb_dims is not None:
|
||||
emb_dims = emb_dims.strip()
|
||||
if emb_dims.startswith("[") and emb_dims.endswith("]"):
|
||||
emb_dims = emb_dims[1:-1]
|
||||
emb_dims = [int(d) for d in emb_dims.split(",")]
|
||||
assert len(emb_dims) == 3, f"invalid emb_dims: {emb_dims}, must be 3 dimensions (x_embedder, t_embedder, final_layer)"
|
||||
|
||||
# block selection
|
||||
def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
|
||||
if selection == "all":
|
||||
return [True] * total_blocks
|
||||
if selection == "none" or selection == "":
|
||||
return [False] * total_blocks
|
||||
|
||||
selected = [False] * total_blocks
|
||||
ranges = selection.split(",")
|
||||
for r in ranges:
|
||||
if "-" in r:
|
||||
start, end = map(str.strip, r.split("-"))
|
||||
start, end = int(start), int(end)
|
||||
assert 0 <= start < total_blocks and 0 <= end < total_blocks and start <= end
|
||||
for i in range(start, end + 1):
|
||||
selected[i] = True
|
||||
else:
|
||||
index = int(r)
|
||||
assert 0 <= index < total_blocks
|
||||
selected[index] = True
|
||||
return selected
|
||||
|
||||
train_block_indices = kwargs.get("train_block_indices", None)
|
||||
if train_block_indices is not None:
|
||||
num_blocks = len(unet.blocks) if hasattr(unet, 'blocks') else 999
|
||||
train_block_indices = parse_block_selection(train_block_indices, num_blocks)
|
||||
|
||||
# train LLM adapter
|
||||
train_llm_adapter = kwargs.get("train_llm_adapter", False)
|
||||
if train_llm_adapter is not None:
|
||||
train_llm_adapter = True if train_llm_adapter == "True" else False
|
||||
train_llm_adapter = True if train_llm_adapter.lower() == "true" else False
|
||||
|
||||
exclude_patterns = kwargs.get("exclude_patterns", None)
|
||||
if exclude_patterns is None:
|
||||
exclude_patterns = []
|
||||
else:
|
||||
exclude_patterns = ast.literal_eval(exclude_patterns)
|
||||
if not isinstance(exclude_patterns, list):
|
||||
exclude_patterns = [exclude_patterns]
|
||||
|
||||
# add default exclude patterns
|
||||
exclude_patterns.append(r".*(_modulation|_norm|_embedder|final_layer).*")
|
||||
|
||||
# regular expression for module selection: exclude and include
|
||||
include_patterns = kwargs.get("include_patterns", None)
|
||||
if include_patterns is not None:
|
||||
include_patterns = ast.literal_eval(include_patterns)
|
||||
if not isinstance(include_patterns, list):
|
||||
include_patterns = [include_patterns]
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
@@ -103,7 +62,7 @@ def create_network(
|
||||
# verbose
|
||||
verbose = kwargs.get("verbose", False)
|
||||
if verbose is not None:
|
||||
verbose = True if verbose == "True" else False
|
||||
verbose = True if verbose.lower() == "true" else False
|
||||
|
||||
network = LoRANetwork(
|
||||
text_encoders,
|
||||
@@ -115,9 +74,8 @@ def create_network(
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
type_dims=type_dims,
|
||||
emb_dims=emb_dims,
|
||||
train_block_indices=train_block_indices,
|
||||
exclude_patterns=exclude_patterns,
|
||||
include_patterns=include_patterns,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@@ -137,6 +95,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
@@ -173,8 +132,8 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh
|
||||
|
||||
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
# Target modules: DiT blocks
|
||||
ANIMA_TARGET_REPLACE_MODULE = ["Block"]
|
||||
# Target modules: DiT blocks, embedders, final layer. embedders and final layer are excluded by default.
|
||||
ANIMA_TARGET_REPLACE_MODULE = ["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"]
|
||||
# Target modules: LLM Adapter blocks
|
||||
ANIMA_ADAPTER_TARGET_REPLACE_MODULE = ["LLMAdapterTransformerBlock"]
|
||||
# Target modules for text encoder (Qwen3)
|
||||
@@ -197,9 +156,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
modules_dim: Optional[Dict[str, int]] = None,
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
train_llm_adapter: bool = False,
|
||||
type_dims: Optional[List[int]] = None,
|
||||
emb_dims: Optional[List[int]] = None,
|
||||
train_block_indices: Optional[List[bool]] = None,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
include_patterns: Optional[List[str]] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -210,21 +168,36 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.train_llm_adapter = train_llm_adapter
|
||||
self.type_dims = type_dims
|
||||
self.emb_dims = emb_dims
|
||||
self.train_block_indices = train_block_indices
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
self.loraplus_text_encoder_lr_ratio = None
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info(f"create LoRA network from weights")
|
||||
logger.info("create LoRA network from weights")
|
||||
if self.emb_dims is None:
|
||||
self.emb_dims = [0] * 3
|
||||
else:
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
|
||||
# compile regular expression if specified
|
||||
def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
|
||||
re_patterns = []
|
||||
if patterns is not None:
|
||||
for pattern in patterns:
|
||||
try:
|
||||
re_pattern = re.compile(pattern)
|
||||
except re.error as e:
|
||||
logger.error(f"Invalid pattern '{pattern}': {e}")
|
||||
continue
|
||||
re_patterns.append(re_pattern)
|
||||
return re_patterns
|
||||
|
||||
exclude_re_patterns = str_to_re_patterns(exclude_patterns)
|
||||
include_re_patterns = str_to_re_patterns(include_patterns)
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -232,15 +205,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
text_encoder_idx: Optional[int],
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[str],
|
||||
filter: Optional[str] = None,
|
||||
default_dim: Optional[int] = None,
|
||||
include_conv2d_if_filter: bool = False,
|
||||
) -> Tuple[List[LoRAModule], List[str]]:
|
||||
prefix = (
|
||||
self.LORA_PREFIX_ANIMA
|
||||
if is_unet
|
||||
else self.LORA_PREFIX_TEXT_ENCODER
|
||||
)
|
||||
prefix = self.LORA_PREFIX_ANIMA if is_unet else self.LORA_PREFIX_TEXT_ENCODER
|
||||
|
||||
loras = []
|
||||
skipped = []
|
||||
@@ -255,14 +222,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
lora_name = prefix + "." + (name + "." if name else "") + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
original_name = (name + "." if name else "") + child_name
|
||||
lora_name = f"{prefix}.{original_name}".replace(".", "_")
|
||||
|
||||
force_incl_conv2d = False
|
||||
if filter is not None:
|
||||
if filter not in lora_name:
|
||||
continue
|
||||
force_incl_conv2d = include_conv2d_if_filter
|
||||
# exclude/include filter
|
||||
excluded = any(pattern.match(original_name) for pattern in exclude_re_patterns)
|
||||
included = any(pattern.match(original_name) for pattern in include_re_patterns)
|
||||
if excluded and not included:
|
||||
if verbose:
|
||||
logger.info(f"exclude: {original_name}")
|
||||
continue
|
||||
|
||||
dim = None
|
||||
alpha_val = None
|
||||
@@ -276,40 +245,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha_val = self.alpha
|
||||
|
||||
if is_unet and type_dims is not None:
|
||||
# type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
|
||||
# Order matters: check most specific identifiers first to avoid mismatches.
|
||||
identifier_order = [
|
||||
(4, ("llm_adapter",)),
|
||||
(3, ("adaln_modulation",)),
|
||||
(0, ("self_attn",)),
|
||||
(1, ("cross_attn",)),
|
||||
(2, ("mlp",)),
|
||||
]
|
||||
for idx, ids in identifier_order:
|
||||
d = type_dims[idx]
|
||||
if d is not None and all(id_str in lora_name for id_str in ids):
|
||||
dim = d # 0 means skip
|
||||
break
|
||||
|
||||
# block index filtering
|
||||
if is_unet and dim and self.train_block_indices is not None and "blocks_" in lora_name:
|
||||
# Extract block index from lora_name: "lora_unet_blocks_0_self_attn..."
|
||||
parts = lora_name.split("_")
|
||||
for pi, part in enumerate(parts):
|
||||
if part == "blocks" and pi + 1 < len(parts):
|
||||
try:
|
||||
block_index = int(parts[pi + 1])
|
||||
if not self.train_block_indices[block_index]:
|
||||
dim = 0
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
break
|
||||
|
||||
elif force_incl_conv2d:
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha_val = self.alpha
|
||||
|
||||
if dim is None or dim == 0:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
skipped.append(lora_name)
|
||||
@@ -339,9 +274,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
if text_encoder is None:
|
||||
continue
|
||||
logger.info(f"create LoRA for Text Encoder {i+1}:")
|
||||
te_loras, te_skipped = create_modules(
|
||||
False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
)
|
||||
te_loras, te_skipped = create_modules(False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
logger.info(f"create LoRA for Text Encoder {i+1}: {len(te_loras)} modules.")
|
||||
self.text_encoder_loras.extend(te_loras)
|
||||
skipped_te += te_skipped
|
||||
@@ -354,19 +287,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
|
||||
# emb_dims: [x_embedder, t_embedder, final_layer]
|
||||
if self.emb_dims:
|
||||
for filter_name, in_dim in zip(
|
||||
["x_embedder", "t_embedder", "final_layer"],
|
||||
self.emb_dims,
|
||||
):
|
||||
loras, _ = create_modules(
|
||||
True, None, unet, None,
|
||||
filter=filter_name, default_dim=in_dim,
|
||||
include_conv2d_if_filter=(filter_name == "x_embedder"),
|
||||
)
|
||||
self.unet_loras.extend(loras)
|
||||
|
||||
logger.info(f"create LoRA for Anima DiT: {len(self.unet_loras)} modules.")
|
||||
if verbose:
|
||||
for lora in self.unet_loras:
|
||||
@@ -396,6 +316,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
@@ -443,10 +364,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
sd_for_lora = {}
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(lora.lora_name):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1:]] = weights_sd[key]
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
logger.info(f"weights are merged")
|
||||
logger.info("weights are merged")
|
||||
|
||||
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||
@@ -498,10 +419,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
if self.text_encoder_loras:
|
||||
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
|
||||
te1_loras = [
|
||||
lora for lora in self.text_encoder_loras
|
||||
if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)
|
||||
]
|
||||
te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)]
|
||||
if len(te1_loras) > 0:
|
||||
logger.info(f"Text Encoder 1 (Qwen3): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
|
||||
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_ratio)
|
||||
|
||||
Reference in New Issue
Block a user