diff --git a/library/anima_models.py b/library/anima_models.py index 037ffd77..00e9c6c6 100644 --- a/library/anima_models.py +++ b/library/anima_models.py @@ -739,13 +739,16 @@ class FinalLayer(nn.Module): emb_B_T_D: torch.Tensor, adaln_lora_B_T_3D: Optional[torch.Tensor] = None, ): - if self.use_adaln_lora: - assert adaln_lora_B_T_3D is not None - shift_B_T_D, scale_B_T_D = (self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]).chunk( - 2, dim=-1 - ) - else: - shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) + # Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers) + use_fp32 = x_B_T_H_W_D.dtype == torch.float16 + with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32): + if self.use_adaln_lora: + assert adaln_lora_B_T_3D is not None + shift_B_T_D, scale_B_T_D = ( + self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size] + ).chunk(2, dim=-1) + else: + shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) shift_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d") scale_B_T_1_1_D = rearrange(scale_B_T_D, "b t d -> b t 1 1 d") @@ -864,32 +867,34 @@ class Block(nn.Module): adaln_lora_B_T_3D: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if x_B_T_H_W_D.dtype == torch.float16: + use_fp32 = x_B_T_H_W_D.dtype == torch.float16 + if use_fp32: # Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context. x_B_T_H_W_D = x_B_T_H_W_D.float() if extra_per_block_pos_emb is not None: x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb - # Compute AdaLN modulation parameters - if self.use_adaln_lora: - shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = ( - self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D - ).chunk(3, dim=-1) - shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( - self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D - ).chunk(3, dim=-1) - shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk( - 3, dim=-1 - ) - else: - shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(emb_B_T_D).chunk( - 3, dim=-1 - ) - shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( - emb_B_T_D - ).chunk(3, dim=-1) - shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1) + # Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers) + with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32): + if self.use_adaln_lora: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = ( + self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( + self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk( + 3, dim=-1 + ) + else: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1) # Reshape for broadcasting: (B, T, D) -> (B, T, 1, 1, D) shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d")