diff --git a/library/train_util.py b/library/train_util.py index 83d04f5e..4db7ba82 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4816,6 +4816,10 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar ignore_nesting_dict[section_name] = section_dict continue + if section_name == "scale_weight_norms_map": + ignore_nesting_dict[section_name] = section_dict + continue + # if value is dict, save all key and value into one dict for key, value in section_dict.items(): ignore_nesting_dict[key] = value diff --git a/networks/lora.py b/networks/lora.py index 1699a60f..b10ed6dc 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -5,6 +5,7 @@ import math import os +from fnmatch import fnmatch from typing import Dict, List, Optional, Tuple, Type, Union from diffusers import AutoencoderKL from transformers import CLIPTextModel @@ -1366,7 +1367,8 @@ class LoRANetwork(torch.nn.Module): org_module._lora_restored = False lora.enabled = False - def apply_max_norm_regularization(self, max_norm_value, device): + @torch.no_grad() + def apply_max_norm_regularization(self, max_norm, device, scale_map: dict[str, float]={}): downkeys = [] upkeys = [] alphakeys = [] @@ -1381,6 +1383,11 @@ class LoRANetwork(torch.nn.Module): alphakeys.append(key.replace("lora_down.weight", "alpha")) for i in range(len(downkeys)): + max_norm_value = max_norm + for key in scale_map.keys(): + if fnmatch(downkeys[i], key): + max_norm_value = scale_map[key] + down = state_dict[downkeys[i]].to(device) up = state_dict[upkeys[i]].to(device) alpha = state_dict[alphakeys[i]].to(device) @@ -1404,7 +1411,7 @@ class LoRANetwork(torch.nn.Module): keys_scaled += 1 state_dict[upkeys[i]] *= sqrt_ratio state_dict[downkeys[i]] *= sqrt_ratio - scalednorm = updown.norm() * ratio + scalednorm: torch.Tensor = updown.norm() * ratio norms.append(scalednorm.item()) return keys_scaled, sum(norms) / len(norms), max(norms) diff --git a/train_network.py b/train_network.py index 6b8ed9bd..365684c1 100644 --- a/train_network.py +++ b/train_network.py @@ -12,6 +12,8 @@ import json from multiprocessing import Value import numpy as np +import ast + from tqdm import tqdm import torch @@ -1444,8 +1446,9 @@ class NetworkTrainer: optimizer.zero_grad(set_to_none=True) if args.scale_weight_norms: + scale_map = args.scale_weight_norms_map if args.scale_weight_norms_map else {} keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization( - args.scale_weight_norms, accelerator.device + args.scale_weight_norms, accelerator.device, scale_map=scale_map ) mean_grad_norm = None mean_combined_norm = None @@ -1713,6 +1716,14 @@ class NetworkTrainer: logger.info("model saved.") +def parse_dict(input_str): + """Convert string input into a dictionary.""" + try: + # Use ast.literal_eval to safely evaluate the string as a Python literal (dict) + return ast.literal_eval(input_str) + except ValueError: + raise argparse.ArgumentTypeError(f"Invalid dictionary format: {input_str}") + def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() @@ -1816,6 +1827,12 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)", ) + parser.add_argument( + "--scale_weight_norms_map", + type=parse_dict, + default="{}", + help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)", + ) parser.add_argument( "--base_weights", type=str,