From fa53f71ec08bd9dd6ccae3935247b702a8217493 Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Thu, 2 Apr 2026 12:36:29 +0900 Subject: [PATCH] fix: improve numerical stability by conditionally using float32 in Anima (#2302) * fix: improve numerical stability by conditionally using float32 in block computations * doc: update README for improvement stability for fp16 training on Anima in version 0.10.3 --- README-ja.md | 3 +++ README.md | 3 +++ library/anima_models.py | 16 ++++++++++++---- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/README-ja.md b/README-ja.md index f4f912a2..ff05468d 100644 --- a/README-ja.md +++ b/README-ja.md @@ -50,6 +50,9 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像 ### 更新履歴 +- **Version 0.10.3 (2026-04-02):** + - Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。 + - **Version 0.10.2 (2026-03-30):** - SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。 - 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。 diff --git a/README.md b/README.md index fc041db3..0fb415d0 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,9 @@ If you find this project helpful, please consider supporting its development via ### Change History +- **Version 0.10.3 (2026-04-02):** + - Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue. + - **Version 0.10.2 (2026-03-30):** - LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294). - Please refer to the [documentation](./docs/train_leco.md) for details. diff --git a/library/anima_models.py b/library/anima_models.py index 00e9c6c6..ad34662f 100644 --- a/library/anima_models.py +++ b/library/anima_models.py @@ -738,9 +738,9 @@ class FinalLayer(nn.Module): x_B_T_H_W_D: torch.Tensor, emb_B_T_D: torch.Tensor, adaln_lora_B_T_3D: Optional[torch.Tensor] = None, + use_fp32: bool = False, ): # Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers) - use_fp32 = x_B_T_H_W_D.dtype == torch.float16 with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32): if self.use_adaln_lora: assert adaln_lora_B_T_3D is not None @@ -863,11 +863,11 @@ class Block(nn.Module): emb_B_T_D: torch.Tensor, crossattn_emb: torch.Tensor, attn_params: attention.AttentionParams, + use_fp32: bool = False, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - use_fp32 = x_B_T_H_W_D.dtype == torch.float16 if use_fp32: # Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context. x_B_T_H_W_D = x_B_T_H_W_D.float() @@ -959,6 +959,7 @@ class Block(nn.Module): emb_B_T_D: torch.Tensor, crossattn_emb: torch.Tensor, attn_params: attention.AttentionParams, + use_fp32: bool = False, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None, @@ -972,6 +973,7 @@ class Block(nn.Module): emb_B_T_D, crossattn_emb, attn_params, + use_fp32, rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, @@ -994,6 +996,7 @@ class Block(nn.Module): emb_B_T_D, crossattn_emb, attn_params, + use_fp32, rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, @@ -1007,6 +1010,7 @@ class Block(nn.Module): emb_B_T_D, crossattn_emb, attn_params, + use_fp32, rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, @@ -1018,6 +1022,7 @@ class Block(nn.Module): emb_B_T_D, crossattn_emb, attn_params, + use_fp32, rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, @@ -1338,16 +1343,19 @@ class Anima(nn.Module): attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn) + # Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16) + use_fp32 = x_B_T_H_W_D.dtype == torch.float16 + for block_idx, block in enumerate(self.blocks): if self.blocks_to_swap: self.offloader.wait_for_block(block_idx) - x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs) + x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs) if self.blocks_to_swap: self.offloader.submit_move_blocks(self.blocks, block_idx) - x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) + x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32) x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O) return x_B_C_Tt_Hp_Wp