mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
fix: cast input tensor to float32 for improved numerical stability in residual connections
This commit is contained in:
@@ -864,6 +864,10 @@ class Block(nn.Module):
|
|||||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if x_B_T_H_W_D.dtype == torch.float16:
|
||||||
|
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
|
||||||
|
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||||
|
|
||||||
if extra_per_block_pos_emb is not None:
|
if extra_per_block_pos_emb is not None:
|
||||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user