# Anima model loading/saving utilities import os from typing import Dict, List, Optional, Union import torch import torch.nn as nn from safetensors.torch import load_file, save_file from accelerate.utils import set_module_tensor_to_device # kept for potential future use from accelerate import init_empty_weights from library.fp8_optimization_utils import apply_fp8_monkey_patch from library.lora_utils import load_safetensors_with_lora_and_fp8 from .utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) from library import anima_models # Keys that should stay in high precision (float32/bfloat16, not quantized) KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"] def load_safetensors(path: str, device: str = "cpu", dtype: Optional[torch.dtype] = None) -> Dict[str, torch.Tensor]: """Load a safetensors file and optionally cast to dtype.""" sd = load_file(path, device=device) if dtype is not None: sd = {k: v.to(dtype) for k, v in sd.items()} return sd def load_anima_dit( dit_path: str, dtype: torch.dtype, device: Union[str, torch.device] = "cpu", transformer_dtype: Optional[torch.dtype] = None, llm_adapter_path: Optional[str] = None, disable_mmap: bool = False, ) -> anima_models.MiniTrainDIT: """Load the MiniTrainDIT model from safetensors. Args: dit_path: Path to DiT safetensors file dtype: Base dtype for model parameters device: Device to load to transformer_dtype: Optional separate dtype for transformer blocks (lower precision) llm_adapter_path: Optional separate path for LLM adapter weights disable_mmap: If True, disable memory-mapped loading (reduces peak memory) """ if transformer_dtype is None: transformer_dtype = dtype logger.info(f"Loading Anima DiT from {dit_path}") if disable_mmap: from library.safetensors_utils import load_safetensors as load_safetensors_no_mmap state_dict = load_safetensors_no_mmap(dit_path, device="cpu", disable_mmap=True) else: state_dict = load_file(dit_path, device="cpu") # Remove 'net.' prefix if present new_state_dict = {} for k, v in state_dict.items(): if k.startswith("net."): k = k[len("net.") :] new_state_dict[k] = v state_dict = new_state_dict # Derive config from state_dict dit_config = anima_models.get_dit_config(state_dict) # Detect LLM adapter if llm_adapter_path is not None: use_llm_adapter = True dit_config["use_llm_adapter"] = True llm_adapter_state_dict = load_safetensors(llm_adapter_path, device="cpu") elif "llm_adapter.out_proj.weight" in state_dict: use_llm_adapter = True dit_config["use_llm_adapter"] = True llm_adapter_state_dict = None # Loaded as part of DiT else: use_llm_adapter = False llm_adapter_state_dict = None logger.info( f"DiT config: model_channels={dit_config['model_channels']}, num_blocks={dit_config['num_blocks']}, " f"num_heads={dit_config['num_heads']}, use_llm_adapter={use_llm_adapter}" ) # Build model normally on CPU — buffers get proper values from __init__ dit = anima_models.MiniTrainDIT(**dit_config) # Merge LLM adapter weights into state_dict if loaded separately if use_llm_adapter and llm_adapter_state_dict is not None: for k, v in llm_adapter_state_dict.items(): state_dict[f"llm_adapter.{k}"] = v # Load checkpoint: strict=False keeps buffers not in checkpoint (e.g. pos_embedder.seq) missing, unexpected = dit.load_state_dict(state_dict, strict=False) if missing: # Filter out expected missing buffers (initialized in __init__, not saved in checkpoint) unexpected_missing = [ k for k in missing if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq")) ] if unexpected_missing: logger.warning(f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}") if unexpected: logger.info(f"Unexpected keys in checkpoint (ignored): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") # Apply per-parameter dtype (high precision for 1D/critical, transformer_dtype for rest) for name, p in dit.named_parameters(): dtype_to_use = dtype if (any(keyword in name for keyword in KEEP_IN_HIGH_PRECISION) or p.ndim == 1) else transformer_dtype p.data = p.data.to(dtype=dtype_to_use) dit.to(device) logger.info(f"Loaded Anima DiT successfully. Parameters: {sum(p.numel() for p in dit.parameters()):,}") return dit FP8_OPTIMIZATION_TARGET_KEYS = ["blocks", ""] FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer"] def load_anima_model( device: Union[str, torch.device], dit_path: str, attn_mode: str, split_attn: bool, loading_device: Union[str, torch.device], dit_weight_dtype: Optional[torch.dtype], fp8_scaled: bool = False, lora_weights_list: Optional[Dict[str, torch.Tensor]] = None, lora_multipliers: Optional[list[float]] = None, ) -> anima_models.Anima: """ Load a HunyuanImage model from the specified checkpoint. Args: device (Union[str, torch.device]): Device for optimization or merging dit_path (str): Path to the DiT model checkpoint. attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc. split_attn (bool): Whether to use split attention. loading_device (Union[str, torch.device]): Device to load the model weights on. dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights. If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype. fp8_scaled (bool): Whether to use fp8 scaling for the model weights. lora_weights_list (Optional[Dict[str, torch.Tensor]]): LoRA weights to apply, if any. lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any. """ # dit_weight_dtype is None for fp8_scaled assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None) device = torch.device(device) loading_device = torch.device(loading_device) # We currently support fixed DiT config for Anima models dit_config = { "max_img_h": 512, "max_img_w": 512, "max_frames": 128, "in_channels": 16, "out_channels": 16, "patch_spatial": 2, "patch_temporal": 1, "model_channels": 2048, "concat_padding_mask": True, "crossattn_emb_channels": 1024, "pos_emb_cls": "rope3d", "pos_emb_learnable": True, "pos_emb_interpolation": "crop", "min_fps": 1, "max_fps": 30, "use_adaln_lora": True, "adaln_lora_dim": 256, "num_blocks": 28, "num_heads": 16, "extra_per_block_abs_pos_emb": False, "rope_h_extrapolation_ratio": 4.0, "rope_w_extrapolation_ratio": 4.0, "rope_t_extrapolation_ratio": 1.0, "extra_h_extrapolation_ratio": 1.0, "extra_w_extrapolation_ratio": 1.0, "extra_t_extrapolation_ratio": 1.0, "rope_enable_fps_modulation": False, "use_llm_adapter": True, "attn_mode": attn_mode, "split_attn": split_attn, } # model = create_model(attn_mode, split_attn, dit_weight_dtype) with init_empty_weights(): model = anima_models.Anima(dit_config) if dit_weight_dtype is not None: model.to(dit_weight_dtype) # load model weights with dynamic fp8 optimization and LoRA merging if needed logger.info(f"Loading DiT model from {dit_path}, device={loading_device}") sd = load_safetensors_with_lora_and_fp8( model_files=dit_path, lora_weights_list=lora_weights_list, lora_multipliers=lora_multipliers, fp8_optimization=fp8_scaled, calc_device=device, move_to_device=(loading_device == device), dit_weight_dtype=dit_weight_dtype, target_keys=FP8_OPTIMIZATION_TARGET_KEYS, exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS, ) if fp8_scaled: apply_fp8_monkey_patch(model, sd, use_scaled_mm=False) if loading_device.type != "cpu": # make sure all the model weights are on the loading_device logger.info(f"Moving weights to {loading_device}") for key in sd.keys(): sd[key] = sd[key].to(loading_device) missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) if missing: # Filter out expected missing buffers (initialized in __init__, not saved in checkpoint) unexpected_missing = [ k for k in missing if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq")) ] if unexpected_missing: # Raise error to avoid silent failures raise RuntimeError( f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}" ) missing = {} # all missing keys were expected if unexpected: # Raise error to avoid silent failures raise RuntimeError(f"Unexpected keys in checkpoint: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") logger.info(f"Loaded DiT model from {dit_path}, unexpected missing keys: {len(missing)}, unexpected keys: {len(unexpected)}") return model def load_anima_vae(vae_path: str, dtype: torch.dtype = torch.float32, device: str = "cpu"): """Load WanVAE from a safetensors/pth file. Returns (vae_model, mean_tensor, std_tensor, scale). """ from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD logger.info(f"Loading Anima VAE from {vae_path}") # VAE config (fixed for WanVAE) vae_config = dict( dim=96, z_dim=16, dim_mult=[1, 2, 4, 4], num_res_blocks=2, attn_scales=[], temperal_downsample=[False, True, True], dropout=0.0, ) from library.anima_vae import WanVAE_ # Build model with torch.device("meta"): vae = WanVAE_(**vae_config) # Load state dict if vae_path.endswith(".safetensors"): vae_sd = load_file(vae_path, device="cpu") else: vae_sd = torch.load(vae_path, map_location="cpu", weights_only=True) vae.load_state_dict(vae_sd, assign=True) vae = vae.eval().requires_grad_(False).to(device, dtype=dtype) # Create normalization tensors mean = torch.tensor(ANIMA_VAE_MEAN, dtype=dtype, device=device) std = torch.tensor(ANIMA_VAE_STD, dtype=dtype, device=device) scale = [mean, 1.0 / std] logger.info(f"Loaded Anima VAE successfully.") return vae, mean, std, scale def load_qwen3_tokenizer(qwen3_path: str): """Load Qwen3 tokenizer only (without the text encoder model). Args: qwen3_path: Path to either a directory with model files or a safetensors file. If a directory, loads tokenizer from it directly. If a file, uses configs/qwen3_06b/ for tokenizer config. Returns: tokenizer """ from transformers import AutoTokenizer if os.path.isdir(qwen3_path): tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True) else: config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b") if not os.path.exists(config_dir): raise FileNotFoundError( f"Qwen3 config directory not found at {config_dir}. " "Expected configs/qwen3_06b/ with config.json, tokenizer.json, etc. " "You can download these from the Qwen3-0.6B HuggingFace repository." ) tokenizer = AutoTokenizer.from_pretrained(config_dir, local_files_only=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return tokenizer def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16, device: str = "cpu"): """Load Qwen3-0.6B text encoder. Args: qwen3_path: Path to either a directory with model files or a safetensors file dtype: Model dtype device: Device to load to Returns: (text_encoder_model, tokenizer) """ import transformers from transformers import AutoTokenizer logger.info(f"Loading Qwen3 text encoder from {qwen3_path}") if os.path.isdir(qwen3_path): # Directory with full model tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True) model = transformers.AutoModelForCausalLM.from_pretrained(qwen3_path, torch_dtype=dtype, local_files_only=True).model else: # Single safetensors file - use configs/qwen3_06b/ for config config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b") if not os.path.exists(config_dir): raise FileNotFoundError( f"Qwen3 config directory not found at {config_dir}. " "Expected configs/qwen3_06b/ with config.json, tokenizer.json, etc. " "You can download these from the Qwen3-0.6B HuggingFace repository." ) tokenizer = AutoTokenizer.from_pretrained(config_dir, local_files_only=True) qwen3_config = transformers.Qwen3Config.from_pretrained(config_dir, local_files_only=True) model = transformers.Qwen3ForCausalLM(qwen3_config).model # Load weights if qwen3_path.endswith(".safetensors"): state_dict = load_file(qwen3_path, device="cpu") else: state_dict = torch.load(qwen3_path, map_location="cpu", weights_only=True) # Remove 'model.' prefix if present new_sd = {} for k, v in state_dict.items(): if k.startswith("model."): new_sd[k[len("model.") :]] = v else: new_sd[k] = v info = model.load_state_dict(new_sd, strict=False) logger.info(f"Loaded Qwen3 state dict: {info}") if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model.config.use_cache = False model = model.requires_grad_(False).to(device, dtype=dtype) logger.info(f"Loaded Qwen3 text encoder. Parameters: {sum(p.numel() for p in model.parameters()):,}") return model, tokenizer def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None): """Load T5 tokenizer for LLM Adapter target tokens. Args: t5_tokenizer_path: Optional path to T5 tokenizer directory. If None, uses default configs. """ from transformers import T5TokenizerFast if t5_tokenizer_path is not None: return T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True) # Use bundled config config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "t5_old") if os.path.exists(config_dir): return T5TokenizerFast( vocab_file=os.path.join(config_dir, "spiece.model"), tokenizer_file=os.path.join(config_dir, "tokenizer.json"), ) raise FileNotFoundError( f"T5 tokenizer config directory not found at {config_dir}. " "Expected configs/t5_old/ with spiece.model and tokenizer.json. " "You can download these from the google/t5-v1_1-xxl HuggingFace repository." ) def save_anima_model(save_path: str, dit_state_dict: Dict[str, torch.Tensor], dtype: Optional[torch.dtype] = None): """Save Anima DiT model with 'net.' prefix for ComfyUI compatibility. Args: save_path: Output path (.safetensors) dit_state_dict: State dict from dit.state_dict() dtype: Optional dtype to cast to before saving """ prefixed_sd = {} for k, v in dit_state_dict.items(): if dtype is not None: v = v.to(dtype) prefixed_sd["net." + k] = v.contiguous() save_file(prefixed_sd, save_path, metadata={"format": "pt"}) logger.info(f"Saved Anima model to {save_path}") def vae_encode(tensor: torch.Tensor, vae, scale): """Encode tensor through WanVAE with normalization. Args: tensor: Input tensor (B, C, T, H, W) in [-1, 1] range vae: WanVAE_ model scale: [mean, 1/std] list Returns: Normalized latents """ return vae.encode(tensor, scale) def vae_decode(latents: torch.Tensor, vae, scale): """Decode latents through WanVAE with denormalization. Args: latents: Normalized latents vae: WanVAE_ model scale: [mean, 1/std] list Returns: Decoded tensor in [-1, 1] range """ return vae.decode(latents, scale)