# Anima Model Architecture # Original code: NVIDIA CORPORATION & AFFILIATES, licensed under Apache-2.0 import math from typing import Any, Optional, Tuple, Union import numpy as np import torch from einops import rearrange, repeat from einops.layers.torch import Rearrange from torch import nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint as torch_checkpoint from library import custom_offloading_utils, attention def to_device(x, device): if isinstance(x, torch.Tensor): return x.to(device) elif isinstance(x, (list, tuple)): return type(x)(to_device(elem, device) for elem in x) elif isinstance(x, dict): return {k: to_device(v, device) for k, v in x.items()} else: return x def to_cpu(x): if isinstance(x, torch.Tensor): return x.cpu() elif isinstance(x, (list, tuple)): return [to_cpu(elem) for elem in x] elif isinstance(x, dict): return {k: to_cpu(v) for k, v in x.items()} else: return x # Unsloth Offloaded Gradient Checkpointing # Based on Unsloth Zoo by Daniel Han-Chen & the Unsloth team try: from deepspeed.runtime.activation_checkpointing.checkpointing import detach_variable except ImportError: def detach_variable(inputs, device=None): """Detach tensors from computation graph, optionally moving to a device. Reimplementation of deepspeed.runtime.activation_checkpointing.checkpointing.detach_variable for environments without DeepSpeed installed. """ if isinstance(inputs, tuple): out = [] for inp in inputs: if not isinstance(inp, torch.Tensor): out.append(inp) continue requires_grad = inp.requires_grad if device is not None: x = inp.to(device=device) else: x = inp x = x.detach() x.requires_grad = requires_grad out.append(x) return tuple(out) else: raise RuntimeError( "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__, ) class UnslothOffloadedGradientCheckpointer(torch.autograd.Function): """Saves VRAM by offloading activations to CPU RAM using non-blocking transfers. Compared to standard cpu_offload_checkpointing which uses blocking transfers, this uses non_blocking=True to hide CPU<->GPU transfer latency behind compute. """ @staticmethod @torch.amp.custom_fwd(device_type="cuda") def forward(ctx, forward_function, hidden_states, *args): # Remember the original device for backward pass (multi-GPU support) ctx.input_device = hidden_states.device saved_hidden_states = hidden_states.to("cpu", non_blocking=True) with torch.no_grad(): output = forward_function(hidden_states, *args) ctx.save_for_backward(saved_hidden_states) ctx.forward_function = forward_function # NOTE: args stored directly on ctx (not via save_for_backward) because # the training loop holds references to these tensors, preventing GC. # Using save_for_backward for all args would add complexity for no benefit. ctx.args = args return output @staticmethod @torch.amp.custom_bwd(device_type="cuda") def backward(ctx, *grads): (hidden_states,) = ctx.saved_tensors hidden_states = hidden_states.to(ctx.input_device, non_blocking=True).detach() hidden_states.requires_grad_(True) args = detach_variable(ctx.args) inputs = (hidden_states,) + args with torch.enable_grad(): outputs = ctx.forward_function(*inputs) output_tensors = [] grad_tensors = [] for out, grad in zip( outputs if isinstance(outputs, tuple) else (outputs,), grads if isinstance(grads, tuple) else (grads,) ): if isinstance(out, torch.Tensor) and out.requires_grad: output_tensors.append(out) grad_tensors.append(grad) torch.autograd.backward(output_tensors, grad_tensors) return (None,) + tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs) @torch._disable_dynamo def unsloth_checkpoint(function, *args): """Wrapper for UnslothOffloadedGradientCheckpointer.""" return UnslothOffloadedGradientCheckpointer.apply(function, *args) from .utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) # Utility functions: RoPE for DiT def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: if not interleaved: x1, x2 = torch.chunk(x, 2, dim=-1) return torch.cat((-x2, x1), dim=-1) x1 = x[:, :, :, ::2] x2 = x[:, :, :, 1::2] x_new = torch.stack((-x2, x1), dim=-1) return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1) def _apply_rotary_pos_emb_base( t: torch.Tensor, freqs: torch.Tensor, start_positions: torch.Tensor = None, tensor_format: str = "sbhd", interleaved: bool = False, ) -> torch.Tensor: max_seq_len = freqs.shape[0] cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] if start_positions is not None: max_offset = torch.max(start_positions) assert max_offset + cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1) assert cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" freqs = freqs[:cur_seq_len] if tensor_format == "bshd": freqs = freqs.transpose(0, 1) cos_ = torch.cos(freqs).to(t.dtype) sin_ = torch.sin(freqs).to(t.dtype) rot_dim = freqs.shape[-1] t, t_pass = t[..., :rot_dim], t[..., rot_dim:] t = (t * cos_) + (_rotate_half(t, interleaved) * sin_) return torch.cat((t, t_pass), dim=-1) def apply_rotary_pos_emb( t: torch.Tensor, freqs: torch.Tensor, tensor_format: str = "sbhd", start_positions: Union[torch.Tensor, None] = None, interleaved: bool = False, fused: bool = False, cu_seqlens: Union[torch.Tensor, None] = None, cp_size: int = 1, ) -> torch.Tensor: assert not (cp_size > 1 and start_positions is not None), "start_positions != None with CP SIZE > 1 is not supported!" assert tensor_format != "thd" or cu_seqlens is not None, "cu_seqlens must not be None when tensor_format is 'thd'." assert fused == False if tensor_format == "thd": cu_seqlens = cu_seqlens // cp_size seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return torch.cat( [ _apply_rotary_pos_emb_base( x.unsqueeze(1), freqs, start_positions=(start_positions[idx : idx + 1] if start_positions is not None else None), interleaved=interleaved, ) for idx, x in enumerate(torch.split(t, seqlens)) ] ).squeeze(1) if tensor_format == "sbhd": seqlen = t.size(0) elif tensor_format == "bshd": seqlen = t.size(1) else: raise ValueError(f"Unsupported tensor_format: {tensor_format}.") return _apply_rotary_pos_emb_base( t, freqs, start_positions, tensor_format, interleaved=interleaved, ) # Basic building blocks class RMSNorm(torch.nn.Module): """RMS Normalization for DiT blocks.""" def __init__(self, dim: int, eps: float = 1e-5) -> None: super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def reset_parameters(self) -> None: torch.nn.init.ones_(self.weight) def _norm(self, x: torch.Tensor) -> torch.Tensor: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.autocast(device_type=x.device.type, dtype=torch.float32): output = self._norm(x.float()).type_as(x) return output * self.weight class GPT2FeedForward(nn.Module): """GELU feedforward network.""" def __init__(self, d_model: int, d_ff: int) -> None: super().__init__() self.activation = nn.GELU() self.layer1 = nn.Linear(d_model, d_ff, bias=False) self.layer2 = nn.Linear(d_ff, d_model, bias=False) self._layer_id = None self._dim = d_model self._hidden_dim = d_ff self.init_weights() def init_weights(self) -> None: std = 1.0 / math.sqrt(self._dim) torch.nn.init.trunc_normal_(self.layer1.weight, std=std, a=-3 * std, b=3 * std) std = 1.0 / math.sqrt(self._hidden_dim) if self._layer_id is not None: std = std / math.sqrt(2 * (self._layer_id + 1)) torch.nn.init.trunc_normal_(self.layer2.weight, std=std, a=-3 * std, b=3 * std) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.layer1(x) x = self.activation(x) x = self.layer2(x) return x # Attention module for DiT class Attention(nn.Module): """Multi-head attention supporting both self-attention and cross-attention. Uses QK-norm (RMSNorm on q/k) and optional RoPE (only for self-attention). """ def __init__( self, query_dim: int, context_dim: Optional[int] = None, n_heads: int = 8, head_dim: int = 64, dropout: float = 0.0, qkv_format: str = "bshd", ) -> None: super().__init__() self.is_selfattn = context_dim is None context_dim = query_dim if context_dim is None else context_dim inner_dim = head_dim * n_heads self.n_heads = n_heads self.head_dim = head_dim self.qkv_format = qkv_format self.query_dim = query_dim self.context_dim = context_dim self.q_proj = nn.Linear(query_dim, inner_dim, bias=False) self.q_norm = RMSNorm(self.head_dim, eps=1e-6) self.k_proj = nn.Linear(context_dim, inner_dim, bias=False) self.k_norm = RMSNorm(self.head_dim, eps=1e-6) self.v_proj = nn.Linear(context_dim, inner_dim, bias=False) self.v_norm = nn.Identity() self.output_proj = nn.Linear(inner_dim, query_dim, bias=False) self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity() self._query_dim = query_dim self._context_dim = context_dim self._inner_dim = inner_dim self.init_weights() def init_weights(self) -> None: std = 1.0 / math.sqrt(self._query_dim) torch.nn.init.trunc_normal_(self.q_proj.weight, std=std, a=-3 * std, b=3 * std) std = 1.0 / math.sqrt(self._context_dim) torch.nn.init.trunc_normal_(self.k_proj.weight, std=std, a=-3 * std, b=3 * std) torch.nn.init.trunc_normal_(self.v_proj.weight, std=std, a=-3 * std, b=3 * std) std = 1.0 / math.sqrt(self._inner_dim) torch.nn.init.trunc_normal_(self.output_proj.weight, std=std, a=-3 * std, b=3 * std) for layer in self.q_norm, self.k_norm, self.v_norm: if hasattr(layer, "reset_parameters"): layer.reset_parameters() def compute_qkv( self, x: torch.Tensor, context: Optional[torch.Tensor] = None, rope_emb: Optional[torch.Tensor] = None, ) -> tuple: q = self.q_proj(x) context = x if context is None else context k = self.k_proj(context) v = self.v_proj(context) q, k, v = map( lambda t: rearrange(t, "b ... (h d) -> b ... h d", h=self.n_heads, d=self.head_dim), (q, k, v), ) q = self.q_norm(q) k = self.k_norm(k) v = self.v_norm(v) if self.is_selfattn and rope_emb is not None: q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=False) k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=False) return q, k, v def forward( self, x: torch.Tensor, attn_params: attention.AttentionParams, context: Optional[torch.Tensor] = None, rope_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) if q.dtype != v.dtype: if (not attn_params.supports_fp32 or attn_params.requires_same_dtype) and torch.is_autocast_enabled(): # FlashAttention requires fp16/bf16, xformers require same dtype; only cast when autocast is active. target_dtype = v.dtype # v has fp16/bf16 dtype q = q.to(target_dtype) k = k.to(target_dtype) # return self.compute_attention(q, k, v) qkv = [q, k, v] del q, k, v result = attention.attention(qkv, attn_params=attn_params) return self.output_dropout(self.output_proj(result)) # Positional Embeddings class VideoPositionEmb(nn.Module): def __init__(self) -> None: super().__init__() @property def seq_dim(self) -> int: return 1 def forward(self, x_B_T_H_W_C: torch.Tensor, fps: Optional[torch.Tensor]) -> torch.Tensor: B_T_H_W_C = x_B_T_H_W_C.shape embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps) return embeddings def generate_embeddings(self, B_T_H_W_C: torch.Size, fps: Optional[torch.Tensor]) -> Any: raise NotImplementedError class VideoRopePosition3DEmb(VideoPositionEmb): """3D Rotary Position Embedding for video (T, H, W) dimensions.""" def __init__( self, *, head_dim: int, len_h: int, len_w: int, len_t: int, base_fps: int = 24, h_extrapolation_ratio: float = 1.0, w_extrapolation_ratio: float = 1.0, t_extrapolation_ratio: float = 1.0, enable_fps_modulation: bool = True, **kwargs, ): del kwargs super().__init__() self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float)) self.base_fps = base_fps self.max_h = len_h self.max_w = len_w self.max_t = len_t self.enable_fps_modulation = enable_fps_modulation dim = head_dim dim_h = dim // 6 * 2 dim_w = dim_h dim_t = dim - 2 * dim_h assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" self.register_buffer( "dim_spatial_range", torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h, persistent=True, ) self.register_buffer( "dim_temporal_range", torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t, persistent=True, ) self._dim_h = dim_h self._dim_t = dim_t self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) self.reset_parameters() def reset_parameters(self) -> None: dim_h = self._dim_h dim_t = self._dim_t self.seq = torch.arange(max(self.max_h, self.max_w, self.max_t)).float().to(self.dim_spatial_range.device) self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t def generate_embeddings( self, B_T_H_W_C: torch.Size, fps: Optional[torch.Tensor] = None, h_ntk_factor: Optional[float] = None, w_ntk_factor: Optional[float] = None, t_ntk_factor: Optional[float] = None, ) -> torch.Tensor: h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor h_theta = 10000.0 * h_ntk_factor w_theta = 10000.0 * w_ntk_factor t_theta = 10000.0 * t_ntk_factor h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) B, T, H, W, _ = B_T_H_W_C assert ( H <= self.max_h and W <= self.max_w ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})" half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs) half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs) if self.enable_fps_modulation: uniform_fps = (fps is None) or (fps.min() == fps.max()) assert ( uniform_fps or B == 1 or T == 1 ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" if fps is None: assert T == 1, "T should be 1 for image batch." half_emb_t = torch.outer(self.seq[:T], temporal_freqs) else: half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs) else: half_emb_t = torch.outer(self.seq[:T], temporal_freqs) em_T_H_W_D = torch.cat( [ repeat(half_emb_t, "t d -> t h w d", h=H, w=W), repeat(half_emb_h, "h d -> t h w d", t=T, w=W), repeat(half_emb_w, "w d -> t h w d", t=T, h=H), ] * 2, dim=-1, ) return rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() @property def seq_dim(self) -> int: return 0 class LearnablePosEmbAxis(VideoPositionEmb): """Learnable axis-decomposed positional embeddings.""" def __init__( self, *, interpolation: str, model_channels: int, len_h: int, len_w: int, len_t: int, **kwargs, ): del kwargs super().__init__() self.interpolation = interpolation assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" self.model_channels = model_channels self.pos_emb_h = nn.Parameter(torch.zeros(len_h, model_channels)) self.pos_emb_w = nn.Parameter(torch.zeros(len_w, model_channels)) self.pos_emb_t = nn.Parameter(torch.zeros(len_t, model_channels)) self.reset_parameters() def reset_parameters(self) -> None: std = 1.0 / math.sqrt(self.model_channels) torch.nn.init.trunc_normal_(self.pos_emb_h, std=std, a=-3 * std, b=3 * std) torch.nn.init.trunc_normal_(self.pos_emb_w, std=std, a=-3 * std, b=3 * std) torch.nn.init.trunc_normal_(self.pos_emb_t, std=std, a=-3 * std, b=3 * std) def generate_embeddings(self, B_T_H_W_C: torch.Size, fps: Optional[torch.Tensor]) -> torch.Tensor: B, T, H, W, _ = B_T_H_W_C if self.interpolation == "crop": emb_h_H = self.pos_emb_h[:H] emb_w_W = self.pos_emb_w[:W] emb_t_T = self.pos_emb_t[:T] emb = ( repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W) + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W) + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H) ) assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" else: raise ValueError(f"Unknown interpolation method {self.interpolation}") norm = torch.linalg.vector_norm(emb, dim=-1, keepdim=True, dtype=torch.float32) norm = torch.add(1e-6, norm, alpha=np.sqrt(norm.numel() / emb.numel())) return emb / norm.to(emb.dtype) # Timestep Embedding class Timesteps(nn.Module): """Sinusoidal timestep features.""" def __init__(self, num_channels: int): super().__init__() self.num_channels = num_channels def forward(self, timesteps_B_T: torch.Tensor) -> torch.Tensor: assert timesteps_B_T.ndim == 2, f"Expected 2D input, got {timesteps_B_T.ndim}" in_dtype = timesteps_B_T.dtype timesteps = timesteps_B_T.flatten().float() half_dim = self.num_channels // 2 exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) exponent = exponent / (half_dim - 0.0) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] sin_emb = torch.sin(emb) cos_emb = torch.cos(emb) emb = torch.cat([cos_emb, sin_emb], dim=-1) return rearrange(emb.to(dtype=in_dtype), "(b t) d -> b t d", b=timesteps_B_T.shape[0], t=timesteps_B_T.shape[1]) class TimestepEmbedding(nn.Module): """Projects timestep features to model dimension, with optional AdaLN-LoRA.""" def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False): super().__init__() self.in_dim = in_features self.out_dim = out_features self.linear_1 = nn.Linear(in_features, out_features, bias=not use_adaln_lora) self.activation = nn.SiLU() self.use_adaln_lora = use_adaln_lora if use_adaln_lora: self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False) else: self.linear_2 = nn.Linear(out_features, out_features, bias=False) self.init_weights() def init_weights(self) -> None: std = 1.0 / math.sqrt(self.in_dim) torch.nn.init.trunc_normal_(self.linear_1.weight, std=std, a=-3 * std, b=3 * std) std = 1.0 / math.sqrt(self.out_dim) torch.nn.init.trunc_normal_(self.linear_2.weight, std=std, a=-3 * std, b=3 * std) def forward(self, sample: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: emb = self.linear_1(sample) emb = self.activation(emb) emb = self.linear_2(emb) if self.use_adaln_lora: adaln_lora_B_T_3D = emb emb_B_T_D = sample else: adaln_lora_B_T_3D = None emb_B_T_D = emb return emb_B_T_D, adaln_lora_B_T_3D # Commented out Fourier Features (not used in Anima). Kept for reference. # class FourierFeatures(nn.Module): # """Fourier feature transform: [B] -> [B, D].""" # def __init__(self, num_channels: int, bandwidth: int = 1, normalize: bool = False): # super().__init__() # self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True) # self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True) # self.gain = np.sqrt(2) if normalize else 1 # self.bandwidth = bandwidth # self.num_channels = num_channels # self.reset_parameters() # def reset_parameters(self) -> None: # generator = torch.Generator() # generator.manual_seed(0) # self.freqs = 2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device) # self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device) # def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor: # in_dtype = x.dtype # x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32)) # x = x.cos().mul(self.gain * gain).to(in_dtype) # return x # Patch Embedding class PatchEmbed(nn.Module): """Patch embedding: (B, C, T, H, W) -> (B, T', H', W', D)""" def __init__( self, spatial_patch_size: int, temporal_patch_size: int, in_channels: int = 3, out_channels: int = 768, ): super().__init__() self.spatial_patch_size = spatial_patch_size self.temporal_patch_size = temporal_patch_size self.proj = nn.Sequential( Rearrange( "b c (t r) (h m) (w n) -> b t h w (c r m n)", r=temporal_patch_size, m=spatial_patch_size, n=spatial_patch_size, ), nn.Linear(in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False), ) self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size self.init_weights() def init_weights(self) -> None: std = 1.0 / math.sqrt(self.dim) torch.nn.init.trunc_normal_(self.proj[1].weight, std=std, a=-3 * std, b=3 * std) def forward(self, x: torch.Tensor) -> torch.Tensor: assert x.dim() == 5 _, _, T, H, W = x.shape assert ( H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 ), f"H,W {(H, W)} should be divisible by spatial_patch_size {self.spatial_patch_size}" assert T % self.temporal_patch_size == 0 x = self.proj(x) return x # Final Layer class FinalLayer(nn.Module): """Final layer with AdaLN modulation + unpatchify.""" def __init__( self, hidden_size: int, spatial_patch_size: int, temporal_patch_size: int, out_channels: int, use_adaln_lora: bool = False, adaln_lora_dim: int = 256, ): super().__init__() self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear( hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False ) self.hidden_size = hidden_size self.n_adaln_chunks = 2 self.use_adaln_lora = use_adaln_lora self.adaln_lora_dim = adaln_lora_dim if use_adaln_lora: self.adaln_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, adaln_lora_dim, bias=False), nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False), ) else: self.adaln_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False)) self.init_weights() def init_weights(self) -> None: std = 1.0 / math.sqrt(self.hidden_size) torch.nn.init.trunc_normal_(self.linear.weight, std=std, a=-3 * std, b=3 * std) if self.use_adaln_lora: torch.nn.init.trunc_normal_(self.adaln_modulation[1].weight, std=std, a=-3 * std, b=3 * std) torch.nn.init.zeros_(self.adaln_modulation[2].weight) else: torch.nn.init.zeros_(self.adaln_modulation[1].weight) self.layer_norm.reset_parameters() def forward( self, x_B_T_H_W_D: torch.Tensor, emb_B_T_D: torch.Tensor, adaln_lora_B_T_3D: Optional[torch.Tensor] = None, use_fp32: bool = False, ): # Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers) with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32): if self.use_adaln_lora: assert adaln_lora_B_T_3D is not None shift_B_T_D, scale_B_T_D = ( self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size] ).chunk(2, dim=-1) else: shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) shift_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d") scale_B_T_1_1_D = rearrange(scale_B_T_D, "b t d -> b t 1 1 d") x_B_T_H_W_D = self.layer_norm(x_B_T_H_W_D) * (1 + scale_B_T_1_1_D) + shift_B_T_1_1_D x_B_T_H_W_O = self.linear(x_B_T_H_W_D) return x_B_T_H_W_O # Transformer Block (DiT Block) class Block(nn.Module): """Transformer block with self-attention + cross-attention + MLP, each modulated by AdaLN. Each sublayer: x = x + gate * sublayer(norm(x) * (1 + scale) + shift) """ def __init__( self, x_dim: int, context_dim: int, num_heads: int, mlp_ratio: float = 4.0, use_adaln_lora: bool = False, adaln_lora_dim: int = 256, ): super().__init__() self.x_dim = x_dim self.layer_norm_self_attn = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) self.self_attn = Attention( x_dim, None, num_heads, x_dim // num_heads, qkv_format="bshd", ) self.layer_norm_cross_attn = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) self.cross_attn = Attention( x_dim, context_dim, num_heads, x_dim // num_heads, qkv_format="bshd", ) self.layer_norm_mlp = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) self.mlp = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio)) self.use_adaln_lora = use_adaln_lora if self.use_adaln_lora: self.adaln_modulation_self_attn = nn.Sequential( nn.SiLU(), nn.Linear(x_dim, adaln_lora_dim, bias=False), nn.Linear(adaln_lora_dim, 3 * x_dim, bias=False), ) self.adaln_modulation_cross_attn = nn.Sequential( nn.SiLU(), nn.Linear(x_dim, adaln_lora_dim, bias=False), nn.Linear(adaln_lora_dim, 3 * x_dim, bias=False), ) self.adaln_modulation_mlp = nn.Sequential( nn.SiLU(), nn.Linear(x_dim, adaln_lora_dim, bias=False), nn.Linear(adaln_lora_dim, 3 * x_dim, bias=False), ) else: self.adaln_modulation_self_attn = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, 3 * x_dim, bias=False)) self.adaln_modulation_cross_attn = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, 3 * x_dim, bias=False)) self.adaln_modulation_mlp = nn.Sequential(nn.SiLU(), nn.Linear(x_dim, 3 * x_dim, bias=False)) self.gradient_checkpointing = False self.cpu_offload_checkpointing = False self.unsloth_offload_checkpointing = False def enable_gradient_checkpointing(self, cpu_offload: bool = False, unsloth_offload: bool = False): self.gradient_checkpointing = True self.cpu_offload_checkpointing = cpu_offload if not unsloth_offload else False self.unsloth_offload_checkpointing = unsloth_offload def disable_gradient_checkpointing(self): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False self.unsloth_offload_checkpointing = False def reset_parameters(self) -> None: self.layer_norm_self_attn.reset_parameters() self.layer_norm_cross_attn.reset_parameters() self.layer_norm_mlp.reset_parameters() if self.use_adaln_lora: std = 1.0 / math.sqrt(self.x_dim) torch.nn.init.trunc_normal_(self.adaln_modulation_self_attn[1].weight, std=std, a=-3 * std, b=3 * std) torch.nn.init.trunc_normal_(self.adaln_modulation_cross_attn[1].weight, std=std, a=-3 * std, b=3 * std) torch.nn.init.trunc_normal_(self.adaln_modulation_mlp[1].weight, std=std, a=-3 * std, b=3 * std) torch.nn.init.zeros_(self.adaln_modulation_self_attn[2].weight) torch.nn.init.zeros_(self.adaln_modulation_cross_attn[2].weight) torch.nn.init.zeros_(self.adaln_modulation_mlp[2].weight) else: torch.nn.init.zeros_(self.adaln_modulation_self_attn[1].weight) torch.nn.init.zeros_(self.adaln_modulation_cross_attn[1].weight) torch.nn.init.zeros_(self.adaln_modulation_mlp[1].weight) def init_weights(self) -> None: self.reset_parameters() self.self_attn.init_weights() self.cross_attn.init_weights() self.mlp.init_weights() def _forward( self, x_B_T_H_W_D: torch.Tensor, emb_B_T_D: torch.Tensor, crossattn_emb: torch.Tensor, attn_params: attention.AttentionParams, use_fp32: bool = False, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: if use_fp32: # Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context. x_B_T_H_W_D = x_B_T_H_W_D.float() if extra_per_block_pos_emb is not None: x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb # Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers) with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32): if self.use_adaln_lora: shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = ( self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D ).chunk(3, dim=-1) shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D ).chunk(3, dim=-1) shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk( 3, dim=-1 ) else: shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn( emb_B_T_D ).chunk(3, dim=-1) shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( emb_B_T_D ).chunk(3, dim=-1) shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1) # Reshape for broadcasting: (B, T, D) -> (B, T, 1, 1, D) shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d") scale_self_attn_B_T_1_1_D = rearrange(scale_self_attn_B_T_D, "b t d -> b t 1 1 d") gate_self_attn_B_T_1_1_D = rearrange(gate_self_attn_B_T_D, "b t d -> b t 1 1 d") shift_cross_attn_B_T_1_1_D = rearrange(shift_cross_attn_B_T_D, "b t d -> b t 1 1 d") scale_cross_attn_B_T_1_1_D = rearrange(scale_cross_attn_B_T_D, "b t d -> b t 1 1 d") gate_cross_attn_B_T_1_1_D = rearrange(gate_cross_attn_B_T_D, "b t d -> b t 1 1 d") shift_mlp_B_T_1_1_D = rearrange(shift_mlp_B_T_D, "b t d -> b t 1 1 d") scale_mlp_B_T_1_1_D = rearrange(scale_mlp_B_T_D, "b t d -> b t 1 1 d") gate_mlp_B_T_1_1_D = rearrange(gate_mlp_B_T_D, "b t d -> b t 1 1 d") B, T, H, W, D = x_B_T_H_W_D.shape def _adaln_fn(_x, _norm_layer, _scale, _shift): return _norm_layer(_x) * (1 + _scale) + _shift # 1. Self-attention normalized_x = _adaln_fn(x_B_T_H_W_D, self.layer_norm_self_attn, scale_self_attn_B_T_1_1_D, shift_self_attn_B_T_1_1_D) result = rearrange( self.self_attn( rearrange(normalized_x, "b t h w d -> b (t h w) d"), attn_params, None, rope_emb=rope_emb_L_1_1_D, ), "b (t h w) d -> b t h w d", t=T, h=H, w=W, ) x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result # 2. Cross-attention normalized_x = _adaln_fn(x_B_T_H_W_D, self.layer_norm_cross_attn, scale_cross_attn_B_T_1_1_D, shift_cross_attn_B_T_1_1_D) result = rearrange( self.cross_attn( rearrange(normalized_x, "b t h w d -> b (t h w) d"), attn_params, crossattn_emb, rope_emb=rope_emb_L_1_1_D, ), "b (t h w) d -> b t h w d", t=T, h=H, w=W, ) x_B_T_H_W_D = result * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D # 3. MLP normalized_x = _adaln_fn(x_B_T_H_W_D, self.layer_norm_mlp, scale_mlp_B_T_1_1_D, shift_mlp_B_T_1_1_D) result = self.mlp(normalized_x) x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result return x_B_T_H_W_D def forward( self, x_B_T_H_W_D: torch.Tensor, emb_B_T_D: torch.Tensor, crossattn_emb: torch.Tensor, attn_params: attention.AttentionParams, use_fp32: bool = False, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.training and self.gradient_checkpointing: if self.unsloth_offload_checkpointing: # Unsloth: async non-blocking CPU RAM offload (fastest offload method) return unsloth_checkpoint( self._forward, x_B_T_H_W_D, emb_B_T_D, crossattn_emb, attn_params, use_fp32, rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, ) elif self.cpu_offload_checkpointing: # Standard cpu offload: blocking transfers def create_custom_forward(func): def custom_forward(*inputs): # Determine original device from first tensor input device = next(t.device for t in inputs if isinstance(t, torch.Tensor)) device_inputs = to_device(inputs, device) outputs = func(*device_inputs) return to_cpu(outputs) return custom_forward return torch_checkpoint( create_custom_forward(self._forward), x_B_T_H_W_D, emb_B_T_D, crossattn_emb, attn_params, use_fp32, rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, use_reentrant=False, ) else: # Standard gradient checkpointing (no offload) return torch_checkpoint( self._forward, x_B_T_H_W_D, emb_B_T_D, crossattn_emb, attn_params, use_fp32, rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, use_reentrant=False, ) else: return self._forward( x_B_T_H_W_D, emb_B_T_D, crossattn_emb, attn_params, use_fp32, rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, ) # Main DiT Model: MiniTrainDIT (renamed to Anima) class Anima(nn.Module): """Cosmos-Predict2 DiT model for image/video generation. 28 transformer blocks with AdaLN-LoRA modulation, 3D RoPE, and optional LLM Adapter. """ LATENT_CHANNELS = 16 def __init__( self, max_img_h: int, max_img_w: int, max_frames: int, in_channels: int, out_channels: int, patch_spatial: int, patch_temporal: int, concat_padding_mask: bool = True, model_channels: int = 768, num_blocks: int = 10, num_heads: int = 16, mlp_ratio: float = 4.0, crossattn_emb_channels: int = 1024, pos_emb_cls: str = "sincos", pos_emb_learnable: bool = False, pos_emb_interpolation: str = "crop", min_fps: int = 1, max_fps: int = 30, use_adaln_lora: bool = False, adaln_lora_dim: int = 256, rope_h_extrapolation_ratio: float = 1.0, rope_w_extrapolation_ratio: float = 1.0, rope_t_extrapolation_ratio: float = 1.0, extra_per_block_abs_pos_emb: bool = False, extra_h_extrapolation_ratio: float = 1.0, extra_w_extrapolation_ratio: float = 1.0, extra_t_extrapolation_ratio: float = 1.0, rope_enable_fps_modulation: bool = True, use_llm_adapter: bool = False, attn_mode: str = "torch", split_attn: bool = False, ) -> None: super().__init__() self.max_img_h = max_img_h self.max_img_w = max_img_w self.max_frames = max_frames self.in_channels = in_channels self.out_channels = out_channels self.patch_spatial = patch_spatial self.patch_temporal = patch_temporal self.num_heads = num_heads self.num_blocks = num_blocks self.model_channels = model_channels self.concat_padding_mask = concat_padding_mask self.pos_emb_cls = pos_emb_cls self.pos_emb_learnable = pos_emb_learnable self.pos_emb_interpolation = pos_emb_interpolation self.min_fps = min_fps self.max_fps = max_fps self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio self.rope_enable_fps_modulation = rope_enable_fps_modulation self.use_llm_adapter = use_llm_adapter self.attn_mode = attn_mode self.split_attn = split_attn # Block swap support self.blocks_to_swap = None self.offloader: Optional[custom_offloading_utils.ModelOffloader] = None self.build_patch_embed() self.build_pos_embed() self.use_adaln_lora = use_adaln_lora self.adaln_lora_dim = adaln_lora_dim self.t_embedder = nn.Sequential( Timesteps(model_channels), TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora), ) if self.use_llm_adapter: self.llm_adapter = LLMAdapter( source_dim=1024, target_dim=1024, model_dim=1024, num_layers=6, self_attn=True, ) self.blocks = nn.ModuleList( [ Block( x_dim=model_channels, context_dim=crossattn_emb_channels, num_heads=num_heads, mlp_ratio=mlp_ratio, use_adaln_lora=use_adaln_lora, adaln_lora_dim=adaln_lora_dim, ) for _ in range(num_blocks) ] ) self.final_layer = FinalLayer( hidden_size=self.model_channels, spatial_patch_size=self.patch_spatial, temporal_patch_size=self.patch_temporal, out_channels=self.out_channels, use_adaln_lora=self.use_adaln_lora, adaln_lora_dim=self.adaln_lora_dim, ) self.t_embedding_norm = RMSNorm(model_channels, eps=1e-6) self.init_weights() def init_weights(self) -> None: self.x_embedder.init_weights() self.pos_embedder.reset_parameters() if self.extra_per_block_abs_pos_emb: self.extra_pos_embedder.reset_parameters() self.t_embedder[1].init_weights() for block in self.blocks: block.init_weights() self.final_layer.init_weights() self.t_embedding_norm.reset_parameters() def enable_gradient_checkpointing(self, cpu_offload: bool = False, unsloth_offload: bool = False): for block in self.blocks: block.enable_gradient_checkpointing(cpu_offload=cpu_offload, unsloth_offload=unsloth_offload) def disable_gradient_checkpointing(self): for block in self.blocks: block.disable_gradient_checkpointing() @property def device(self): return next(self.parameters()).device @property def dtype(self): return next(self.parameters()).dtype def build_patch_embed(self) -> None: in_channels = self.in_channels + 1 if self.concat_padding_mask else self.in_channels self.x_embedder = PatchEmbed( spatial_patch_size=self.patch_spatial, temporal_patch_size=self.patch_temporal, in_channels=in_channels, out_channels=self.model_channels, ) def build_pos_embed(self) -> None: if self.pos_emb_cls == "rope3d": cls_type = VideoRopePosition3DEmb else: raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") kwargs = dict( model_channels=self.model_channels, len_h=self.max_img_h // self.patch_spatial, len_w=self.max_img_w // self.patch_spatial, len_t=self.max_frames // self.patch_temporal, max_fps=self.max_fps, min_fps=self.min_fps, is_learnable=self.pos_emb_learnable, interpolation=self.pos_emb_interpolation, head_dim=self.model_channels // self.num_heads, h_extrapolation_ratio=self.rope_h_extrapolation_ratio, w_extrapolation_ratio=self.rope_w_extrapolation_ratio, t_extrapolation_ratio=self.rope_t_extrapolation_ratio, enable_fps_modulation=self.rope_enable_fps_modulation, ) self.pos_embedder = cls_type(**kwargs) if self.extra_per_block_abs_pos_emb: kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio self.extra_pos_embedder = LearnablePosEmbAxis(**kwargs) def prepare_embedded_sequence( self, x_B_C_T_H_W: torch.Tensor, fps: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: from torchvision import transforms if self.concat_padding_mask: padding_mask = transforms.functional.resize( padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST ) x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1) x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) if self.extra_per_block_abs_pos_emb: extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps) else: extra_pos_emb = None if "rope" in self.pos_emb_cls.lower(): return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) return x_B_T_H_W_D, None, extra_pos_emb def unpatchify(self, x_B_T_H_W_M: torch.Tensor) -> torch.Tensor: x_B_C_Tt_Hp_Wp = rearrange( x_B_T_H_W_M, "B T H W (p1 p2 t C) -> B C (T t) (H p1) (W p2)", p1=self.patch_spatial, p2=self.patch_spatial, t=self.patch_temporal, ) return x_B_C_Tt_Hp_Wp def enable_block_swap(self, num_blocks: int, device: torch.device): self.blocks_to_swap = num_blocks assert ( self.blocks_to_swap <= self.num_blocks - 2 ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks." self.offloader = custom_offloading_utils.ModelOffloader(self.blocks, self.blocks_to_swap, device) logger.info(f"Anima: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.") def move_to_device_except_swap_blocks(self, device: torch.device): # Move all modules to device except blocks (which are managed by offloader) if self.blocks_to_swap: save_blocks = self.blocks self.blocks = None # Use None to skip .to() on blocks (consistent with flux_models.py) self.to(device) if self.blocks_to_swap: self.blocks = save_blocks def switch_block_swap_for_inference(self): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return self.offloader.set_forward_only(True) self.prepare_block_swap_before_forward() print(f"Anima: Block swap set to forward only.") def switch_block_swap_for_training(self): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return self.offloader.set_forward_only(False) self.prepare_block_swap_before_forward() print(f"Anima: Block swap set to forward and backward.") def prepare_block_swap_before_forward(self): if self.blocks_to_swap is None or self.blocks_to_swap == 0: return self.offloader.prepare_block_devices_before_forward(self.blocks) def forward_mini_train_dit( self, x_B_C_T_H_W: torch.Tensor, timesteps_B_T: torch.Tensor, crossattn_emb: torch.Tensor, fps: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, source_attention_mask: Optional[torch.Tensor] = None, t5_input_ids: Optional[torch.Tensor] = None, t5_attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: x_B_C_T_H_W: (B, C, T, H, W) noisy latents timesteps_B_T: (B,) or (B, T) timesteps crossattn_emb: (B, N, D) cross-attention embeddings (or raw Qwen3 prompt_embeds if t5_input_ids provided) fps: Optional frames per second padding_mask: Optional padding mask source_attention_mask: Optional attention mask for Qwen3 embeddings (used with LLM adapter) t5_input_ids: Optional T5 token IDs (triggers LLM adapter when provided) t5_attn_mask: Optional T5 attention mask """ # Run LLM adapter inside forward for correct DDP gradient synchronization if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, "llm_adapter"): crossattn_emb = self.llm_adapter( source_hidden_states=crossattn_emb, target_input_ids=t5_input_ids, target_attention_mask=t5_attn_mask, source_attention_mask=source_attention_mask, ) if t5_attn_mask is not None: crossattn_emb[~t5_attn_mask.bool()] = 0 x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb = self.prepare_embedded_sequence( x_B_C_T_H_W, fps=fps, padding_mask=padding_mask, ) if timesteps_B_T.ndim == 1: timesteps_B_T = timesteps_B_T.unsqueeze(1) t_embedding_B_T_D, adaln_lora_B_T_3D = self.t_embedder(timesteps_B_T) t_embedding_B_T_D = self.t_embedding_norm(t_embedding_B_T_D) block_kwargs = { "rope_emb_L_1_1_D": rope_emb_L_1_1_D, "adaln_lora_B_T_3D": adaln_lora_B_T_3D, "extra_per_block_pos_emb": extra_pos_emb, } attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn) # Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16) use_fp32 = x_B_T_H_W_D.dtype == torch.float16 for block_idx, block in enumerate(self.blocks): if self.blocks_to_swap: self.offloader.wait_for_block(block_idx) x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs) if self.blocks_to_swap: self.offloader.submit_move_blocks(self.blocks, block_idx) x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32) x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O) return x_B_C_Tt_Hp_Wp def forward( self, x: torch.Tensor, timesteps: torch.Tensor, context: Optional[torch.Tensor] = None, fps: Optional[torch.Tensor] = None, padding_mask: Optional[torch.Tensor] = None, target_input_ids: Optional[torch.Tensor] = None, target_attention_mask: Optional[torch.Tensor] = None, source_attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: context = self._preprocess_text_embeds(context, target_input_ids, target_attention_mask, source_attention_mask) return self.forward_mini_train_dit(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs) def _preprocess_text_embeds( self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None ): if target_input_ids is not None: context = self.llm_adapter( source_hidden_states, target_input_ids, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, ) context[~target_attention_mask.bool()] = 0 # zero out padding tokens return context else: return source_hidden_states # LLM Adapter: Bridges Qwen3 embeddings to T5-compatible space class LLMAdapterRMSNorm(nn.Module): """RMSNorm specifically for the LLM Adapter (T5-style, no mean subtraction).""" def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) return self.weight * hidden_states def _adapter_rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def _adapter_apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) x_embed = (x * cos) + (_adapter_rotate_half(x) * sin) return x_embed class AdapterRotaryEmbedding(nn.Module): """Rotary embedding for LLM Adapter.""" def __init__(self, head_dim): super().__init__() self.rope_theta = 10000 inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) @torch.no_grad() def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) class LLMAdapterAttention(nn.Module): """Attention module for LLM Adapter with QK-norm and separate RoPE for query/key.""" def __init__(self, query_dim, context_dim, n_heads, head_dim): super().__init__() inner_dim = head_dim * n_heads self.n_heads = n_heads self.head_dim = head_dim self.query_dim = query_dim self.context_dim = context_dim self.q_proj = nn.Linear(query_dim, inner_dim, bias=False) self.q_norm = LLMAdapterRMSNorm(self.head_dim) self.k_proj = nn.Linear(context_dim, inner_dim, bias=False) self.k_norm = LLMAdapterRMSNorm(self.head_dim) self.v_proj = nn.Linear(context_dim, inner_dim, bias=False) self.o_proj = nn.Linear(inner_dim, query_dim, bias=False) def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None): context = x if context is None else context input_shape = x.shape[:-1] q_shape = (*input_shape, self.n_heads, self.head_dim) context_shape = context.shape[:-1] kv_shape = (*context_shape, self.n_heads, self.head_dim) query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2) key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2) value_states = self.v_proj(context).view(kv_shape).transpose(1, 2) if position_embeddings is not None: assert position_embeddings_context is not None cos, sin = position_embeddings query_states = _adapter_apply_rotary_pos_emb(query_states, cos, sin) cos, sin = position_embeddings_context key_states = _adapter_apply_rotary_pos_emb(key_states, cos, sin) attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask) attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output class LLMAdapterTransformerBlock(nn.Module): """Transformer block for LLM Adapter: optional self-attn + cross-attn + MLP.""" def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, self_attn=False, layer_norm=False): super().__init__() self.has_self_attn = self_attn if self.has_self_attn: self.norm_self_attn = nn.LayerNorm(model_dim) if layer_norm else LLMAdapterRMSNorm(model_dim) self.self_attn = LLMAdapterAttention( query_dim=model_dim, context_dim=model_dim, n_heads=num_heads, head_dim=model_dim // num_heads, ) self.norm_cross_attn = nn.LayerNorm(model_dim) if layer_norm else LLMAdapterRMSNorm(model_dim) self.cross_attn = LLMAdapterAttention( query_dim=model_dim, context_dim=source_dim, n_heads=num_heads, head_dim=model_dim // num_heads, ) self.norm_mlp = nn.LayerNorm(model_dim) if layer_norm else LLMAdapterRMSNorm(model_dim) self.mlp = nn.Sequential( nn.Linear(model_dim, int(model_dim * mlp_ratio)), nn.GELU(), nn.Linear(int(model_dim * mlp_ratio), model_dim) ) def forward( self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None, ): if self.has_self_attn: # Self-attention: target_attention_mask is not expected to be all zeros normed = self.norm_self_attn(x) attn_out = self.self_attn( normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings, ) x = x + attn_out normed = self.norm_cross_attn(x) attn_out = self.cross_attn( normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context, ) x = x + attn_out x = x + self.mlp(self.norm_mlp(x)) return x def init_weights(self): torch.nn.init.zeros_(self.mlp[2].weight) class LLMAdapter(nn.Module): """Bridge module: Qwen3 embeddings (source) → T5-compatible space (target). Uses T5 token IDs as target input, embeds them, and cross-attends to Qwen3 hidden states. """ def __init__( self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16, embed=None, self_attn=False, layer_norm=False ): super().__init__() if embed is not None: self.embed = nn.Embedding.from_pretrained(embed.weight) else: self.embed = nn.Embedding(32128, target_dim) if model_dim != target_dim: self.in_proj = nn.Linear(target_dim, model_dim) else: self.in_proj = nn.Identity() self.rotary_emb = AdapterRotaryEmbedding(model_dim // num_heads) self.blocks = nn.ModuleList( [ LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, self_attn=self_attn, layer_norm=layer_norm) for _ in range(num_layers) ] ) self.out_proj = nn.Linear(model_dim, target_dim) self.norm = LLMAdapterRMSNorm(target_dim) def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None): if target_attention_mask is not None: target_attention_mask = target_attention_mask.to(torch.bool) if target_attention_mask.ndim == 2: target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1) if source_attention_mask is not None: source_attention_mask = source_attention_mask.to(torch.bool) if source_attention_mask.ndim == 2: source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1) x = self.in_proj(self.embed(target_input_ids)) context = source_hidden_states position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0) position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0) position_embeddings = self.rotary_emb(x, position_ids) position_embeddings_context = self.rotary_emb(x, position_ids_context) for block in self.blocks: x = block( x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context, ) return self.norm(self.out_proj(x)) # Not used currently, but kept for reference # def get_dit_config(state_dict, key_prefix=""): # """Derive DiT configuration from state_dict weight shapes.""" # dit_config = {} # dit_config["max_img_h"] = 512 # dit_config["max_img_w"] = 512 # dit_config["max_frames"] = 128 # concat_padding_mask = True # dit_config["in_channels"] = (state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[1] // 4) - int( # concat_padding_mask # ) # dit_config["out_channels"] = 16 # dit_config["patch_spatial"] = 2 # dit_config["patch_temporal"] = 1 # dit_config["model_channels"] = state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[0] # dit_config["concat_padding_mask"] = concat_padding_mask # dit_config["crossattn_emb_channels"] = 1024 # dit_config["pos_emb_cls"] = "rope3d" # dit_config["pos_emb_learnable"] = True # dit_config["pos_emb_interpolation"] = "crop" # dit_config["min_fps"] = 1 # dit_config["max_fps"] = 30 # dit_config["use_adaln_lora"] = True # dit_config["adaln_lora_dim"] = 256 # if dit_config["model_channels"] == 2048: # dit_config["num_blocks"] = 28 # dit_config["num_heads"] = 16 # elif dit_config["model_channels"] == 5120: # dit_config["num_blocks"] = 36 # dit_config["num_heads"] = 40 # elif dit_config["model_channels"] == 1280: # dit_config["num_blocks"] = 20 # dit_config["num_heads"] = 20 # if dit_config["in_channels"] == 16: # dit_config["extra_per_block_abs_pos_emb"] = False # dit_config["rope_h_extrapolation_ratio"] = 4.0 # dit_config["rope_w_extrapolation_ratio"] = 4.0 # dit_config["rope_t_extrapolation_ratio"] = 1.0 # elif dit_config["in_channels"] == 17: # dit_config["extra_per_block_abs_pos_emb"] = False # dit_config["rope_h_extrapolation_ratio"] = 3.0 # dit_config["rope_w_extrapolation_ratio"] = 3.0 # dit_config["rope_t_extrapolation_ratio"] = 1.0 # dit_config["extra_h_extrapolation_ratio"] = 1.0 # dit_config["extra_w_extrapolation_ratio"] = 1.0 # dit_config["extra_t_extrapolation_ratio"] = 1.0 # dit_config["rope_enable_fps_modulation"] = False # return dit_config