mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 23:01:22 +00:00
Add scale map to max_norm
This commit is contained in:
@@ -4537,6 +4537,10 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
if section_name == "scale_weight_norms_map":
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
# if value is dict, save all key and value into one dict
|
||||
for key, value in section_dict.items():
|
||||
ignore_nesting_dict[key] = value
|
||||
|
||||
@@ -1366,7 +1366,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
org_module._lora_restored = False
|
||||
lora.enabled = False
|
||||
|
||||
def apply_max_norm_regularization(self, max_norm_value, device):
|
||||
def apply_max_norm_regularization(self, max_norm, device, scale_map: dict[str, float]={}):
|
||||
downkeys = []
|
||||
upkeys = []
|
||||
alphakeys = []
|
||||
@@ -1381,6 +1381,11 @@ class LoRANetwork(torch.nn.Module):
|
||||
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
||||
|
||||
for i in range(len(downkeys)):
|
||||
max_norm_value = max_norm
|
||||
for key in scale_map.keys():
|
||||
if key in downkeys[i]:
|
||||
max_norm_value = scale_map[key]
|
||||
|
||||
down = state_dict[downkeys[i]].to(device)
|
||||
up = state_dict[upkeys[i]].to(device)
|
||||
alpha = state_dict[alphakeys[i]].to(device)
|
||||
|
||||
@@ -10,6 +10,8 @@ from multiprocessing import Value
|
||||
from typing import Any, List
|
||||
import toml
|
||||
|
||||
import ast
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
@@ -1260,8 +1262,9 @@ class NetworkTrainer:
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
scale_map = args.scale_weight_norms_map if args.scale_weight_norms_map else {}
|
||||
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
||||
args.scale_weight_norms, accelerator.device
|
||||
args.scale_weight_norms, accelerator.device, scale_map=scale_map
|
||||
)
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
else:
|
||||
@@ -1356,6 +1359,14 @@ class NetworkTrainer:
|
||||
|
||||
logger.info("model saved.")
|
||||
|
||||
def parse_dict(input_str):
|
||||
"""Convert string input into a dictionary."""
|
||||
try:
|
||||
# Use ast.literal_eval to safely evaluate the string as a Python literal (dict)
|
||||
return ast.literal_eval(input_str)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(f"Invalid dictionary format: {input_str}")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -1458,6 +1469,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_weight_norms_map",
|
||||
type=parse_dict,
|
||||
default="{}",
|
||||
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_weights",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user