This commit is contained in:
Dave Lage
2026-04-05 01:14:11 +00:00
committed by GitHub
3 changed files with 31 additions and 3 deletions

View File

@@ -4816,6 +4816,10 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
ignore_nesting_dict[section_name] = section_dict ignore_nesting_dict[section_name] = section_dict
continue continue
if section_name == "scale_weight_norms_map":
ignore_nesting_dict[section_name] = section_dict
continue
# if value is dict, save all key and value into one dict # if value is dict, save all key and value into one dict
for key, value in section_dict.items(): for key, value in section_dict.items():
ignore_nesting_dict[key] = value ignore_nesting_dict[key] = value

View File

@@ -5,6 +5,7 @@
import math import math
import os import os
from fnmatch import fnmatch
from typing import Dict, List, Optional, Tuple, Type, Union from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
from transformers import CLIPTextModel from transformers import CLIPTextModel
@@ -1366,7 +1367,8 @@ class LoRANetwork(torch.nn.Module):
org_module._lora_restored = False org_module._lora_restored = False
lora.enabled = False lora.enabled = False
def apply_max_norm_regularization(self, max_norm_value, device): @torch.no_grad()
def apply_max_norm_regularization(self, max_norm, device, scale_map: dict[str, float]={}):
downkeys = [] downkeys = []
upkeys = [] upkeys = []
alphakeys = [] alphakeys = []
@@ -1381,6 +1383,11 @@ class LoRANetwork(torch.nn.Module):
alphakeys.append(key.replace("lora_down.weight", "alpha")) alphakeys.append(key.replace("lora_down.weight", "alpha"))
for i in range(len(downkeys)): for i in range(len(downkeys)):
max_norm_value = max_norm
for key in scale_map.keys():
if fnmatch(downkeys[i], key):
max_norm_value = scale_map[key]
down = state_dict[downkeys[i]].to(device) down = state_dict[downkeys[i]].to(device)
up = state_dict[upkeys[i]].to(device) up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device) alpha = state_dict[alphakeys[i]].to(device)
@@ -1404,7 +1411,7 @@ class LoRANetwork(torch.nn.Module):
keys_scaled += 1 keys_scaled += 1
state_dict[upkeys[i]] *= sqrt_ratio state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm() * ratio scalednorm: torch.Tensor = updown.norm() * ratio
norms.append(scalednorm.item()) norms.append(scalednorm.item())
return keys_scaled, sum(norms) / len(norms), max(norms) return keys_scaled, sum(norms) / len(norms), max(norms)

View File

@@ -12,6 +12,8 @@ import json
from multiprocessing import Value from multiprocessing import Value
import numpy as np import numpy as np
import ast
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -1444,8 +1446,9 @@ class NetworkTrainer:
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
if args.scale_weight_norms: if args.scale_weight_norms:
scale_map = args.scale_weight_norms_map if args.scale_weight_norms_map else {}
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device args.scale_weight_norms, accelerator.device, scale_map=scale_map
) )
mean_grad_norm = None mean_grad_norm = None
mean_combined_norm = None mean_combined_norm = None
@@ -1713,6 +1716,14 @@ class NetworkTrainer:
logger.info("model saved.") logger.info("model saved.")
def parse_dict(input_str):
"""Convert string input into a dictionary."""
try:
# Use ast.literal_eval to safely evaluate the string as a Python literal (dict)
return ast.literal_eval(input_str)
except ValueError:
raise argparse.ArgumentTypeError(f"Invalid dictionary format: {input_str}")
def setup_parser() -> argparse.ArgumentParser: def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@@ -1816,6 +1827,12 @@ def setup_parser() -> argparse.ArgumentParser:
default=None, default=None,
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ1が初期値としては適当", help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ1が初期値としては適当",
) )
parser.add_argument(
"--scale_weight_norms_map",
type=parse_dict,
default="{}",
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ1が初期値としては適当",
)
parser.add_argument( parser.add_argument(
"--base_weights", "--base_weights",
type=str, type=str,