fix: cast input tensor to float32 for improved numerical stability in residual connections

This commit is contained in:
Kohya S
2026-02-23 21:12:57 +09:00
parent 50694df3cf
commit 892f8be78f

View File

@@ -864,6 +864,10 @@ class Block(nn.Module):
adaln_lora_B_T_3D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if x_B_T_H_W_D.dtype == torch.float16:
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
x_B_T_H_W_D = x_B_T_H_W_D.float()
if extra_per_block_pos_emb is not None: if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb