Compare commits

..

3 Commits

Author SHA1 Message Date
Dave Lage
85ea2c004b Merge dfe1da4d36 into fa53f71ec0 2026-04-05 01:14:11 +00:00
rockerBOO
dfe1da4d36 Add fnmatch. Make max_norm no_grad 2025-01-23 14:24:57 -05:00
rockerBOO
b0d0d43bfa Add scale map to max_norm 2025-01-11 13:42:20 -05:00
6 changed files with 39 additions and 20 deletions

View File

@@ -50,9 +50,6 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
### 更新履歴
- 次のリリースに含まれる予定の主な変更点は以下の通りです。リリース前の変更点は予告なく変更される可能性があります。
- Intel GPUの互換性を向上しました。[PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307) WhitePr氏に感謝します。
- **Version 0.10.3 (2026-04-02):**
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。

View File

@@ -47,9 +47,6 @@ If you find this project helpful, please consider supporting its development via
### Change History
- The following are the main changes planned for the next release. Please note that these changes may be subject to change without notice before the release.
- Improved compatibility with Intel GPUs. Thanks to WhitePr for [PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307).
- **Version 0.10.3 (2026-04-02):**
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.

View File

@@ -1,7 +1,6 @@
import os
import sys
import torch
from packaging import version
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
has_ipex = True
@@ -9,7 +8,7 @@ except Exception:
has_ipex = False
from .hijacks import ipex_hijacks
torch_version = version.parse(torch.__version__)
torch_version = float(torch.__version__[:3])
# pylint: disable=protected-access, missing-function-docstring, line-too-long
@@ -57,6 +56,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.torch = torch.xpu.torch
torch.cuda.Union = torch.xpu.Union
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
@@ -64,12 +64,14 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.__file__ = torch.xpu.__file__
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
if torch_version < version.parse("2.3"):
if torch_version < 2.3:
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
@@ -112,22 +114,17 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.threading = torch.xpu.threading
torch.cuda.traceback = torch.xpu.traceback
if torch_version < version.parse("2.5"):
if torch_version < 2.5:
torch.cuda.os = torch.xpu.os
torch.cuda.Device = torch.xpu.Device
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.classproperty = torch.xpu.classproperty
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
if torch_version < version.parse("2.7"):
if torch_version < 2.7:
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.List = torch.xpu.List
if torch_version < version.parse("2.11"):
torch.cuda._device_t = torch.xpu._device_t
torch.cuda._device = torch.xpu._device
torch.cuda.Union = torch.xpu.Union
# Memory:
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
@@ -163,7 +160,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.initial_seed = torch.xpu.initial_seed
# C
if torch_version < version.parse("2.3"):
if torch_version < 2.3:
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 12

View File

@@ -4816,6 +4816,10 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
ignore_nesting_dict[section_name] = section_dict
continue
if section_name == "scale_weight_norms_map":
ignore_nesting_dict[section_name] = section_dict
continue
# if value is dict, save all key and value into one dict
for key, value in section_dict.items():
ignore_nesting_dict[key] = value

View File

@@ -5,6 +5,7 @@
import math
import os
from fnmatch import fnmatch
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
@@ -1366,7 +1367,8 @@ class LoRANetwork(torch.nn.Module):
org_module._lora_restored = False
lora.enabled = False
def apply_max_norm_regularization(self, max_norm_value, device):
@torch.no_grad()
def apply_max_norm_regularization(self, max_norm, device, scale_map: dict[str, float]={}):
downkeys = []
upkeys = []
alphakeys = []
@@ -1381,6 +1383,11 @@ class LoRANetwork(torch.nn.Module):
alphakeys.append(key.replace("lora_down.weight", "alpha"))
for i in range(len(downkeys)):
max_norm_value = max_norm
for key in scale_map.keys():
if fnmatch(downkeys[i], key):
max_norm_value = scale_map[key]
down = state_dict[downkeys[i]].to(device)
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
@@ -1404,7 +1411,7 @@ class LoRANetwork(torch.nn.Module):
keys_scaled += 1
state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm() * ratio
scalednorm: torch.Tensor = updown.norm() * ratio
norms.append(scalednorm.item())
return keys_scaled, sum(norms) / len(norms), max(norms)

View File

@@ -12,6 +12,8 @@ import json
from multiprocessing import Value
import numpy as np
import ast
from tqdm import tqdm
import torch
@@ -1444,8 +1446,9 @@ class NetworkTrainer:
optimizer.zero_grad(set_to_none=True)
if args.scale_weight_norms:
scale_map = args.scale_weight_norms_map if args.scale_weight_norms_map else {}
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
args.scale_weight_norms, accelerator.device, scale_map=scale_map
)
mean_grad_norm = None
mean_combined_norm = None
@@ -1713,6 +1716,14 @@ class NetworkTrainer:
logger.info("model saved.")
def parse_dict(input_str):
"""Convert string input into a dictionary."""
try:
# Use ast.literal_eval to safely evaluate the string as a Python literal (dict)
return ast.literal_eval(input_str)
except ValueError:
raise argparse.ArgumentTypeError(f"Invalid dictionary format: {input_str}")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
@@ -1816,6 +1827,12 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ1が初期値としては適当",
)
parser.add_argument(
"--scale_weight_norms_map",
type=parse_dict,
default="{}",
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ1が初期値としては適当",
)
parser.add_argument(
"--base_weights",
type=str,