Compare commits

..

8 Commits

Author SHA1 Message Date
Kohya S.
308a0cc9fc Merge pull request #2312 from kohya-ss/dev
Merge dev to main
2026-04-07 08:53:13 +09:00
Kohya S
7e60e163c1 Merge branch 'main' into dev 2026-04-07 08:49:58 +09:00
Kohya S.
a8f5c222e0 Merge pull request #2311 from kohya-ss/doc-update-readme-for-next-release
README: Add planned changes for next release (intel GPU compatibility)
2026-04-07 08:47:37 +09:00
Kohya S
1d588d6cb6 README: Add planned changes for next release and improve Intel GPU compatibility 2026-04-07 08:44:31 +09:00
Kohya S.
a7d35701a0 Merge pull request #2307 from WhitePr/dev
update ipex
2026-04-07 08:41:41 +09:00
WhitePr
8da05a10dc Update IPEX libs 2026-04-04 05:37:18 +09:00
WhitePr
197b129284 Modifying the method for get the Torch version 2026-04-04 04:44:53 +09:00
Kohya S.
51435f1718 Merge pull request #2303 from kohya-ss/sd3
fix: improve numerical stability by conditionally using float32 in Anima with fp16 training
2026-04-02 12:40:48 +09:00
6 changed files with 17 additions and 273 deletions

View File

@@ -50,6 +50,9 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
### 更新履歴
- 次のリリースに含まれる予定の主な変更点は以下の通りです。リリース前の変更点は予告なく変更される可能性があります。
- Intel GPUの互換性を向上しました。[PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307) WhitePr氏に感謝します。
- **Version 0.10.3 (2026-04-02):**
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。

View File

@@ -47,6 +47,9 @@ If you find this project helpful, please consider supporting its development via
### Change History
- The following are the main changes planned for the next release. Please note that these changes may be subject to change without notice before the release.
- Improved compatibility with Intel GPUs. Thanks to WhitePr for [PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307).
- **Version 0.10.3 (2026-04-02):**
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.

View File

@@ -1,6 +1,7 @@
import os
import sys
import torch
from packaging import version
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
has_ipex = True
@@ -8,7 +9,7 @@ except Exception:
has_ipex = False
from .hijacks import ipex_hijacks
torch_version = float(torch.__version__[:3])
torch_version = version.parse(torch.__version__)
# pylint: disable=protected-access, missing-function-docstring, line-too-long
@@ -56,7 +57,6 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.torch = torch.xpu.torch
torch.cuda.Union = torch.xpu.Union
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
@@ -64,14 +64,12 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.__file__ = torch.xpu.__file__
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
if torch_version < 2.3:
if torch_version < version.parse("2.3"):
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
@@ -114,17 +112,22 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.threading = torch.xpu.threading
torch.cuda.traceback = torch.xpu.traceback
if torch_version < 2.5:
if torch_version < version.parse("2.5"):
torch.cuda.os = torch.xpu.os
torch.cuda.Device = torch.xpu.Device
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.classproperty = torch.xpu.classproperty
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
if torch_version < 2.7:
if torch_version < version.parse("2.7"):
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.List = torch.xpu.List
if torch_version < version.parse("2.11"):
torch.cuda._device_t = torch.xpu._device_t
torch.cuda._device = torch.xpu._device
torch.cuda.Union = torch.xpu.Union
# Memory:
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
@@ -160,7 +163,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.initial_seed = torch.xpu.initial_seed
# C
if torch_version < 2.3:
if torch_version < version.parse("2.3"):
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 12

View File

@@ -48,8 +48,6 @@ class LoRAModule(torch.nn.Module):
split_dims: Optional[List[int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
mgpo_rho: float | None = None,
mgpo_beta: float | None = None,
):
"""
if alpha == 0 or None, alpha is rank (no scaling).
@@ -119,25 +117,6 @@ class LoRAModule(torch.nn.Module):
self.initialize_norm_cache(org_module.weight)
self.org_module_shape: tuple[int] = org_module.weight.shape
self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
self.mgpo_rho = mgpo_rho
self.mgpo_beta = mgpo_beta
# EMA of gradient magnitudes for adaptive normalization
self.register_buffer('_grad_magnitude_ema_down', torch.tensor(1.0), persistent=False)
self.register_buffer('_grad_magnitude_ema_up', torch.tensor(1.0), persistent=False)
self.optimizer: torch.optim.Optimizer | None = None
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
self.combined_weight_norms = None
self.grad_norms = None
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
self.initialize_norm_cache(org_module.weight)
self.org_module_shape: tuple[int] = org_module.weight.shape
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
@@ -179,18 +158,6 @@ class LoRAModule(torch.nn.Module):
lx = self.lora_up(lx)
# LoRA Momentum-Guided Perturbation Optimization (MGPO)
if (
self.training
and hasattr(self, "mgpo_rho")
and self.mgpo_rho is not None
and hasattr(self, "optimizer")
and self.optimizer is not None
):
mgpo_perturbation_output = self.get_mgpo_output_perturbation(x)
if mgpo_perturbation_output is not None:
return org_forwarded + (self.multiplier * scale * lx) + mgpo_perturbation_output
# LoRA Gradient-Guided Perturbation Optimization
if (
self.training
@@ -337,97 +304,6 @@ class LoRAModule(torch.nn.Module):
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
def update_gradient_ema(self):
"""
Update EMA of gradient magnitudes for adaptive perturbation normalization
Formula: ḡₗ⁽ᵗ⁾ = β * ḡₗ⁽ᵗ⁻¹⁾ + (1 - β) * ||∇ΔWₗL||₂
"""
if self.mgpo_beta is None:
return
# Update EMA for lora_down gradient magnitude
if self.lora_down.weight.grad is not None:
current_grad_norm = torch.norm(self.lora_down.weight.grad, p=2)
self._grad_magnitude_ema_down.mul_(self.mgpo_beta).add_(
current_grad_norm, alpha=(1 - self.mgpo_beta)
)
# Update EMA for lora_up gradient magnitude
if self.lora_up.weight.grad is not None:
current_grad_norm = torch.norm(self.lora_up.weight.grad, p=2)
self._grad_magnitude_ema_up.mul_(self.mgpo_beta).add_(
current_grad_norm, alpha=(1 - self.mgpo_beta)
)
def get_mgpo_output_perturbation(self, x: Tensor) -> Tensor | None:
"""
Generate MGPO perturbation using both momentum direction and gradient magnitude normalization
Full MGPO Formula: ε = -ρ · (vₜ / ||vₜ||₂) · (ḡₗ⁽ᵗ⁾)⁻¹
Where:
- ε = perturbation vector
- ρ = perturbation radius (mgpo_rho)
- vₜ = momentum vector from optimizer (exp_avg) - provides DIRECTION
- ||vₜ||₂ = L2 norm of momentum for unit direction
- ḡₗ⁽ᵗ⁾ = EMA of gradient magnitude - provides ADAPTIVE SCALING
Two separate EMAs:
1. Momentum EMA (from Adam): vₜ = β₁ * vₜ₋₁ + (1 - β₁) * ∇L(Wₜ)
2. Gradient Magnitude EMA: ḡₗ⁽ᵗ⁾ = β * ḡₗ⁽ᵗ⁻¹⁾ + (1 - β) * ||∇L(Wₜ)||₂
"""
if self.optimizer is None or self.mgpo_rho is None or self.mgpo_beta is None:
return None
total_perturbation_scale = 0.0
valid_params = 0
# Handle both single and split dims cases
if self.split_dims is None:
params_and_emas = [
(self.lora_down.weight, self._grad_magnitude_ema_down),
(self.lora_up.weight, self._grad_magnitude_ema_up),
]
else:
# For split dims, use average EMA (or extend to per-param EMAs)
avg_ema = (self._grad_magnitude_ema_down + self._grad_magnitude_ema_up) / 2
params_and_emas = []
for lora_down in self.lora_down:
params_and_emas.append((lora_down.weight, avg_ema))
for lora_up in self.lora_up:
params_and_emas.append((lora_up.weight, avg_ema))
for param, grad_ema in params_and_emas:
if param in self.optimizer.state and "exp_avg" in self.optimizer.state[param]:
# Get momentum direction: vₜ / ||vₜ||₂
momentum = self.optimizer.state[param]["exp_avg"]
momentum_norm = torch.norm(momentum, p=2)
if momentum_norm > 1e-8 and grad_ema > 1e-8:
# Apply full MGPO formula: ρ · (momentum_direction) · (1/grad_magnitude_ema)
direction_component = momentum_norm # We'll use this for scaling
adaptive_scale = 1.0 / grad_ema # Adaptive normalization
perturbation_scale = self.mgpo_rho * direction_component * adaptive_scale
total_perturbation_scale += perturbation_scale.item()
valid_params += 1
if valid_params == 0:
return None
# Average perturbation scale across all valid parameters
avg_perturbation_scale = total_perturbation_scale / valid_params
with torch.no_grad():
# Generate random perturbation scaled by MGPO formula
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
perturbation.mul_(avg_perturbation_scale)
perturbation_output = x @ perturbation.T # Result: (batch × n)
return perturbation_output
def register_optimizer(self, optimizer):
self.optimizer = optimizer
@property
def device(self):
return next(self.parameters()).device
@@ -698,15 +574,6 @@ def create_network(
if ggpo_sigma is not None:
ggpo_sigma = float(ggpo_sigma)
mgpo_beta = kwargs.get("mgpo_beta", None)
mgpo_rho = kwargs.get("mgpo_rho", None)
if mgpo_beta is not None:
mgpo_beta = float(mgpo_beta)
if mgpo_rho is not None:
mgpo_rho = float(mgpo_rho)
# train T5XXL
train_t5xxl = kwargs.get("train_t5xxl", False)
if train_t5xxl is not None:
@@ -775,8 +642,6 @@ def create_network(
reg_dims=reg_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
mgpo_rho=mgpo_rho,
mgpo_beta=mgpo_beta,
reg_lrs=reg_lrs,
verbose=verbose,
)
@@ -880,8 +745,6 @@ class LoRANetwork(torch.nn.Module):
reg_dims: Optional[Dict[str, int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
mgpo_rho: Optional[float] = None,
mgpo_beta: Optional[float] = None,
reg_lrs: Optional[Dict[str, float]] = None,
verbose: Optional[bool] = False,
) -> None:
@@ -927,8 +790,6 @@ class LoRANetwork(torch.nn.Module):
if ggpo_beta is not None and ggpo_sigma is not None:
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")
if mgpo_beta is not None and mgpo_rho is not None:
logger.info(f"LoRA-MGPO training rho: {mgpo_rho} beta: {mgpo_beta}")
if self.split_qkv:
logger.info(f"split qkv for LoRA")
if self.train_blocks is not None:
@@ -1063,8 +924,6 @@ class LoRANetwork(torch.nn.Module):
split_dims=split_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
mgpo_rho=mgpo_rho,
mgpo_beta=mgpo_beta,
)
loras.append(lora)

View File

@@ -1,119 +0,0 @@
import pytest
import torch
import math
from networks.lora_flux import LoRAModule
class MockLinear(torch.nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))
self.in_features = in_features
self.out_features = out_features
def forward(self, x):
return torch.matmul(x, self.weight.t())
def state_dict(self):
return {"weight": self.weight}
class MockOptimizer:
def __init__(self, param):
self.state = {param: {"exp_avg": torch.randn_like(param)}}
@pytest.fixture
def lora_module():
org_module = MockLinear(10, 20)
lora_module = LoRAModule(org_module, org_module, multiplier=1.0, lora_dim=4, alpha=1.0, mgpo_rho=0.1, mgpo_beta=0.9)
# Manually set org_module_shape to match the original module's weight
lora_module.org_module_shape = org_module.weight.shape
return lora_module
def test_mgpo_parameter_initialization(lora_module):
"""Test MGPO-specific parameter initialization."""
# Check MGPO-specific attributes
assert hasattr(lora_module, "mgpo_rho")
assert hasattr(lora_module, "mgpo_beta")
assert lora_module.mgpo_rho == 0.1
assert lora_module.mgpo_beta == 0.9
# Check EMA parameters initialization
assert hasattr(lora_module, "_grad_magnitude_ema_down")
assert hasattr(lora_module, "_grad_magnitude_ema_up")
assert isinstance(lora_module._grad_magnitude_ema_down, torch.nn.Parameter)
assert isinstance(lora_module._grad_magnitude_ema_up, torch.nn.Parameter)
assert lora_module._grad_magnitude_ema_down.requires_grad == False
assert lora_module._grad_magnitude_ema_up.requires_grad == False
assert lora_module._grad_magnitude_ema_down.item() == 1.0
assert lora_module._grad_magnitude_ema_up.item() == 1.0
def test_update_gradient_ema(lora_module):
"""Test gradient EMA update method."""
# Ensure method works when mgpo_beta is set
lora_module.lora_down.weight.grad = torch.randn_like(lora_module.lora_down.weight)
lora_module.lora_up.weight.grad = torch.randn_like(lora_module.lora_up.weight)
# Store initial EMA values
initial_down_ema = lora_module._grad_magnitude_ema_down.clone()
initial_up_ema = lora_module._grad_magnitude_ema_up.clone()
# Update gradient EMA
lora_module.update_gradient_ema()
# Check EMA update logic
down_grad_norm = torch.norm(lora_module.lora_down.weight.grad, p=2)
up_grad_norm = torch.norm(lora_module.lora_up.weight.grad, p=2)
# Verify EMA calculation
expected_down_ema = lora_module.mgpo_beta * initial_down_ema + (1 - lora_module.mgpo_beta) * down_grad_norm
expected_up_ema = lora_module.mgpo_beta * initial_up_ema + (1 - lora_module.mgpo_beta) * up_grad_norm
assert torch.allclose(lora_module._grad_magnitude_ema_down, expected_down_ema, rtol=1e-5)
assert torch.allclose(lora_module._grad_magnitude_ema_up, expected_up_ema, rtol=1e-5)
# Test when mgpo_beta is None
lora_module.mgpo_beta = None
lora_module.update_gradient_ema() # Should not raise an exception
def test_get_mgpo_output_perturbation(lora_module):
"""Test MGPO perturbation generation."""
# Create a mock optimizer
mock_optimizer = MockOptimizer(lora_module.lora_down.weight)
lora_module.register_optimizer(mock_optimizer)
# Prepare input
x = torch.randn(5, 10) # batch × input_dim
# Ensure method works with valid conditions
perturbation = lora_module.get_mgpo_output_perturbation(x)
# Verify perturbation characteristics
assert perturbation is not None
assert isinstance(perturbation, torch.Tensor)
assert perturbation.shape == (x.shape[0], lora_module.org_module.out_features)
# Test when conditions are not met
lora_module.optimizer = None
lora_module.mgpo_rho = None
lora_module.mgpo_beta = None
no_perturbation = lora_module.get_mgpo_output_perturbation(x)
assert no_perturbation is None
def test_register_optimizer(lora_module):
"""Test optimizer registration method."""
# Create a mock optimizer
mock_optimizer = MockOptimizer(lora_module.lora_down.weight)
# Register optimizer
lora_module.register_optimizer(mock_optimizer)
# Verify optimizer is correctly registered
assert hasattr(lora_module, "optimizer")
assert lora_module.optimizer == mock_optimizer

View File

@@ -750,9 +750,6 @@ class NetworkTrainer:
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
if hasattr(network, "register_optimizer"):
network.register_optimizer(optimizer)
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
@@ -1441,8 +1438,6 @@ class NetworkTrainer:
network.update_grad_norms()
if hasattr(network, "update_norms"):
network.update_norms()
if hasattr(network, "update_gradient_ema"):
network.update_gradient_ema()
optimizer.step()
lr_scheduler.step()