Compare commits

...

4 Commits

Author SHA1 Message Date
Dave Lage
12ea9b2ec5 Merge 90d14b9eb0 into fa53f71ec0 2026-04-05 00:39:06 +00:00
Kohya S.
fa53f71ec0 fix: improve numerical stability by conditionally using float32 in Anima (#2302)
* fix: improve numerical stability by conditionally using float32 in block computations

* doc: update README for improvement stability for fp16 training on Anima in version 0.10.3
2026-04-02 12:36:29 +09:00
rockerBOO
90d14b9eb0 Remove priority 2025-04-12 04:09:39 -04:00
rockerBOO
2ab5bc69e6 Add Flash, cuDNN, Efficient attention for Flux 2025-04-11 23:14:41 -04:00
4 changed files with 22 additions and 5 deletions

View File

@@ -50,6 +50,9 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
### 更新履歴 ### 更新履歴
- **Version 0.10.3 (2026-04-02):**
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
- **Version 0.10.2 (2026-03-30):** - **Version 0.10.2 (2026-03-30):**
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。 - SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。 - 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。

View File

@@ -47,6 +47,9 @@ If you find this project helpful, please consider supporting its development via
### Change History ### Change History
- **Version 0.10.3 (2026-04-02):**
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
- **Version 0.10.2 (2026-03-30):** - **Version 0.10.2 (2026-03-30):**
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294). - LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
- Please refer to the [documentation](./docs/train_leco.md) for details. - Please refer to the [documentation](./docs/train_leco.md) for details.

View File

@@ -738,9 +738,9 @@ class FinalLayer(nn.Module):
x_B_T_H_W_D: torch.Tensor, x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor, emb_B_T_D: torch.Tensor,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
use_fp32: bool = False,
): ):
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers) # Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32): with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
if self.use_adaln_lora: if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None assert adaln_lora_B_T_3D is not None
@@ -863,11 +863,11 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor, emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor, crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams, attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
if use_fp32: if use_fp32:
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context. # Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
x_B_T_H_W_D = x_B_T_H_W_D.float() x_B_T_H_W_D = x_B_T_H_W_D.float()
@@ -959,6 +959,7 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor, emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor, crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams, attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -972,6 +973,7 @@ class Block(nn.Module):
emb_B_T_D, emb_B_T_D,
crossattn_emb, crossattn_emb,
attn_params, attn_params,
use_fp32,
rope_emb_L_1_1_D, rope_emb_L_1_1_D,
adaln_lora_B_T_3D, adaln_lora_B_T_3D,
extra_per_block_pos_emb, extra_per_block_pos_emb,
@@ -994,6 +996,7 @@ class Block(nn.Module):
emb_B_T_D, emb_B_T_D,
crossattn_emb, crossattn_emb,
attn_params, attn_params,
use_fp32,
rope_emb_L_1_1_D, rope_emb_L_1_1_D,
adaln_lora_B_T_3D, adaln_lora_B_T_3D,
extra_per_block_pos_emb, extra_per_block_pos_emb,
@@ -1007,6 +1010,7 @@ class Block(nn.Module):
emb_B_T_D, emb_B_T_D,
crossattn_emb, crossattn_emb,
attn_params, attn_params,
use_fp32,
rope_emb_L_1_1_D, rope_emb_L_1_1_D,
adaln_lora_B_T_3D, adaln_lora_B_T_3D,
extra_per_block_pos_emb, extra_per_block_pos_emb,
@@ -1018,6 +1022,7 @@ class Block(nn.Module):
emb_B_T_D, emb_B_T_D,
crossattn_emb, crossattn_emb,
attn_params, attn_params,
use_fp32,
rope_emb_L_1_1_D, rope_emb_L_1_1_D,
adaln_lora_B_T_3D, adaln_lora_B_T_3D,
extra_per_block_pos_emb, extra_per_block_pos_emb,
@@ -1338,16 +1343,19 @@ class Anima(nn.Module):
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn) attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
# Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
for block_idx, block in enumerate(self.blocks): for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap: if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx) self.offloader.wait_for_block(block_idx)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs) x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs)
if self.blocks_to_swap: if self.blocks_to_swap:
self.offloader.submit_move_blocks(self.blocks, block_idx) self.offloader.submit_move_blocks(self.blocks, block_idx)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O) x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp return x_B_C_Tt_Hp_Wp

View File

@@ -18,6 +18,7 @@ import torch
from einops import rearrange from einops import rearrange
from torch import Tensor, nn from torch import Tensor, nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from torch.nn.attention import SDPBackend, sdpa_kernel
from library import custom_offloading_utils from library import custom_offloading_utils
@@ -445,10 +446,12 @@ configs = {
# region math # region math
kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
q, k = apply_rope(q, k, pe) q, k = apply_rope(q, k, pe)
with sdpa_kernel(kernels):
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "B H L D -> B L (H D)") x = rearrange(x, "B H L D -> B L (H D)")