From 892f8be78fc01989ab27c01bfd02173676d43bd3 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Mon, 23 Feb 2026 21:12:57 +0900 Subject: [PATCH] fix: cast input tensor to float32 for improved numerical stability in residual connections --- library/anima_models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/library/anima_models.py b/library/anima_models.py index 6828e598..037ffd77 100644 --- a/library/anima_models.py +++ b/library/anima_models.py @@ -864,6 +864,10 @@ class Block(nn.Module): adaln_lora_B_T_3D: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if x_B_T_H_W_D.dtype == torch.float16: + # Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context. + x_B_T_H_W_D = x_B_T_H_W_D.float() + if extra_per_block_pos_emb is not None: x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb