Compare commits

..

9 Commits

Author SHA1 Message Date
Kohya S.
308a0cc9fc Merge pull request #2312 from kohya-ss/dev
Merge dev to main
2026-04-07 08:53:13 +09:00
Kohya S
7e60e163c1 Merge branch 'main' into dev 2026-04-07 08:49:58 +09:00
Kohya S.
a8f5c222e0 Merge pull request #2311 from kohya-ss/doc-update-readme-for-next-release
README: Add planned changes for next release (intel GPU compatibility)
2026-04-07 08:47:37 +09:00
Kohya S
1d588d6cb6 README: Add planned changes for next release and improve Intel GPU compatibility 2026-04-07 08:44:31 +09:00
Kohya S.
a7d35701a0 Merge pull request #2307 from WhitePr/dev
update ipex
2026-04-07 08:41:41 +09:00
WhitePr
8da05a10dc Update IPEX libs 2026-04-04 05:37:18 +09:00
WhitePr
197b129284 Modifying the method for get the Torch version 2026-04-04 04:44:53 +09:00
Kohya S.
51435f1718 Merge pull request #2303 from kohya-ss/sd3
fix: improve numerical stability by conditionally using float32 in Anima with fp16 training
2026-04-02 12:40:48 +09:00
Kohya S.
fa53f71ec0 fix: improve numerical stability by conditionally using float32 in Anima (#2302)
* fix: improve numerical stability by conditionally using float32 in block computations

* doc: update README for improvement stability for fp16 training on Anima in version 0.10.3
2026-04-02 12:36:29 +09:00
7 changed files with 37 additions and 219 deletions

View File

@@ -50,6 +50,12 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
### 更新履歴
- 次のリリースに含まれる予定の主な変更点は以下の通りです。リリース前の変更点は予告なく変更される可能性があります。
- Intel GPUの互換性を向上しました。[PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307) WhitePr氏に感謝します。
- **Version 0.10.3 (2026-04-02):**
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
- **Version 0.10.2 (2026-03-30):**
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。

View File

@@ -47,6 +47,12 @@ If you find this project helpful, please consider supporting its development via
### Change History
- The following are the main changes planned for the next release. Please note that these changes may be subject to change without notice before the release.
- Improved compatibility with Intel GPUs. Thanks to WhitePr for [PR #2307](https://github.com/kohya-ss/sd-scripts/pull/2307).
- **Version 0.10.3 (2026-04-02):**
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
- **Version 0.10.2 (2026-03-30):**
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
- Please refer to the [documentation](./docs/train_leco.md) for details.

View File

@@ -738,9 +738,9 @@ class FinalLayer(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
use_fp32: bool = False,
):
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
@@ -863,11 +863,11 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
if use_fp32:
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
x_B_T_H_W_D = x_B_T_H_W_D.float()
@@ -959,6 +959,7 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -972,6 +973,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -994,6 +996,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1007,6 +1010,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1018,6 +1022,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1338,16 +1343,19 @@ class Anima(nn.Module):
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
# Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs)
if self.blocks_to_swap:
self.offloader.submit_move_blocks(self.blocks, block_idx)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp

View File

@@ -1,6 +1,7 @@
import os
import sys
import torch
from packaging import version
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
has_ipex = True
@@ -8,7 +9,7 @@ except Exception:
has_ipex = False
from .hijacks import ipex_hijacks
torch_version = float(torch.__version__[:3])
torch_version = version.parse(torch.__version__)
# pylint: disable=protected-access, missing-function-docstring, line-too-long
@@ -56,7 +57,6 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.torch = torch.xpu.torch
torch.cuda.Union = torch.xpu.Union
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
@@ -64,14 +64,12 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.__file__ = torch.xpu.__file__
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
if torch_version < 2.3:
if torch_version < version.parse("2.3"):
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
@@ -114,17 +112,22 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.threading = torch.xpu.threading
torch.cuda.traceback = torch.xpu.traceback
if torch_version < 2.5:
if torch_version < version.parse("2.5"):
torch.cuda.os = torch.xpu.os
torch.cuda.Device = torch.xpu.Device
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.classproperty = torch.xpu.classproperty
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
if torch_version < 2.7:
if torch_version < version.parse("2.7"):
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.List = torch.xpu.List
if torch_version < version.parse("2.11"):
torch.cuda._device_t = torch.xpu._device_t
torch.cuda._device = torch.xpu._device
torch.cuda.Union = torch.xpu.Union
# Memory:
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
@@ -160,7 +163,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.initial_seed = torch.xpu.initial_seed
# C
if torch_version < 2.3:
if torch_version < version.parse("2.3"):
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 12

View File

@@ -1,18 +0,0 @@
import torch
from torch import nn, Tensor
class AID(nn.Module):
def __init__(self, p=0.9):
super().__init__()
self.p = p
def forward(self, x: Tensor):
if self.training:
pos_mask = (x >= 0) * torch.bernoulli(torch.ones_like(x) * self.p)
neg_mask = (x < 0) * torch.bernoulli(torch.ones_like(x) * (1 - self.p))
return x * (pos_mask + neg_mask)
else:
pos_part = (x >= 0) * x * self.p
neg_part = (x < 0) * x * (1 - self.p)
return pos_part + neg_part

View File

@@ -17,7 +17,6 @@ import numpy as np
import torch
from torch import Tensor
import re
from library.model_utils import AID
from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel
@@ -31,21 +30,6 @@ NUM_DOUBLE_BLOCKS = 19
NUM_SINGLE_BLOCKS = 38
def get_point_on_curve(block_id, total_blocks=38, peak=0.9, shift=0.75):
# Normalize the position to 0-1 range
normalized_pos = block_id / total_blocks
# Shift the sine curve to only use the first 3/4 of the cycle
# This gives us: start at 0, peak in the middle, end around 0.7
phase_shift = shift * math.pi
sine_value = math.sin(normalized_pos * phase_shift)
# Scale to our desired peak of 0.9
result = peak * sine_value
return result
class LoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
@@ -61,7 +45,6 @@ class LoRAModule(torch.nn.Module):
dropout=None,
rank_dropout=None,
module_dropout=None,
aid_dropout=None,
split_dims: Optional[List[int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
@@ -124,13 +107,6 @@ class LoRAModule(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.aid = (
AID(aid_dropout) if aid_dropout is not None else torch.nn.Identity()
) # AID activation
if aid_dropout is not None:
self.register_buffer("aid_p", torch.tensor(aid_dropout))
self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
@@ -182,9 +158,6 @@ class LoRAModule(torch.nn.Module):
lx = self.lora_up(lx)
# Activation by Interval-wise Dropout
lx = self.aid(lx)
# LoRA Gradient-Guided Perturbation Optimization
if (
self.training
@@ -581,9 +554,6 @@ def create_network(
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
aid_dropout = kwargs.get("aid_dropout", None)
if aid_dropout is not None:
aid_dropout = float(aid_dropout)
# single or double blocks
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
@@ -660,7 +630,6 @@ def create_network(
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
aid_dropout=aid_dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
train_blocks=train_blocks,
@@ -761,7 +730,6 @@ class LoRANetwork(torch.nn.Module):
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
aid_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
module_class: Type[object] = LoRAModule,
@@ -790,7 +758,6 @@ class LoRANetwork(torch.nn.Module):
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.aid_dropout = aid_dropout
self.train_blocks = train_blocks if train_blocks is not None else "all"
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl
@@ -867,7 +834,6 @@ class LoRANetwork(torch.nn.Module):
dim = None
alpha = None
aid_dropout_p = None
if modules_dim is not None:
# モジュール指定あり
@@ -903,21 +869,6 @@ class LoRANetwork(torch.nn.Module):
if d is not None and all([id in lora_name for id in identifier[i]]):
dim = d # may be 0 for skip
break
is_double = False
if "double" in lora_name:
is_double = True
is_single = False
if "single" in lora_name:
is_single = True
block_index = None
if is_flux and dim and (is_double or is_single):
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
block_index = int(lora_name.split("_")[4]) # bit dirty
if block_index is not None and aid_dropout is not None:
all_block_index = block_index if is_double else block_index + NUM_DOUBLE_BLOCKS
aid_dropout_p = get_point_on_curve(
all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS, peak=aid_dropout)
if (
is_flux
@@ -928,6 +879,8 @@ class LoRANetwork(torch.nn.Module):
)
and ("double" in lora_name or "single" in lora_name)
):
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
block_index = int(lora_name.split("_")[4]) # bit dirty
if (
"double" in lora_name
and self.train_double_block_indices is not None
@@ -968,7 +921,6 @@ class LoRANetwork(torch.nn.Module):
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
aid_dropout=aid_dropout_p if aid_dropout_p is not None else aid_dropout,
split_dims=split_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
@@ -1165,7 +1117,6 @@ class LoRANetwork(torch.nn.Module):
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]
alpha = state_dict.pop(f"{lora_name}.alpha")
aid_p = state_dict.pop(f"{lora_name}.aid_p")
# merge down weight
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
@@ -1181,7 +1132,6 @@ class LoRANetwork(torch.nn.Module):
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
new_state_dict[f"{lora_name}.alpha"] = alpha
new_state_dict[f"{lora_name}.aid_p"] = aid_p
# print(
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"

View File

@@ -1,137 +0,0 @@
import pytest
import torch
from library.model_utils import AID
from torch import nn
import torch.nn.functional as F
@pytest.fixture
def input_tensor():
# Create a tensor with positive and negative values
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
return x
def test_aid_forward_train_mode(input_tensor):
aid = AID(p=0.9)
aid.train()
# Run several forward passes to test stochastic behavior
results = []
for _ in range(10):
output = aid(input_tensor)
results.append(output.detach().clone())
# Test that outputs vary (stochastic behavior)
all_equal = all(torch.allclose(results[0], results[i]) for i in range(1, 10))
assert not all_equal, "All outputs are identical, expected variability in training mode"
# Test shape preservation
assert results[0].shape == input_tensor.shape
def test_aid_forward_eval_mode(input_tensor):
aid = AID(p=0.9)
aid.eval()
output = aid(input_tensor)
# Test deterministic behavior
output2 = aid(input_tensor)
assert torch.allclose(output, output2), "Expected deterministic behavior in eval mode"
# Test correct transformation
expected = 0.9 * F.relu(input_tensor) + 0.1 * F.relu(-input_tensor) * -1
assert torch.allclose(output, expected), "Incorrect evaluation mode transformation"
def test_aid_gradient_flow(input_tensor):
aid = AID(p=0.9)
aid.train()
# Forward pass
output = aid(input_tensor)
# Check gradient flow
assert output.requires_grad, "Output lost gradient tracking"
# Compute loss and backpropagate
loss = output.sum()
loss.backward()
# Verify gradients were computed
assert input_tensor.grad is not None, "No gradients were recorded for input tensor"
assert torch.any(input_tensor.grad != 0), "Gradients are all zeros"
def test_aid_extreme_p_values():
# Test with p=1.0 (only positive values pass through)
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
aid = AID(p=1.0)
aid.eval()
output = aid(x)
expected = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])
assert torch.allclose(output, expected), "Failed with p=1.0"
# Test with p=0.0 (only negative values pass through)
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
aid = AID(p=0.0)
aid.eval()
output = aid(x)
expected = torch.tensor([-2.0, -1.0, 0.0, 0.0, 0.0])
assert torch.allclose(output, expected), "Failed with p=0.0"
def test_aid_with_all_positive_values():
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
aid = AID(p=0.9)
aid.train()
# Run forward passes and check that only positive values are affected
output = aid(x)
# Backprop should work
loss = output.sum()
loss.backward()
assert x.grad is not None, "No gradients were recorded for all-positive input"
def test_aid_with_all_negative_values():
x = torch.tensor([-1.0, -2.0, -3.0, -4.0, -5.0], requires_grad=True)
aid = AID(p=0.9)
aid.train()
# Run forward passes and check that only negative values are affected
output = aid(x)
# Backprop should work
loss = output.sum()
loss.backward()
assert x.grad is not None, "No gradients were recorded for all-negative input"
def test_aid_with_zero_values():
x = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], requires_grad=True)
aid = AID(p=0.9)
# Test training mode
aid.train()
output = aid(x)
assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input"
# Test eval mode
aid.eval()
output = aid(x)
assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input"
def test_aid_integration_with_linear_layer():
# Test AID's compatibility with a linear layer
linear = nn.Linear(5, 2)
aid = AID(p=0.9)
model = nn.Sequential(linear, aid)
model.train()
x = torch.randn(3, 5, requires_grad=True)
output = model(x)
# Check that gradients flow through the whole model
loss = output.sum()
loss.backward()
assert linear.weight.grad is not None, "No gradients for linear layer weights"
assert x.grad is not None, "No gradients for input tensor"