Merge pull request #2277 from kohya-ss/feat-stability-with-fp16-for-anima

feat: Stability with fp16 for anima
This commit is contained in:
Kohya S.
2026-02-23 21:15:49 +09:00
committed by GitHub

View File

@@ -864,6 +864,10 @@ class Block(nn.Module):
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if x_B_T_H_W_D.dtype == torch.float16:
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
x_B_T_H_W_D = x_B_T_H_W_D.float()
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb