feat: add regex-based learning rates and dimensions handling in create_network

This commit is contained in:
Kohya S
2026-02-10 08:17:43 +09:00
parent 58db77a488
commit 3d24736238

View File

@@ -64,6 +64,40 @@ def create_network(
if verbose is not None:
verbose = True if verbose.lower() == "true" else False
# regex-specific learning rates / dimensions
def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]:
"""
Parse a string of key-value pairs separated by commas.
"""
pairs = {}
for pair in kv_pair_str.split(","):
pair = pair.strip()
if not pair:
continue
if "=" not in pair:
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
continue
key, value = pair.split("=", 1)
key = key.strip()
value = value.strip()
try:
pairs[key] = int(value) if is_int else float(value)
except ValueError:
logger.warning(f"Invalid value for {key}: {value}")
return pairs
network_reg_lrs = kwargs.get("network_reg_lrs", None)
if network_reg_lrs is not None:
reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False)
else:
reg_lrs = None
network_reg_dims = kwargs.get("network_reg_dims", None)
if network_reg_dims is not None:
reg_dims = parse_kv_pairs(network_reg_dims, is_int=True)
else:
reg_dims = None
network = LoRANetwork(
text_encoders,
unet,
@@ -76,6 +110,8 @@ def create_network(
train_llm_adapter=train_llm_adapter,
exclude_patterns=exclude_patterns,
include_patterns=include_patterns,
reg_dims=reg_dims,
reg_lrs=reg_lrs,
verbose=verbose,
)
@@ -158,6 +194,8 @@ class LoRANetwork(torch.nn.Module):
train_llm_adapter: bool = False,
exclude_patterns: Optional[List[str]] = None,
include_patterns: Optional[List[str]] = None,
reg_dims: Optional[Dict[str, int]] = None,
reg_lrs: Optional[Dict[str, float]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
@@ -168,6 +206,8 @@ class LoRANetwork(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.train_llm_adapter = train_llm_adapter
self.reg_dims = reg_dims
self.reg_lrs = reg_lrs
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
@@ -240,6 +280,13 @@ class LoRANetwork(torch.nn.Module):
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha_val = modules_alpha[lora_name]
elif self.reg_dims is not None:
for reg, d in self.reg_dims.items():
if re.search(reg, original_name):
dim = d
alpha_val = self.alpha
logger.info(f"LoRA {original_name} matched with regex {reg}, using dim: {dim}")
break
else:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
@@ -260,6 +307,7 @@ class LoRANetwork(torch.nn.Module):
rank_dropout=rank_dropout,
module_dropout=module_dropout,
)
lora.original_name = original_name
loras.append(lora)
if target_replace_modules is None:
@@ -392,8 +440,29 @@ class LoRANetwork(torch.nn.Module):
def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}}
reg_groups = {}
reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
for lora in loras:
matched_reg_lr = None
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
if re.search(regex_str, lora.original_name):
matched_reg_lr = (i, reg_lr)
logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
break
for name, param in lora.named_parameters():
if matched_reg_lr is not None:
reg_idx, reg_lr = matched_reg_lr
group_key = f"reg_lr_{reg_idx}"
if group_key not in reg_groups:
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
if loraplus_ratio is not None and "lora_up" in name:
reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param
else:
reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param
continue
if loraplus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
@@ -401,6 +470,23 @@ class LoRANetwork(torch.nn.Module):
params = []
descriptions = []
for group_key, group in reg_groups.items():
reg_lr = group["lr"]
for key in ("lora", "plus"):
param_data = {"params": group[key].values()}
if len(param_data["params"]) == 0:
continue
if key == "plus":
param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr
else:
param_data["lr"] = reg_lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
desc = f"reg_lr_{group_key.split('_')[-1]}"
descriptions.append(desc + (" plus" if key == "plus" else ""))
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0: