diff --git a/library/anima_models.py b/library/anima_models.py index 6828e598..037ffd77 100644 --- a/library/anima_models.py +++ b/library/anima_models.py @@ -864,6 +864,10 @@ class Block(nn.Module): adaln_lora_B_T_3D: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if x_B_T_H_W_D.dtype == torch.float16: + # Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context. + x_B_T_H_W_D = x_B_T_H_W_D.float() + if extra_per_block_pos_emb is not None: x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb