fix: improve regex matching for module selection and learning rates in LoRANetwork

This commit is contained in:
Kohya S
2026-02-10 08:25:20 +09:00
parent 3d24736238
commit f3b6e59900

View File

@@ -215,8 +215,6 @@ class LoRANetwork(torch.nn.Module):
if modules_dim is not None:
logger.info("create LoRA network from weights")
if self.emb_dims is None:
self.emb_dims = [0] * 3
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
@@ -265,9 +263,9 @@ class LoRANetwork(torch.nn.Module):
original_name = (name + "." if name else "") + child_name
lora_name = f"{prefix}.{original_name}".replace(".", "_")
# exclude/include filter
excluded = any(pattern.match(original_name) for pattern in exclude_re_patterns)
included = any(pattern.match(original_name) for pattern in include_re_patterns)
# exclude/include filter (fullmatch: pattern must match the entire original_name)
excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns)
included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns)
if excluded and not included:
if verbose:
logger.info(f"exclude: {original_name}")
@@ -280,17 +278,19 @@ class LoRANetwork(torch.nn.Module):
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha_val = modules_alpha[lora_name]
elif self.reg_dims is not None:
for reg, d in self.reg_dims.items():
if re.search(reg, original_name):
dim = d
alpha_val = self.alpha
logger.info(f"LoRA {original_name} matched with regex {reg}, using dim: {dim}")
break
else:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
if self.reg_dims is not None:
for reg, d in self.reg_dims.items():
if re.fullmatch(reg, original_name):
dim = d
alpha_val = self.alpha
logger.info(f"LoRA {original_name} matched with regex {reg}, using dim: {dim}")
break
# fallback to default dim if not matched by reg_dims or reg_dims is not specified
if dim is None:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
if dim is None or dim == 0:
if is_linear or is_conv2d_1x1:
@@ -446,7 +446,7 @@ class LoRANetwork(torch.nn.Module):
for lora in loras:
matched_reg_lr = None
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
if re.search(regex_str, lora.original_name):
if re.fullmatch(regex_str, lora.original_name):
matched_reg_lr = (i, reg_lr)
logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
break