diff --git a/networks/lora_anima.py b/networks/lora_anima.py index d674141c..54bd22ca 100644 --- a/networks/lora_anima.py +++ b/networks/lora_anima.py @@ -215,8 +215,6 @@ class LoRANetwork(torch.nn.Module): if modules_dim is not None: logger.info("create LoRA network from weights") - if self.emb_dims is None: - self.emb_dims = [0] * 3 else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") logger.info( @@ -265,9 +263,9 @@ class LoRANetwork(torch.nn.Module): original_name = (name + "." if name else "") + child_name lora_name = f"{prefix}.{original_name}".replace(".", "_") - # exclude/include filter - excluded = any(pattern.match(original_name) for pattern in exclude_re_patterns) - included = any(pattern.match(original_name) for pattern in include_re_patterns) + # exclude/include filter (fullmatch: pattern must match the entire original_name) + excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns) + included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns) if excluded and not included: if verbose: logger.info(f"exclude: {original_name}") @@ -280,17 +278,19 @@ class LoRANetwork(torch.nn.Module): if lora_name in modules_dim: dim = modules_dim[lora_name] alpha_val = modules_alpha[lora_name] - elif self.reg_dims is not None: - for reg, d in self.reg_dims.items(): - if re.search(reg, original_name): - dim = d - alpha_val = self.alpha - logger.info(f"LoRA {original_name} matched with regex {reg}, using dim: {dim}") - break else: - if is_linear or is_conv2d_1x1: - dim = default_dim if default_dim is not None else self.lora_dim - alpha_val = self.alpha + if self.reg_dims is not None: + for reg, d in self.reg_dims.items(): + if re.fullmatch(reg, original_name): + dim = d + alpha_val = self.alpha + logger.info(f"LoRA {original_name} matched with regex {reg}, using dim: {dim}") + break + # fallback to default dim if not matched by reg_dims or reg_dims is not specified + if dim is None: + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha_val = self.alpha if dim is None or dim == 0: if is_linear or is_conv2d_1x1: @@ -446,7 +446,7 @@ class LoRANetwork(torch.nn.Module): for lora in loras: matched_reg_lr = None for i, (regex_str, reg_lr) in enumerate(reg_lrs_list): - if re.search(regex_str, lora.original_name): + if re.fullmatch(regex_str, lora.original_name): matched_reg_lr = (i, reg_lr) logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}") break