Compare commits

...

11 Commits

Author SHA1 Message Date
Dave Lage
88a4442b7e Merge ac120e68ef into fa53f71ec0 2026-04-02 06:31:26 +00:00
Kohya S.
fa53f71ec0 fix: improve numerical stability by conditionally using float32 in Anima (#2302)
* fix: improve numerical stability by conditionally using float32 in block computations

* doc: update README for improvement stability for fp16 training on Anima in version 0.10.3
2026-04-02 12:36:29 +09:00
rockerBOO
ac120e68ef Add aid_p buffer to save AID 2025-04-18 02:02:47 -04:00
rockerBOO
aefea026a7 Add AID tests 2025-04-15 15:38:57 -04:00
rockerBOO
c2f75f43a4 Do not convert to float 2025-04-15 04:41:59 -04:00
rockerBOO
4d005cdf3d Remove old code 2025-04-14 15:52:25 -04:00
rockerBOO
9bc392001c Simplify AID implementation. AID(B*A) instead of AID(B)*A. 2025-04-14 15:50:54 -04:00
rockerBOO
584ea4ee34 Improve performance. Add curve for AID probabilities 2025-04-14 01:56:37 -04:00
rockerBOO
956275f295 Add pythonpath to pytest.ini 2025-04-13 21:31:35 -04:00
rockerBOO
61af45ef3c Add AID_GELU. Add dropout curve for AID 2025-04-13 21:30:35 -04:00
rockerBOO
dfae3a486c Add AID activation interval-wise dropout 2025-04-13 13:57:52 -04:00
6 changed files with 225 additions and 6 deletions

View File

@@ -50,6 +50,9 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
### 更新履歴
- **Version 0.10.3 (2026-04-02):**
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
- **Version 0.10.2 (2026-03-30):**
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。

View File

@@ -47,6 +47,9 @@ If you find this project helpful, please consider supporting its development via
### Change History
- **Version 0.10.3 (2026-04-02):**
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
- **Version 0.10.2 (2026-03-30):**
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
- Please refer to the [documentation](./docs/train_leco.md) for details.

View File

@@ -738,9 +738,9 @@ class FinalLayer(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
use_fp32: bool = False,
):
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
@@ -863,11 +863,11 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
if use_fp32:
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
x_B_T_H_W_D = x_B_T_H_W_D.float()
@@ -959,6 +959,7 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -972,6 +973,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -994,6 +996,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1007,6 +1010,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1018,6 +1022,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1338,16 +1343,19 @@ class Anima(nn.Module):
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
# Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs)
if self.blocks_to_swap:
self.offloader.submit_move_blocks(self.blocks, block_idx)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp

18
library/model_utils.py Normal file
View File

@@ -0,0 +1,18 @@
import torch
from torch import nn, Tensor
class AID(nn.Module):
def __init__(self, p=0.9):
super().__init__()
self.p = p
def forward(self, x: Tensor):
if self.training:
pos_mask = (x >= 0) * torch.bernoulli(torch.ones_like(x) * self.p)
neg_mask = (x < 0) * torch.bernoulli(torch.ones_like(x) * (1 - self.p))
return x * (pos_mask + neg_mask)
else:
pos_part = (x >= 0) * x * self.p
neg_part = (x < 0) * x * (1 - self.p)
return pos_part + neg_part

View File

@@ -17,6 +17,7 @@ import numpy as np
import torch
from torch import Tensor
import re
from library.model_utils import AID
from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel
@@ -30,6 +31,21 @@ NUM_DOUBLE_BLOCKS = 19
NUM_SINGLE_BLOCKS = 38
def get_point_on_curve(block_id, total_blocks=38, peak=0.9, shift=0.75):
# Normalize the position to 0-1 range
normalized_pos = block_id / total_blocks
# Shift the sine curve to only use the first 3/4 of the cycle
# This gives us: start at 0, peak in the middle, end around 0.7
phase_shift = shift * math.pi
sine_value = math.sin(normalized_pos * phase_shift)
# Scale to our desired peak of 0.9
result = peak * sine_value
return result
class LoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
@@ -45,6 +61,7 @@ class LoRAModule(torch.nn.Module):
dropout=None,
rank_dropout=None,
module_dropout=None,
aid_dropout=None,
split_dims: Optional[List[int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
@@ -107,6 +124,13 @@ class LoRAModule(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.aid = (
AID(aid_dropout) if aid_dropout is not None else torch.nn.Identity()
) # AID activation
if aid_dropout is not None:
self.register_buffer("aid_p", torch.tensor(aid_dropout))
self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
@@ -158,6 +182,9 @@ class LoRAModule(torch.nn.Module):
lx = self.lora_up(lx)
# Activation by Interval-wise Dropout
lx = self.aid(lx)
# LoRA Gradient-Guided Perturbation Optimization
if (
self.training
@@ -554,6 +581,9 @@ def create_network(
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
aid_dropout = kwargs.get("aid_dropout", None)
if aid_dropout is not None:
aid_dropout = float(aid_dropout)
# single or double blocks
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
@@ -630,6 +660,7 @@ def create_network(
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
aid_dropout=aid_dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
train_blocks=train_blocks,
@@ -730,6 +761,7 @@ class LoRANetwork(torch.nn.Module):
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
aid_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
module_class: Type[object] = LoRAModule,
@@ -758,6 +790,7 @@ class LoRANetwork(torch.nn.Module):
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.aid_dropout = aid_dropout
self.train_blocks = train_blocks if train_blocks is not None else "all"
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl
@@ -834,6 +867,7 @@ class LoRANetwork(torch.nn.Module):
dim = None
alpha = None
aid_dropout_p = None
if modules_dim is not None:
# モジュール指定あり
@@ -869,6 +903,21 @@ class LoRANetwork(torch.nn.Module):
if d is not None and all([id in lora_name for id in identifier[i]]):
dim = d # may be 0 for skip
break
is_double = False
if "double" in lora_name:
is_double = True
is_single = False
if "single" in lora_name:
is_single = True
block_index = None
if is_flux and dim and (is_double or is_single):
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
block_index = int(lora_name.split("_")[4]) # bit dirty
if block_index is not None and aid_dropout is not None:
all_block_index = block_index if is_double else block_index + NUM_DOUBLE_BLOCKS
aid_dropout_p = get_point_on_curve(
all_block_index, NUM_DOUBLE_BLOCKS + NUM_SINGLE_BLOCKS, peak=aid_dropout)
if (
is_flux
@@ -879,8 +928,6 @@ class LoRANetwork(torch.nn.Module):
)
and ("double" in lora_name or "single" in lora_name)
):
# "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..."
block_index = int(lora_name.split("_")[4]) # bit dirty
if (
"double" in lora_name
and self.train_double_block_indices is not None
@@ -921,6 +968,7 @@ class LoRANetwork(torch.nn.Module):
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
aid_dropout=aid_dropout_p if aid_dropout_p is not None else aid_dropout,
split_dims=split_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
@@ -1117,6 +1165,7 @@ class LoRANetwork(torch.nn.Module):
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]
alpha = state_dict.pop(f"{lora_name}.alpha")
aid_p = state_dict.pop(f"{lora_name}.aid_p")
# merge down weight
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
@@ -1132,6 +1181,7 @@ class LoRANetwork(torch.nn.Module):
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
new_state_dict[f"{lora_name}.alpha"] = alpha
new_state_dict[f"{lora_name}.aid_p"] = aid_p
# print(
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"

View File

@@ -0,0 +1,137 @@
import pytest
import torch
from library.model_utils import AID
from torch import nn
import torch.nn.functional as F
@pytest.fixture
def input_tensor():
# Create a tensor with positive and negative values
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
return x
def test_aid_forward_train_mode(input_tensor):
aid = AID(p=0.9)
aid.train()
# Run several forward passes to test stochastic behavior
results = []
for _ in range(10):
output = aid(input_tensor)
results.append(output.detach().clone())
# Test that outputs vary (stochastic behavior)
all_equal = all(torch.allclose(results[0], results[i]) for i in range(1, 10))
assert not all_equal, "All outputs are identical, expected variability in training mode"
# Test shape preservation
assert results[0].shape == input_tensor.shape
def test_aid_forward_eval_mode(input_tensor):
aid = AID(p=0.9)
aid.eval()
output = aid(input_tensor)
# Test deterministic behavior
output2 = aid(input_tensor)
assert torch.allclose(output, output2), "Expected deterministic behavior in eval mode"
# Test correct transformation
expected = 0.9 * F.relu(input_tensor) + 0.1 * F.relu(-input_tensor) * -1
assert torch.allclose(output, expected), "Incorrect evaluation mode transformation"
def test_aid_gradient_flow(input_tensor):
aid = AID(p=0.9)
aid.train()
# Forward pass
output = aid(input_tensor)
# Check gradient flow
assert output.requires_grad, "Output lost gradient tracking"
# Compute loss and backpropagate
loss = output.sum()
loss.backward()
# Verify gradients were computed
assert input_tensor.grad is not None, "No gradients were recorded for input tensor"
assert torch.any(input_tensor.grad != 0), "Gradients are all zeros"
def test_aid_extreme_p_values():
# Test with p=1.0 (only positive values pass through)
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
aid = AID(p=1.0)
aid.eval()
output = aid(x)
expected = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])
assert torch.allclose(output, expected), "Failed with p=1.0"
# Test with p=0.0 (only negative values pass through)
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True)
aid = AID(p=0.0)
aid.eval()
output = aid(x)
expected = torch.tensor([-2.0, -1.0, 0.0, 0.0, 0.0])
assert torch.allclose(output, expected), "Failed with p=0.0"
def test_aid_with_all_positive_values():
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], requires_grad=True)
aid = AID(p=0.9)
aid.train()
# Run forward passes and check that only positive values are affected
output = aid(x)
# Backprop should work
loss = output.sum()
loss.backward()
assert x.grad is not None, "No gradients were recorded for all-positive input"
def test_aid_with_all_negative_values():
x = torch.tensor([-1.0, -2.0, -3.0, -4.0, -5.0], requires_grad=True)
aid = AID(p=0.9)
aid.train()
# Run forward passes and check that only negative values are affected
output = aid(x)
# Backprop should work
loss = output.sum()
loss.backward()
assert x.grad is not None, "No gradients were recorded for all-negative input"
def test_aid_with_zero_values():
x = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], requires_grad=True)
aid = AID(p=0.9)
# Test training mode
aid.train()
output = aid(x)
assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input"
# Test eval mode
aid.eval()
output = aid(x)
assert torch.allclose(output, torch.zeros_like(output)), "Expected zeros out for zero input"
def test_aid_integration_with_linear_layer():
# Test AID's compatibility with a linear layer
linear = nn.Linear(5, 2)
aid = AID(p=0.9)
model = nn.Sequential(linear, aid)
model.train()
x = torch.randn(3, 5, requires_grad=True)
output = model(x)
# Check that gradients flow through the whole model
loss = output.sum()
loss.backward()
assert linear.weight.grad is not None, "No gradients for linear layer weights"
assert x.grad is not None, "No gradients for input tensor"