feat: optimize RMSNorm forward method and remove unused torch_attention_op

This commit is contained in:
Kohya S
2026-02-11 22:06:47 +09:00
parent 4b2283491e
commit 59267d19f3

View File

@@ -233,10 +233,10 @@ class RMSNorm(torch.nn.Module):
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
@torch.amp.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
with torch.autocast(device_type=x.device.type, dtype=torch.float32):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class GPT2FeedForward(nn.Module):
@@ -269,20 +269,6 @@ class GPT2FeedForward(nn.Module):
return x
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
"""Computes multi-head attention using PyTorch's native scaled_dot_product_attention.
Input/output format: (batch, seq_len, n_heads, head_dim)
"""
in_q_shape = q_B_S_H_D.shape
in_k_shape = k_B_S_H_D.shape
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
result_B_S_HD = rearrange(F.scaled_dot_product_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D), "b h ... l -> b ... (h l)")
return result_B_S_HD
# Attention module for DiT
class Attention(nn.Module):
"""Multi-head attention supporting both self-attention and cross-attention.
@@ -323,8 +309,6 @@ class Attention(nn.Module):
self.output_proj = nn.Linear(inner_dim, query_dim, bias=False)
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
self.attn_op = torch_attention_op
self._query_dim = query_dim
self._context_dim = context_dim
self._inner_dim = inner_dim
@@ -640,29 +624,30 @@ class TimestepEmbedding(nn.Module):
return emb_B_T_D, adaln_lora_B_T_3D
class FourierFeatures(nn.Module):
"""Fourier feature transform: [B] -> [B, D]."""
# Commented out Fourier Features (not used in Anima). Kept for reference.
# class FourierFeatures(nn.Module):
# """Fourier feature transform: [B] -> [B, D]."""
def __init__(self, num_channels: int, bandwidth: int = 1, normalize: bool = False):
super().__init__()
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
self.gain = np.sqrt(2) if normalize else 1
self.bandwidth = bandwidth
self.num_channels = num_channels
self.reset_parameters()
# def __init__(self, num_channels: int, bandwidth: int = 1, normalize: bool = False):
# super().__init__()
# self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
# self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
# self.gain = np.sqrt(2) if normalize else 1
# self.bandwidth = bandwidth
# self.num_channels = num_channels
# self.reset_parameters()
def reset_parameters(self) -> None:
generator = torch.Generator()
generator.manual_seed(0)
self.freqs = 2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device)
# def reset_parameters(self) -> None:
# generator = torch.Generator()
# generator.manual_seed(0)
# self.freqs = 2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
# self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device)
def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
in_dtype = x.dtype
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
x = x.cos().mul(self.gain * gain).to(in_dtype)
return x
# def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
# in_dtype = x.dtype
# x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
# x = x.cos().mul(self.gain * gain).to(in_dtype)
# return x
# Patch Embedding
@@ -1352,13 +1337,7 @@ class Anima(nn.Module):
if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx)
x_B_T_H_W_D = block(
x_B_T_H_W_D,
t_embedding_B_T_D,
crossattn_emb,
attn_params,
**block_kwargs,
)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
if self.blocks_to_swap:
self.offloader.submit_move_blocks(self.blocks, block_idx)