mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
feat: optimize RMSNorm forward method and remove unused torch_attention_op
This commit is contained in:
@@ -233,10 +233,10 @@ class RMSNorm(torch.nn.Module):
|
||||
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
@torch.amp.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
with torch.autocast(device_type=x.device.type, dtype=torch.float32):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
|
||||
class GPT2FeedForward(nn.Module):
|
||||
@@ -269,20 +269,6 @@ class GPT2FeedForward(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes multi-head attention using PyTorch's native scaled_dot_product_attention.
|
||||
|
||||
Input/output format: (batch, seq_len, n_heads, head_dim)
|
||||
"""
|
||||
in_q_shape = q_B_S_H_D.shape
|
||||
in_k_shape = k_B_S_H_D.shape
|
||||
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
||||
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
result_B_S_HD = rearrange(F.scaled_dot_product_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D), "b h ... l -> b ... (h l)")
|
||||
return result_B_S_HD
|
||||
|
||||
|
||||
# Attention module for DiT
|
||||
class Attention(nn.Module):
|
||||
"""Multi-head attention supporting both self-attention and cross-attention.
|
||||
@@ -323,8 +309,6 @@ class Attention(nn.Module):
|
||||
self.output_proj = nn.Linear(inner_dim, query_dim, bias=False)
|
||||
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
|
||||
|
||||
self.attn_op = torch_attention_op
|
||||
|
||||
self._query_dim = query_dim
|
||||
self._context_dim = context_dim
|
||||
self._inner_dim = inner_dim
|
||||
@@ -640,29 +624,30 @@ class TimestepEmbedding(nn.Module):
|
||||
return emb_B_T_D, adaln_lora_B_T_3D
|
||||
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
"""Fourier feature transform: [B] -> [B, D]."""
|
||||
# Commented out Fourier Features (not used in Anima). Kept for reference.
|
||||
# class FourierFeatures(nn.Module):
|
||||
# """Fourier feature transform: [B] -> [B, D]."""
|
||||
|
||||
def __init__(self, num_channels: int, bandwidth: int = 1, normalize: bool = False):
|
||||
super().__init__()
|
||||
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
|
||||
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
|
||||
self.gain = np.sqrt(2) if normalize else 1
|
||||
self.bandwidth = bandwidth
|
||||
self.num_channels = num_channels
|
||||
self.reset_parameters()
|
||||
# def __init__(self, num_channels: int, bandwidth: int = 1, normalize: bool = False):
|
||||
# super().__init__()
|
||||
# self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
|
||||
# self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
|
||||
# self.gain = np.sqrt(2) if normalize else 1
|
||||
# self.bandwidth = bandwidth
|
||||
# self.num_channels = num_channels
|
||||
# self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(0)
|
||||
self.freqs = 2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
|
||||
self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device)
|
||||
# def reset_parameters(self) -> None:
|
||||
# generator = torch.Generator()
|
||||
# generator.manual_seed(0)
|
||||
# self.freqs = 2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
|
||||
# self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device)
|
||||
|
||||
def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
|
||||
in_dtype = x.dtype
|
||||
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
|
||||
x = x.cos().mul(self.gain * gain).to(in_dtype)
|
||||
return x
|
||||
# def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
|
||||
# in_dtype = x.dtype
|
||||
# x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
|
||||
# x = x.cos().mul(self.gain * gain).to(in_dtype)
|
||||
# return x
|
||||
|
||||
|
||||
# Patch Embedding
|
||||
@@ -1352,13 +1337,7 @@ class Anima(nn.Module):
|
||||
if self.blocks_to_swap:
|
||||
self.offloader.wait_for_block(block_idx)
|
||||
|
||||
x_B_T_H_W_D = block(
|
||||
x_B_T_H_W_D,
|
||||
t_embedding_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
**block_kwargs,
|
||||
)
|
||||
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.offloader.submit_move_blocks(self.blocks, block_idx)
|
||||
|
||||
Reference in New Issue
Block a user