mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
fix: improve regex matching for module selection and learning rates in LoRANetwork
This commit is contained in:
@@ -215,8 +215,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info("create LoRA network from weights")
|
||||
if self.emb_dims is None:
|
||||
self.emb_dims = [0] * 3
|
||||
else:
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(
|
||||
@@ -265,9 +263,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
original_name = (name + "." if name else "") + child_name
|
||||
lora_name = f"{prefix}.{original_name}".replace(".", "_")
|
||||
|
||||
# exclude/include filter
|
||||
excluded = any(pattern.match(original_name) for pattern in exclude_re_patterns)
|
||||
included = any(pattern.match(original_name) for pattern in include_re_patterns)
|
||||
# exclude/include filter (fullmatch: pattern must match the entire original_name)
|
||||
excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns)
|
||||
included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns)
|
||||
if excluded and not included:
|
||||
if verbose:
|
||||
logger.info(f"exclude: {original_name}")
|
||||
@@ -280,17 +278,19 @@ class LoRANetwork(torch.nn.Module):
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha_val = modules_alpha[lora_name]
|
||||
elif self.reg_dims is not None:
|
||||
for reg, d in self.reg_dims.items():
|
||||
if re.search(reg, original_name):
|
||||
dim = d
|
||||
alpha_val = self.alpha
|
||||
logger.info(f"LoRA {original_name} matched with regex {reg}, using dim: {dim}")
|
||||
break
|
||||
else:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha_val = self.alpha
|
||||
if self.reg_dims is not None:
|
||||
for reg, d in self.reg_dims.items():
|
||||
if re.fullmatch(reg, original_name):
|
||||
dim = d
|
||||
alpha_val = self.alpha
|
||||
logger.info(f"LoRA {original_name} matched with regex {reg}, using dim: {dim}")
|
||||
break
|
||||
# fallback to default dim if not matched by reg_dims or reg_dims is not specified
|
||||
if dim is None:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha_val = self.alpha
|
||||
|
||||
if dim is None or dim == 0:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
@@ -446,7 +446,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in loras:
|
||||
matched_reg_lr = None
|
||||
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
|
||||
if re.search(regex_str, lora.original_name):
|
||||
if re.fullmatch(regex_str, lora.original_name):
|
||||
matched_reg_lr = (i, reg_lr)
|
||||
logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user