mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Merge dfe1da4d36 into fa53f71ec0
This commit is contained in:
@@ -4816,6 +4816,10 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
|||||||
ignore_nesting_dict[section_name] = section_dict
|
ignore_nesting_dict[section_name] = section_dict
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if section_name == "scale_weight_norms_map":
|
||||||
|
ignore_nesting_dict[section_name] = section_dict
|
||||||
|
continue
|
||||||
|
|
||||||
# if value is dict, save all key and value into one dict
|
# if value is dict, save all key and value into one dict
|
||||||
for key, value in section_dict.items():
|
for key, value in section_dict.items():
|
||||||
ignore_nesting_dict[key] = value
|
ignore_nesting_dict[key] = value
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
from fnmatch import fnmatch
|
||||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
@@ -1366,7 +1367,8 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
org_module._lora_restored = False
|
org_module._lora_restored = False
|
||||||
lora.enabled = False
|
lora.enabled = False
|
||||||
|
|
||||||
def apply_max_norm_regularization(self, max_norm_value, device):
|
@torch.no_grad()
|
||||||
|
def apply_max_norm_regularization(self, max_norm, device, scale_map: dict[str, float]={}):
|
||||||
downkeys = []
|
downkeys = []
|
||||||
upkeys = []
|
upkeys = []
|
||||||
alphakeys = []
|
alphakeys = []
|
||||||
@@ -1381,6 +1383,11 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
||||||
|
|
||||||
for i in range(len(downkeys)):
|
for i in range(len(downkeys)):
|
||||||
|
max_norm_value = max_norm
|
||||||
|
for key in scale_map.keys():
|
||||||
|
if fnmatch(downkeys[i], key):
|
||||||
|
max_norm_value = scale_map[key]
|
||||||
|
|
||||||
down = state_dict[downkeys[i]].to(device)
|
down = state_dict[downkeys[i]].to(device)
|
||||||
up = state_dict[upkeys[i]].to(device)
|
up = state_dict[upkeys[i]].to(device)
|
||||||
alpha = state_dict[alphakeys[i]].to(device)
|
alpha = state_dict[alphakeys[i]].to(device)
|
||||||
@@ -1404,7 +1411,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
keys_scaled += 1
|
keys_scaled += 1
|
||||||
state_dict[upkeys[i]] *= sqrt_ratio
|
state_dict[upkeys[i]] *= sqrt_ratio
|
||||||
state_dict[downkeys[i]] *= sqrt_ratio
|
state_dict[downkeys[i]] *= sqrt_ratio
|
||||||
scalednorm = updown.norm() * ratio
|
scalednorm: torch.Tensor = updown.norm() * ratio
|
||||||
norms.append(scalednorm.item())
|
norms.append(scalednorm.item())
|
||||||
|
|
||||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ import json
|
|||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import ast
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -1444,8 +1446,9 @@ class NetworkTrainer:
|
|||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
if args.scale_weight_norms:
|
if args.scale_weight_norms:
|
||||||
|
scale_map = args.scale_weight_norms_map if args.scale_weight_norms_map else {}
|
||||||
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
||||||
args.scale_weight_norms, accelerator.device
|
args.scale_weight_norms, accelerator.device, scale_map=scale_map
|
||||||
)
|
)
|
||||||
mean_grad_norm = None
|
mean_grad_norm = None
|
||||||
mean_combined_norm = None
|
mean_combined_norm = None
|
||||||
@@ -1713,6 +1716,14 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
logger.info("model saved.")
|
logger.info("model saved.")
|
||||||
|
|
||||||
|
def parse_dict(input_str):
|
||||||
|
"""Convert string input into a dictionary."""
|
||||||
|
try:
|
||||||
|
# Use ast.literal_eval to safely evaluate the string as a Python literal (dict)
|
||||||
|
return ast.literal_eval(input_str)
|
||||||
|
except ValueError:
|
||||||
|
raise argparse.ArgumentTypeError(f"Invalid dictionary format: {input_str}")
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@@ -1816,6 +1827,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
default=None,
|
default=None,
|
||||||
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
|
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scale_weight_norms_map",
|
||||||
|
type=parse_dict,
|
||||||
|
default="{}",
|
||||||
|
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base_weights",
|
"--base_weights",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
Reference in New Issue
Block a user