Compare commits

...

7 Commits

Author SHA1 Message Date
Urle Sistiana
e721d21b17 Merge 3bbfa9b258 into fa53f71ec0 2026-04-05 00:38:08 +00:00
Kohya S.
fa53f71ec0 fix: improve numerical stability by conditionally using float32 in Anima (#2302)
* fix: improve numerical stability by conditionally using float32 in block computations

* doc: update README for improvement stability for fp16 training on Anima in version 0.10.3
2026-04-02 12:36:29 +09:00
urlesistiana
3bbfa9b258 fixed FeedForward 2025-10-01 18:00:16 +08:00
urlesistiana
3420a6f7d1 check activation_memory_budget value range and accept value 0 2025-10-01 17:55:24 +08:00
urlesistiana
45cab086cc make torch.compile happy
don't compile funcs with complex ops

simplify FeedForward to avoid "cache line invalidated" error
2025-10-01 17:40:12 +08:00
urlesistiana
f25cb8abd1 add --activation_memory_budget to training arguments
remove SDSCRIPTS_TORCH_COMPILE_ACTIVATION_MEMORY_BUDGET env
2025-09-30 22:03:30 +08:00
urlesistiana
c15e6b4f3b feat: add selective torch compile and activation memory budget to Lumina 2 2025-09-30 15:25:13 +08:00
5 changed files with 50 additions and 11 deletions

View File

@@ -50,6 +50,9 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
### 更新履歴
- **Version 0.10.3 (2026-04-02):**
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
- **Version 0.10.2 (2026-03-30):**
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。

View File

@@ -47,6 +47,9 @@ If you find this project helpful, please consider supporting its development via
### Change History
- **Version 0.10.3 (2026-04-02):**
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
- **Version 0.10.2 (2026-03-30):**
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
- Please refer to the [documentation](./docs/train_leco.md) for details.

View File

@@ -738,9 +738,9 @@ class FinalLayer(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
use_fp32: bool = False,
):
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
@@ -863,11 +863,11 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
if use_fp32:
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
x_B_T_H_W_D = x_B_T_H_W_D.float()
@@ -959,6 +959,7 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -972,6 +973,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -994,6 +996,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1007,6 +1010,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1018,6 +1022,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1338,16 +1343,19 @@ class Anima(nn.Module):
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
# Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs)
if self.blocks_to_swap:
self.offloader.submit_move_blocks(self.blocks, block_idx)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp

View File

@@ -20,6 +20,7 @@
# --------------------------------------------------------
import math
import os
from typing import List, Optional, Tuple
from dataclasses import dataclass
@@ -31,6 +32,10 @@ import torch.nn.functional as F
from library import custom_offloading_utils
disable_selective_torch_compile = (
os.getenv("SDSCRIPTS_SELECTIVE_TORCH_COMPILE", "0") == "0"
)
try:
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@@ -549,7 +554,7 @@ class JointAttention(nn.Module):
f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
)
@torch.compiler.disable(reason="complex ops inside")
def apply_rope(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
@@ -625,13 +630,10 @@ class FeedForward(nn.Module):
bias=False,
)
nn.init.xavier_uniform_(self.w3.weight)
# @torch.compile
def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3
@torch.compile(disable=disable_selective_torch_compile)
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
return self.w2(F.silu(self.w1(x))*self.w3(x))
class JointTransformerBlock(GradientCheckpointMixin):
@@ -697,6 +699,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
@torch.compile(disable=disable_selective_torch_compile)
def _forward(
self,
x: torch.Tensor,
@@ -788,6 +791,7 @@ class FinalLayer(GradientCheckpointMixin):
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
@torch.compile(disable=disable_selective_torch_compile)
def forward(self, x, c):
scale = self.adaLN_modulation(c)
x = modulate(self.norm_final(x), scale)
@@ -808,6 +812,7 @@ class RopeEmbedder:
self.axes_lens = axes_lens
self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
@torch.compiler.disable(reason="complex ops inside")
def __call__(self, ids: torch.Tensor):
device = ids.device
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
@@ -1219,6 +1224,7 @@ class NextDiT(nn.Module):
return output
@staticmethod
@torch.compiler.disable(reason="complex ops inside")
def precompute_freqs_cis(
dim: List[int],
end: List[int],

View File

@@ -3992,6 +3992,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類デフォルトは inductor",
)
parser.add_argument(
"--activation_memory_budget",
type=float,
default=None,
help="activation memory budget setting for torch.compile (range: 0~1). Smaller value saves more memory at cost of speed. If set, use --torch_compile without --gradient_checkpointing is recommended. Requires PyTorch 2.4. / torch.compileのactivation memory budget設定01の値。この値を小さくするとメモリ使用量を節約できますが、処理速度は低下します。この設定を行う場合は、--gradient_checkpointing オプションを指定せずに --torch_compile を使用することをお勧めします。PyTorch 2.4以降が必要です。"
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument(
"--sdpa",
@@ -5539,6 +5545,19 @@ def prepare_accelerator(args: argparse.Namespace):
if args.torch_compile:
dynamo_backend = args.dynamo_backend
if args.activation_memory_budget is not None: # Note: 0 is a valid value.
if 0 <= args.activation_memory_budget <= 1:
logger.info(
f"set torch compile activation memory budget to {args.activation_memory_budget}"
)
torch._functorch.config.activation_memory_budget = ( # type: ignore
args.activation_memory_budget
)
else:
raise ValueError(
"activation_memory_budget must be between 0 and 1 (inclusive)"
)
kwargs_handlers = [
(
InitProcessGroupKwargs(