diff --git a/anima_train.py b/anima_train.py index a86c30c3..1916dfa6 100644 --- a/anima_train.py +++ b/anima_train.py @@ -49,35 +49,32 @@ def train(args): args.skip_cache_check = args.skip_latents_validity_check if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: - logger.warning( - "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled" - ) + logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled") args.cache_text_encoder_outputs = True if args.cpu_offload_checkpointing and not args.gradient_checkpointing: logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled") args.gradient_checkpointing = True - if getattr(args, 'unsloth_offload_checkpointing', False): + if getattr(args, "unsloth_offload_checkpointing", False): if not args.gradient_checkpointing: logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled") args.gradient_checkpointing = True - assert not args.cpu_offload_checkpointing, \ - "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing" + assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing" assert ( args.blocks_to_swap is None or args.blocks_to_swap == 0 ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing" - assert ( - args.blocks_to_swap is None or args.blocks_to_swap == 0 - ) or not getattr(args, 'unsloth_offload_checkpointing', False), \ - "blocks_to_swap is not supported with unsloth_offload_checkpointing" + assert (args.blocks_to_swap is None or args.blocks_to_swap == 0) or not getattr( + args, "unsloth_offload_checkpointing", False + ), "blocks_to_swap is not supported with unsloth_offload_checkpointing" # Flash attention: validate availability - if getattr(args, 'flash_attn', False): + if getattr(args, "flash_attn", False): try: import flash_attn # noqa: F401 + logger.info("Flash Attention enabled for DiT blocks") except ImportError: logger.warning("flash_attn package not installed, falling back to PyTorch SDPA") @@ -104,9 +101,7 @@ def train(args): user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignore following options because config file is found: {0}".format(", ".join(ignored)) - ) + logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored))) else: if use_dreambooth_method: logger.info("Using DreamBooth method.") @@ -150,7 +145,7 @@ def train(args): # Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of # dataset-level caption dropout, so we save the rate and zero out subset-level # caption_dropout_rate to allow text encoder output caching. - caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) + caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0) if caption_dropout_rate > 0: logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}") for dataset in train_dataset_group.datasets: @@ -175,9 +170,7 @@ def train(args): return if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used" + assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used" if args.cache_text_encoder_outputs: assert ( @@ -193,7 +186,7 @@ def train(args): # parse transformer_dtype transformer_dtype = None - if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None: + if hasattr(args, "transformer_dtype") and args.transformer_dtype is not None: transformer_dtype_map = { "float16": torch.float16, "bfloat16": torch.bfloat16, @@ -203,12 +196,8 @@ def train(args): # Load tokenizers and set strategies logger.info("Loading tokenizers...") - qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder( - args.qwen3_path, dtype=weight_dtype, device="cpu" - ) - t5_tokenizer = anima_utils.load_t5_tokenizer( - getattr(args, 't5_tokenizer_path', None) - ) + qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=weight_dtype, device="cpu") + t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None)) # Set tokenize strategy tokenize_strategy = strategy_anima.AnimaTokenizeStrategy( @@ -220,7 +209,7 @@ def train(args): strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) # Set text encoding strategy - caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) + caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0) text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy( dropout_rate=caption_dropout_rate, ) @@ -266,7 +255,7 @@ def train(args): ) # Pre-cache unconditional embeddings for caption dropout before text encoder is deleted - caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) + caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0) if caption_dropout_rate > 0.0: with accelerator.autocast(): text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder]) @@ -299,17 +288,17 @@ def train(args): dtype=weight_dtype, device="cpu", transformer_dtype=transformer_dtype, - llm_adapter_path=getattr(args, 'llm_adapter_path', None), - disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False), + llm_adapter_path=getattr(args, "llm_adapter_path", None), + disable_mmap=getattr(args, "disable_mmap_load_safetensors", False), ) if args.gradient_checkpointing: dit.enable_gradient_checkpointing( cpu_offload=args.cpu_offload_checkpointing, - unsloth_offload=getattr(args, 'unsloth_offload_checkpointing', False), + unsloth_offload=getattr(args, "unsloth_offload_checkpointing", False), ) - if getattr(args, 'flash_attn', False): + if getattr(args, "flash_attn", False): dit.set_flash_attn(True) train_dit = args.learning_rate != 0 @@ -335,11 +324,11 @@ def train(args): param_groups = anima_train_utils.get_anima_param_groups( dit, base_lr=args.learning_rate, - self_attn_lr=getattr(args, 'self_attn_lr', None), - cross_attn_lr=getattr(args, 'cross_attn_lr', None), - mlp_lr=getattr(args, 'mlp_lr', None), - mod_lr=getattr(args, 'mod_lr', None), - llm_adapter_lr=getattr(args, 'llm_adapter_lr', None), + self_attn_lr=getattr(args, "self_attn_lr", None), + cross_attn_lr=getattr(args, "cross_attn_lr", None), + mlp_lr=getattr(args, "mlp_lr", None), + mod_lr=getattr(args, "mod_lr", None), + llm_adapter_lr=getattr(args, "llm_adapter_lr", None), ) else: param_groups = [] @@ -366,8 +355,8 @@ def train(args): # Build param_id → lr mapping from param_groups to propagate per-component LRs param_lr_map = {} for group in param_groups: - for p in group['params']: - param_lr_map[id(p)] = group['lr'] + for p in group["params"]: + param_lr_map[id(p)] = group["lr"] grouped_params = [] param_group = {} @@ -557,9 +546,7 @@ def train(args): accelerator.print(f" num examples: {train_dataset_group.num_train_images}") accelerator.print(f" num batches per epoch: {len(train_dataloader)}") accelerator.print(f" num epochs: {num_train_epochs}") - accelerator.print( - f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" - ) + accelerator.print(f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps: {args.max_train_steps}") @@ -580,6 +567,7 @@ def train(args): if "wandb" in [tracker.name for tracker in accelerator.trackers]: import wandb + wandb.define_metric("epoch") wandb.define_metric("loss/epoch", step_metric="epoch") @@ -589,8 +577,16 @@ def train(args): # For --sample_at_first optimizer_eval_fn() anima_train_utils.sample_images( - accelerator, args, 0, global_step, dit, vae, vae_scale, - qwen3_text_encoder, tokenize_strategy, text_encoding_strategy, + accelerator, + args, + 0, + global_step, + dit, + vae, + vae_scale, + qwen3_text_encoder, + tokenize_strategy, + text_encoding_strategy, sample_prompts_te_outputs, ) optimizer_train_fn() @@ -600,7 +596,9 @@ def train(args): # Show model info unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None if unwrapped_dit is not None: - logger.info(f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}") + logger.info( + f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}" + ) if qwen3_text_encoder is not None: logger.info(f"qwen3 device: {next(qwen3_text_encoder.parameters()).device}") if vae is not None: @@ -640,9 +638,7 @@ def train(args): text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: # Cached outputs - text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs( - *text_encoder_outputs_list - ) + text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list else: # Encode on-the-fly @@ -678,10 +674,7 @@ def train(args): bs = latents.shape[0] h_latent = latents.shape[-2] w_latent = latents.shape[-1] - padding_mask = torch.zeros( - bs, 1, h_latent, w_latent, - dtype=weight_dtype, device=accelerator.device - ) + padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device) # DiT forward (LLM adapter runs inside forward for DDP gradient sync) if is_swapping_blocks: @@ -708,9 +701,7 @@ def train(args): # Loss huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None) - loss = train_util.conditional_loss( - model_pred.float(), target.float(), args.loss_type, "none", huber_c - ) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3, 4]) # (B, C, T, H, W) -> (B,) @@ -748,8 +739,16 @@ def train(args): optimizer_eval_fn() anima_train_utils.sample_images( - accelerator, args, None, global_step, dit, vae, vae_scale, - qwen3_text_encoder, tokenize_strategy, text_encoding_strategy, + accelerator, + args, + None, + global_step, + dit, + vae, + vae_scale, + qwen3_text_encoder, + tokenize_strategy, + text_encoding_strategy, sample_prompts_te_outputs, ) @@ -773,8 +772,10 @@ def train(args): if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs_with_names( - logs, lr_scheduler, args.optimizer_type, - ["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [] + logs, + lr_scheduler, + args.optimizer_type, + ["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [], ) accelerator.log(logs, step=global_step) @@ -807,8 +808,16 @@ def train(args): ) anima_train_utils.sample_images( - accelerator, args, epoch + 1, global_step, dit, vae, vae_scale, - qwen3_text_encoder, tokenize_strategy, text_encoding_strategy, + accelerator, + args, + epoch + 1, + global_step, + dit, + vae, + vae_scale, + qwen3_text_encoder, + tokenize_strategy, + text_encoding_strategy, sample_prompts_te_outputs, ) diff --git a/anima_train_network.py b/anima_train_network.py index 57ad1681..c19c52cc 100644 --- a/anima_train_network.py +++ b/anima_train_network.py @@ -39,17 +39,15 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): val_dataset_group: Optional[train_util.DatasetGroup], ): if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: - logger.warning( - "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled" - ) + logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled") args.cache_text_encoder_outputs = True # Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of # dataset-level caption dropout, so zero out subset-level rates to allow caching. - caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) + caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0) if caption_dropout_rate > 0: logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}") - if hasattr(train_dataset_group, 'datasets'): + if hasattr(train_dataset_group, "datasets"): for dataset in train_dataset_group.datasets: for subset in dataset.subsets: subset.caption_dropout_rate = 0.0 @@ -63,26 +61,28 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): args.blocks_to_swap is None or args.blocks_to_swap == 0 ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing" - if getattr(args, 'unsloth_offload_checkpointing', False): + if getattr(args, "unsloth_offload_checkpointing", False): if not args.gradient_checkpointing: logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled") args.gradient_checkpointing = True - assert not args.cpu_offload_checkpointing, \ - "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing" + assert ( + not args.cpu_offload_checkpointing + ), "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing" assert ( args.blocks_to_swap is None or args.blocks_to_swap == 0 ), "blocks_to_swap is not supported with unsloth_offload_checkpointing" # Flash attention: validate availability - if getattr(args, 'flash_attn', False): + if getattr(args, "flash_attn", False): try: import flash_attn # noqa: F401 + logger.info("Flash Attention enabled for DiT blocks") except ImportError: logger.warning("flash_attn package not installed, falling back to PyTorch SDPA") args.flash_attn = False - if getattr(args, 'blockwise_fused_optimizers', False): + if getattr(args, "blockwise_fused_optimizers", False): raise ValueError("blockwise_fused_optimizers is not supported with LoRA/NetworkTrainer") train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8 @@ -92,14 +92,12 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): def load_target_model(self, args, weight_dtype, accelerator): # Load Qwen3 text encoder (tokenizers already loaded in get_tokenize_strategy) logger.info("Loading Qwen3 text encoder...") - self.qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder( - args.qwen3_path, dtype=weight_dtype, device="cpu" - ) + self.qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=weight_dtype, device="cpu") self.qwen3_text_encoder.eval() # Parse transformer_dtype transformer_dtype = None - if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None: + if hasattr(args, "transformer_dtype") and args.transformer_dtype is not None: transformer_dtype_map = { "float16": torch.float16, "bfloat16": torch.bfloat16, @@ -114,18 +112,18 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): dtype=weight_dtype, device="cpu", transformer_dtype=transformer_dtype, - llm_adapter_path=getattr(args, 'llm_adapter_path', None), - disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False), + llm_adapter_path=getattr(args, "llm_adapter_path", None), + disable_mmap=getattr(args, "disable_mmap_load_safetensors", False), ) # Flash attention - if getattr(args, 'flash_attn', False): + if getattr(args, "flash_attn", False): dit.set_flash_attn(True) # Store unsloth preference so that when the base NetworkTrainer calls # dit.enable_gradient_checkpointing(cpu_offload=...), we can override to use unsloth. # The base trainer only passes cpu_offload, so we store the flag on the model. - self._use_unsloth_offload_checkpointing = getattr(args, 'unsloth_offload_checkpointing', False) + self._use_unsloth_offload_checkpointing = getattr(args, "unsloth_offload_checkpointing", False) # Block swap self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 @@ -135,9 +133,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): # Load VAE logger.info("Loading Anima VAE...") - self.vae, vae_mean, vae_std, self.vae_scale = anima_utils.load_anima_vae( - args.vae_path, dtype=weight_dtype, device="cpu" - ) + self.vae, vae_mean, vae_std, self.vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu") # Return format: (model_type, text_encoders, vae, unet) return "anima", [self.qwen3_text_encoder], self.vae, dit @@ -146,7 +142,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): # Load tokenizers from paths (called before load_target_model, so self.qwen3_tokenizer isn't set yet) self.tokenize_strategy = strategy_anima.AnimaTokenizeStrategy( qwen3_path=args.qwen3_path, - t5_tokenizer_path=getattr(args, 't5_tokenizer_path', None), + t5_tokenizer_path=getattr(args, "t5_tokenizer_path", None), qwen3_max_length=args.qwen3_max_token_length, t5_max_length=args.t5_max_token_length, ) @@ -159,12 +155,10 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): return [tokenize_strategy.qwen3_tokenizer] def get_latents_caching_strategy(self, args): - return strategy_anima.AnimaLatentsCachingStrategy( - args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check - ) + return strategy_anima.AnimaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check) def get_text_encoding_strategy(self, args): - caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) + caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0) self.text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy( dropout_rate=caption_dropout_rate, ) @@ -237,7 +231,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): self.sample_prompts_te_outputs = sample_prompts_te_outputs # Pre-cache unconditional embeddings for caption dropout before text encoder is deleted - caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) + caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0) text_encoding_strategy_for_uncond = strategy_base.TextEncodingStrategy.get_strategy() if caption_dropout_rate > 0.0: tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy() @@ -264,8 +258,16 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): qwen3_te = te[0] if te is not None else None anima_train_utils.sample_images( - accelerator, args, epoch, global_step, unet, vae, self.vae_scale, - qwen3_te, self.tokenize_strategy, self.text_encoding_strategy, + accelerator, + args, + epoch, + global_step, + unet, + vae, + self.vae_scale, + qwen3_te, + self.tokenize_strategy, + self.text_encoding_strategy, self.sample_prompts_te_outputs, ) @@ -329,10 +331,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): bs = latents.shape[0] h_latent = latents.shape[-2] w_latent = latents.shape[-1] - padding_mask = torch.zeros( - bs, 1, h_latent, w_latent, - dtype=weight_dtype, device=accelerator.device - ) + padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device) # Prepare block swap if self.is_swapping_blocks: @@ -354,9 +353,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): target = noise - latents # Loss weighting - weighting = anima_train_utils.compute_loss_weighting_for_anima( - weighting_scheme=args.weighting_scheme, sigmas=sigmas - ) + weighting = anima_train_utils.compute_loss_weighting_for_anima(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # Differential output preservation if "custom_attributes" in batch: @@ -386,10 +383,22 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): return model_pred, target, timesteps, weighting def process_batch( - self, batch, text_encoders, unet, network, vae, noise_scheduler, - vae_dtype, weight_dtype, accelerator, args, - text_encoding_strategy, tokenize_strategy, - is_train=True, train_text_encoder=True, train_unet=True, + self, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=True, + train_text_encoder=True, + train_unet=True, ) -> torch.Tensor: """Override base process_batch for 5D video latents (B, C, T, H, W). @@ -446,8 +455,17 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): text_encoder_conds[i] = encoded_text_encoder_conds[i] noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( - args, accelerator, noise_scheduler, latents, batch, - text_encoder_conds, unet, network, weight_dtype, train_unet, is_train=is_train, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + is_train=is_train, ) huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) @@ -479,8 +497,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): def update_metadata(self, metadata, args): metadata["ss_weighting_scheme"] = args.weighting_scheme metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift - metadata["ss_timestep_sample_method"] = getattr(args, 'timestep_sample_method', 'logit_normal') - metadata["ss_sigmoid_scale"] = getattr(args, 'sigmoid_scale', 1.0) + metadata["ss_timestep_sample_method"] = getattr(args, "timestep_sample_method", "logit_normal") + metadata["ss_sigmoid_scale"] = getattr(args, "sigmoid_scale", 1.0) def is_text_encoder_not_needed_for_training(self, args): return args.cache_text_encoder_outputs diff --git a/library/anima_models.py b/library/anima_models.py index 79f7962f..d3adff5f 100644 --- a/library/anima_models.py +++ b/library/anima_models.py @@ -17,7 +17,6 @@ from library import custom_offloading_utils, attention from library.device_utils import clean_memory_on_device - def to_device(x, device): if isinstance(x, torch.Tensor): return x.to(device) @@ -39,11 +38,13 @@ def to_cpu(x): else: return x + # Unsloth Offloaded Gradient Checkpointing # Based on Unsloth Zoo by Daniel Han-Chen & the Unsloth team try: from deepspeed.runtime.activation_checkpointing.checkpointing import detach_variable except ImportError: + def detach_variable(inputs, device=None): """Detach tensors from computation graph, optionally moving to a device. @@ -80,11 +81,11 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function): """ @staticmethod - @torch.amp.custom_fwd(device_type='cuda') + @torch.amp.custom_fwd(device_type="cuda") def forward(ctx, forward_function, hidden_states, *args): # Remember the original device for backward pass (multi-GPU support) ctx.input_device = hidden_states.device - saved_hidden_states = hidden_states.to('cpu', non_blocking=True) + saved_hidden_states = hidden_states.to("cpu", non_blocking=True) with torch.no_grad(): output = forward_function(hidden_states, *args) ctx.save_for_backward(saved_hidden_states) @@ -96,7 +97,7 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function): return output @staticmethod - @torch.amp.custom_bwd(device_type='cuda') + @torch.amp.custom_bwd(device_type="cuda") def backward(ctx, *grads): (hidden_states,) = ctx.saved_tensors hidden_states = hidden_states.to(ctx.input_device, non_blocking=True).detach() @@ -108,8 +109,9 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function): output_tensors = [] grad_tensors = [] - for out, grad in zip(outputs if isinstance(outputs, tuple) else (outputs,), - grads if isinstance(grads, tuple) else (grads,)): + for out, grad in zip( + outputs if isinstance(outputs, tuple) else (outputs,), grads if isinstance(grads, tuple) else (grads,) + ): if isinstance(out, torch.Tensor) and out.requires_grad: output_tensors.append(out) grad_tensors.append(grad) @@ -174,14 +176,10 @@ def _apply_rotary_pos_emb_base( if start_positions is not None: max_offset = torch.max(start_positions) - assert ( - max_offset + cur_seq_len <= max_seq_len - ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" + assert max_offset + cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1) - assert ( - cur_seq_len <= max_seq_len - ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" + assert cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" freqs = freqs[:cur_seq_len] if tensor_format == "bshd": @@ -205,13 +203,9 @@ def apply_rotary_pos_emb( cu_seqlens: Union[torch.Tensor, None] = None, cp_size: int = 1, ) -> torch.Tensor: - assert not ( - cp_size > 1 and start_positions is not None - ), "start_positions != None with CP SIZE > 1 is not supported!" + assert not (cp_size > 1 and start_positions is not None), "start_positions != None with CP SIZE > 1 is not supported!" - assert ( - tensor_format != "thd" or cu_seqlens is not None - ), "cu_seqlens must not be None when tensor_format is 'thd'." + assert tensor_format != "thd" or cu_seqlens is not None, "cu_seqlens must not be None when tensor_format is 'thd'." assert fused == False @@ -223,9 +217,7 @@ def apply_rotary_pos_emb( _apply_rotary_pos_emb_base( x.unsqueeze(1), freqs, - start_positions=( - start_positions[idx : idx + 1] if start_positions is not None else None - ), + start_positions=(start_positions[idx : idx + 1] if start_positions is not None else None), interleaved=interleaved, ) for idx, x in enumerate(torch.split(t, seqlens)) @@ -262,7 +254,7 @@ class RMSNorm(torch.nn.Module): def _norm(self, x: torch.Tensor) -> torch.Tensor: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - @torch.amp.autocast(device_type='cuda', dtype=torch.float32) + @torch.amp.autocast(device_type="cuda", dtype=torch.float32) def forward(self, x: torch.Tensor) -> torch.Tensor: output = self._norm(x.float()).type_as(x) return output * self.weight @@ -308,9 +300,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) - result_B_S_HD = rearrange( - F.scaled_dot_product_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D), "b h ... l -> b ... (h l)" - ) + result_B_S_HD = rearrange(F.scaled_dot_product_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D), "b h ... l -> b ... (h l)") return result_B_S_HD @@ -412,7 +402,7 @@ class Attention(nn.Module): ) -> torch.Tensor: q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) # return self.compute_attention(q, k, v) - qkv = [q,k,v] + qkv = [q, k, v] del q, k, v result = attention.attention(qkv, attn_params=attn_params) return self.output_dropout(self.output_proj(result)) @@ -489,12 +479,8 @@ class VideoRopePosition3DEmb(VideoPositionEmb): dim_t = self._dim_t self.seq = torch.arange(max(self.max_h, self.max_w, self.max_t)).float().to(self.dim_spatial_range.device) - self.dim_spatial_range = ( - torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h - ) - self.dim_temporal_range = ( - torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t - ) + self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h + self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t def generate_embeddings( self, @@ -684,9 +670,7 @@ class FourierFeatures(nn.Module): def reset_parameters(self) -> None: generator = torch.Generator() generator.manual_seed(0) - self.freqs = ( - 2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device) - ) + self.freqs = 2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device) self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device) def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor: @@ -718,9 +702,7 @@ class PatchEmbed(nn.Module): m=spatial_patch_size, n=spatial_patch_size, ), - nn.Linear( - in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False - ), + nn.Linear(in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False), ) self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size @@ -770,9 +752,7 @@ class FinalLayer(nn.Module): nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False), ) else: - self.adaln_modulation = nn.Sequential( - nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False) - ) + self.adaln_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False)) self.init_weights() @@ -795,9 +775,9 @@ class FinalLayer(nn.Module): ): if self.use_adaln_lora: assert adaln_lora_B_T_3D is not None - shift_B_T_D, scale_B_T_D = ( - self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size] - ).chunk(2, dim=-1) + shift_B_T_D, scale_B_T_D = (self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]).chunk( + 2, dim=-1 + ) else: shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) @@ -838,7 +818,11 @@ class Block(nn.Module): self.layer_norm_cross_attn = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) self.cross_attn = Attention( - x_dim, context_dim, num_heads, x_dim // num_heads, qkv_format="bshd", + x_dim, + context_dim, + num_heads, + x_dim // num_heads, + qkv_format="bshd", ) self.layer_norm_mlp = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) @@ -925,13 +909,13 @@ class Block(nn.Module): shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D ).chunk(3, dim=-1) - shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = ( - self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D - ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk( + 3, dim=-1 + ) else: - shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn( - emb_B_T_D - ).chunk(3, dim=-1) + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(emb_B_T_D).chunk( + 3, dim=-1 + ) shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( emb_B_T_D ).chunk(3, dim=-1) @@ -965,7 +949,9 @@ class Block(nn.Module): rope_emb=rope_emb_L_1_1_D, ), "b (t h w) d -> b t h w d", - t=T, h=H, w=W, + t=T, + h=H, + w=W, ) x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result @@ -979,7 +965,9 @@ class Block(nn.Module): rope_emb=rope_emb_L_1_1_D, ), "b (t h w) d -> b t h w d", - t=T, h=H, w=W, + t=T, + h=H, + w=W, ) x_B_T_H_W_D = result * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D @@ -1005,8 +993,13 @@ class Block(nn.Module): # Unsloth: async non-blocking CPU RAM offload (fastest offload method) return unsloth_checkpoint( self._forward, - x_B_T_H_W_D, emb_B_T_D, crossattn_emb, attn_params, - rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, + x_B_T_H_W_D, + emb_B_T_D, + crossattn_emb, + attn_params, + rope_emb_L_1_1_D, + adaln_lora_B_T_3D, + extra_per_block_pos_emb, ) elif self.cpu_offload_checkpointing: # Standard cpu offload: blocking transfers @@ -1017,26 +1010,42 @@ class Block(nn.Module): device_inputs = to_device(inputs, device) outputs = func(*device_inputs) return to_cpu(outputs) + return custom_forward return torch_checkpoint( create_custom_forward(self._forward), - x_B_T_H_W_D, emb_B_T_D, crossattn_emb, attn_params, - rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, + x_B_T_H_W_D, + emb_B_T_D, + crossattn_emb, + attn_params, + rope_emb_L_1_1_D, + adaln_lora_B_T_3D, + extra_per_block_pos_emb, use_reentrant=False, ) else: # Standard gradient checkpointing (no offload) return torch_checkpoint( self._forward, - x_B_T_H_W_D, emb_B_T_D, crossattn_emb, attn_params, - rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, + x_B_T_H_W_D, + emb_B_T_D, + crossattn_emb, + attn_params, + rope_emb_L_1_1_D, + adaln_lora_B_T_3D, + extra_per_block_pos_emb, use_reentrant=False, ) else: return self._forward( - x_B_T_H_W_D, emb_B_T_D, crossattn_emb, attn_params, - rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, + x_B_T_H_W_D, + emb_B_T_D, + crossattn_emb, + attn_params, + rope_emb_L_1_1_D, + adaln_lora_B_T_3D, + extra_per_block_pos_emb, ) @@ -1078,7 +1087,7 @@ class MiniTrainDIT(nn.Module): extra_t_extrapolation_ratio: float = 1.0, rope_enable_fps_modulation: bool = True, use_llm_adapter: bool = False, - attn_mode: str = "torch", + attn_mode: str = "torch", split_attn: bool = False, ) -> None: super().__init__() @@ -1170,7 +1179,6 @@ class MiniTrainDIT(nn.Module): self.final_layer.init_weights() self.t_embedding_norm.reset_parameters() - def enable_gradient_checkpointing(self, cpu_offload: bool = False, unsloth_offload: bool = False): for block in self.blocks: block.enable_gradient_checkpointing(cpu_offload=cpu_offload, unsloth_offload=unsloth_offload) @@ -1183,7 +1191,6 @@ class MiniTrainDIT(nn.Module): def device(self): return next(self.parameters()).device - # def set_flash_attn(self, use_flash_attn: bool): # """Toggle flash attention for all DiT blocks (self-attn + cross-attn). @@ -1246,9 +1253,7 @@ class MiniTrainDIT(nn.Module): padding_mask = transforms.functional.resize( padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST ) - x_B_C_T_H_W = torch.cat( - [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 - ) + x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1) x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) if self.extra_per_block_abs_pos_emb: @@ -1272,7 +1277,6 @@ class MiniTrainDIT(nn.Module): ) return x_B_C_Tt_Hp_Wp - def enable_block_swap(self, num_blocks: int, device: torch.device): self.blocks_to_swap = num_blocks @@ -1280,9 +1284,7 @@ class MiniTrainDIT(nn.Module): self.blocks_to_swap <= self.num_blocks - 2 ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks." - self.offloader = custom_offloading_utils.ModelOffloader( - self.blocks, self.blocks_to_swap, device - ) + self.offloader = custom_offloading_utils.ModelOffloader(self.blocks, self.blocks_to_swap, device) logger.info(f"Anima: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.") def move_to_device_except_swap_blocks(self, device: torch.device): @@ -1324,7 +1326,7 @@ class MiniTrainDIT(nn.Module): t5_attn_mask: Optional T5 attention mask """ # Run LLM adapter inside forward for correct DDP gradient synchronization - if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, 'llm_adapter'): + if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, "llm_adapter"): crossattn_emb = self.llm_adapter( source_hidden_states=crossattn_emb, target_input_ids=t5_input_ids, @@ -1351,7 +1353,7 @@ class MiniTrainDIT(nn.Module): "extra_per_block_pos_emb": extra_pos_emb, } - attn_params= attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn) + attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn) for block_idx, block in enumerate(self.blocks): if self.blocks_to_swap: @@ -1502,24 +1504,36 @@ class LLMAdapterTransformerBlock(nn.Module): self.norm_mlp = nn.LayerNorm(model_dim) if layer_norm else LLMAdapterRMSNorm(model_dim) self.mlp = nn.Sequential( - nn.Linear(model_dim, int(model_dim * mlp_ratio)), - nn.GELU(), - nn.Linear(int(model_dim * mlp_ratio), model_dim) + nn.Linear(model_dim, int(model_dim * mlp_ratio)), nn.GELU(), nn.Linear(int(model_dim * mlp_ratio), model_dim) ) - def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, - position_embeddings=None, position_embeddings_context=None): + def forward( + self, + x, + context, + target_attention_mask=None, + source_attention_mask=None, + position_embeddings=None, + position_embeddings_context=None, + ): if self.has_self_attn: normed = self.norm_self_attn(x) - attn_out = self.self_attn(normed, mask=target_attention_mask, - position_embeddings=position_embeddings, - position_embeddings_context=position_embeddings) + attn_out = self.self_attn( + normed, + mask=target_attention_mask, + position_embeddings=position_embeddings, + position_embeddings_context=position_embeddings, + ) x = x + attn_out normed = self.norm_cross_attn(x) - attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, - position_embeddings=position_embeddings, - position_embeddings_context=position_embeddings_context) + attn_out = self.cross_attn( + normed, + mask=source_attention_mask, + context=context, + position_embeddings=position_embeddings, + position_embeddings_context=position_embeddings_context, + ) x = x + attn_out x = x + self.mlp(self.norm_mlp(x)) @@ -1535,8 +1549,9 @@ class LLMAdapter(nn.Module): Uses T5 token IDs as target input, embeds them, and cross-attends to Qwen3 hidden states. """ - def __init__(self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16, - embed=None, self_attn=False, layer_norm=False): + def __init__( + self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16, embed=None, self_attn=False, layer_norm=False + ): super().__init__() if embed is not None: self.embed = nn.Embedding.from_pretrained(embed.weight) @@ -1547,11 +1562,12 @@ class LLMAdapter(nn.Module): else: self.in_proj = nn.Identity() self.rotary_emb = AdapterRotaryEmbedding(model_dim // num_heads) - self.blocks = nn.ModuleList([ - LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, - self_attn=self_attn, layer_norm=layer_norm) - for _ in range(num_layers) - ]) + self.blocks = nn.ModuleList( + [ + LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, self_attn=self_attn, layer_norm=layer_norm) + for _ in range(num_layers) + ] + ) self.out_proj = nn.Linear(model_dim, target_dim) self.norm = LLMAdapterRMSNorm(target_dim) @@ -1573,10 +1589,14 @@ class LLMAdapter(nn.Module): position_embeddings = self.rotary_emb(x, position_ids) position_embeddings_context = self.rotary_emb(x, position_ids_context) for block in self.blocks: - x = block(x, context, target_attention_mask=target_attention_mask, - source_attention_mask=source_attention_mask, - position_embeddings=position_embeddings, - position_embeddings_context=position_embeddings_context) + x = block( + x, + context, + target_attention_mask=target_attention_mask, + source_attention_mask=source_attention_mask, + position_embeddings=position_embeddings, + position_embeddings_context=position_embeddings_context, + ) return self.norm(self.out_proj(x)) @@ -1584,6 +1604,7 @@ class Anima(nn.Module): """ Wrapper class for the MiniTrainDIT and LLM Adapter. """ + LATENT_CHANNELS = 16 def __init__(self, dit_config: dict): @@ -1593,7 +1614,7 @@ class Anima(nn.Module): @property def device(self): return self.net.device - + @property def dtype(self): return self.net.dtype @@ -1609,41 +1630,78 @@ class Anima(nn.Module): ) -> torch.Tensor: return self.net(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs) - def preprocess_text_embeds(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None): + def preprocess_text_embeds( + self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None + ): if target_input_ids is not None: - return self.net.llm_adapter(source_hidden_states, target_input_ids, target_attention_mask=target_attention_mask, - source_attention_mask=source_attention_mask) + return self.net.llm_adapter( + source_hidden_states, + target_input_ids, + target_attention_mask=target_attention_mask, + source_attention_mask=source_attention_mask, + ) else: return source_hidden_states + # VAE Wrapper # VAE normalization constants ANIMA_VAE_MEAN = [ - -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, - 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, ] ANIMA_VAE_STD = [ - 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, - 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, ] # DiT config detection from state_dict -KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer'] +KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"] -def get_dit_config(state_dict, key_prefix=''): +def get_dit_config(state_dict, key_prefix=""): """Derive DiT configuration from state_dict weight shapes.""" dit_config = {} dit_config["max_img_h"] = 512 dit_config["max_img_w"] = 512 dit_config["max_frames"] = 128 concat_padding_mask = True - dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask) + dit_config["in_channels"] = (state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[1] // 4) - int( + concat_padding_mask + ) dit_config["out_channels"] = 16 dit_config["patch_spatial"] = 2 dit_config["patch_temporal"] = 1 - dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0] + dit_config["model_channels"] = state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[0] dit_config["concat_padding_mask"] = concat_padding_mask dit_config["crossattn_emb_channels"] = 1024 dit_config["pos_emb_cls"] = "rope3d" diff --git a/library/anima_train_utils.py b/library/anima_train_utils.py index ef0016b5..edac2fb7 100644 --- a/library/anima_train_utils.py +++ b/library/anima_train_utils.py @@ -32,6 +32,7 @@ from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas # Anima-specific training arguments + def add_anima_training_arguments(parser: argparse.ArgumentParser): """Add Anima-specific training arguments to the parser.""" parser.add_argument( @@ -169,20 +170,20 @@ def get_noisy_model_input_and_timesteps( """ bs = latents.shape[0] - timestep_sample_method = getattr(args, 'timestep_sample_method', 'logit_normal') - sigmoid_scale = getattr(args, 'sigmoid_scale', 1.0) - shift = getattr(args, 'discrete_flow_shift', 1.0) + timestep_sample_method = getattr(args, "timestep_sample_method", "logit_normal") + sigmoid_scale = getattr(args, "sigmoid_scale", 1.0) + shift = getattr(args, "discrete_flow_shift", 1.0) - if timestep_sample_method == 'logit_normal': + if timestep_sample_method == "logit_normal": dist = torch.distributions.normal.Normal(0, 1) - elif timestep_sample_method == 'uniform': + elif timestep_sample_method == "uniform": dist = torch.distributions.uniform.Uniform(0, 1) else: raise NotImplementedError(f"Unknown timestep_sample_method: {timestep_sample_method}") t = dist.sample((bs,)).to(device) - if timestep_sample_method == 'logit_normal': + if timestep_sample_method == "logit_normal": t = t * sigmoid_scale t = torch.sigmoid(t) @@ -196,10 +197,10 @@ def get_noisy_model_input_and_timesteps( # Create noisy input: (1 - t) * latents + t * noise t_expanded = t.view(-1, *([1] * (latents.ndim - 1))) - ip_noise_gamma = getattr(args, 'ip_noise_gamma', None) + ip_noise_gamma = getattr(args, "ip_noise_gamma", None) if ip_noise_gamma: xi = torch.randn_like(latents, device=latents.device, dtype=dtype) - if getattr(args, 'ip_noise_gamma_random_strength', False): + if getattr(args, "ip_noise_gamma_random_strength", False): ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * ip_noise_gamma noisy_model_input = (1 - t_expanded) * latents + t_expanded * (noise + ip_noise_gamma * xi) else: @@ -213,6 +214,7 @@ def get_noisy_model_input_and_timesteps( # Loss weighting + def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor: """Compute loss weighting for Anima training. @@ -276,15 +278,15 @@ def get_anima_param_groups( # Store original name for debugging p.original_name = name - if 'llm_adapter' in name: + if "llm_adapter" in name: llm_adapter_params.append(p) - elif '.self_attn' in name: + elif ".self_attn" in name: self_attn_params.append(p) - elif '.cross_attn' in name: + elif ".cross_attn" in name: cross_attn_params.append(p) - elif '.mlp' in name: + elif ".mlp" in name: mlp_params.append(p) - elif '.adaln_modulation' in name: + elif ".adaln_modulation" in name: mod_params.append(p) else: base_params.append(p) @@ -311,9 +313,9 @@ def get_anima_param_groups( p.requires_grad_(False) logger.info(f" Frozen {name} params ({len(params)} parameters)") elif len(params) > 0: - param_groups.append({'params': params, 'lr': lr}) + param_groups.append({"params": params, "lr": lr}) - total_trainable = sum(p.numel() for group in param_groups for p in group['params'] if p.requires_grad) + total_trainable = sum(p.numel() for group in param_groups for p in group["params"] if p.requires_grad) logger.info(f"Total trainable parameters: {total_trainable:,}") return param_groups @@ -328,10 +330,9 @@ def save_anima_model_on_train_end( dit: anima_models.MiniTrainDIT, ): """Save Anima model at the end of training.""" + def sd_saver(ckpt_file, epoch_no, global_step): - sai_metadata = train_util.get_sai_model_spec( - None, args, False, False, False, is_stable_diffusion_ckpt=True - ) + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True) dit_sd = dit.state_dict() # Save with 'net.' prefix for ComfyUI compatibility anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype) @@ -350,10 +351,9 @@ def save_anima_model_on_epoch_end_or_stepwise( dit: anima_models.MiniTrainDIT, ): """Save Anima model at epoch end or specific steps.""" + def sd_saver(ckpt_file, epoch_no, global_step): - sai_metadata = train_util.get_sai_model_spec( - None, args, False, False, False, is_stable_diffusion_ckpt=True - ) + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True) dit_sd = dit.state_dict() anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype) @@ -410,9 +410,7 @@ def do_sample( generator = torch.manual_seed(seed) else: generator = None - noise = torch.randn( - latent.size(), dtype=torch.float32, generator=generator, device="cpu" - ).to(dtype).to(device) + noise = torch.randn(latent.size(), dtype=torch.float32, generator=generator, device="cpu").to(dtype).to(device) # Timestep schedule: linear from 1.0 to 0.0 sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype) @@ -512,10 +510,20 @@ def sample_images( with torch.no_grad(), accelerator.autocast(): for prompt_dict in prompts: _sample_image_inference( - accelerator, args, dit, text_encoder, vae, vae_scale, - tokenize_strategy, text_encoding_strategy, - save_dir, prompt_dict, epoch, steps, - sample_prompts_te_outputs, prompt_replacement, + accelerator, + args, + dit, + text_encoder, + vae, + vae_scale, + tokenize_strategy, + text_encoding_strategy, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, ) # Restore RNG state @@ -527,10 +535,20 @@ def sample_images( def _sample_image_inference( - accelerator, args, dit, text_encoder, vae, vae_scale, - tokenize_strategy, text_encoding_strategy, - save_dir, prompt_dict, epoch, steps, - sample_prompts_te_outputs, prompt_replacement, + accelerator, + args, + dit, + text_encoder, + vae, + vae_scale, + tokenize_strategy, + text_encoding_strategy, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, ): """Generate a single sample image.""" prompt = prompt_dict.get("prompt", "") @@ -585,7 +603,7 @@ def _sample_image_inference( t5_attn_mask = t5_attn_mask.to(accelerator.device) # Process through LLM adapter if available - if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'): + if dit.use_llm_adapter and hasattr(dit, "llm_adapter"): crossattn_emb = dit.llm_adapter( source_hidden_states=prompt_embeds, target_input_ids=t5_input_ids, @@ -613,7 +631,7 @@ def _sample_image_inference( neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long) neg_t5_am = neg_t5_am.to(accelerator.device) - if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'): + if dit.use_llm_adapter and hasattr(dit, "llm_adapter"): neg_crossattn_emb = dit.llm_adapter( source_hidden_states=neg_pe, target_input_ids=neg_t5_ids, @@ -627,9 +645,16 @@ def _sample_image_inference( # Generate sample clean_memory_on_device(accelerator.device) latents = do_sample( - height, width, seed, dit, crossattn_emb, - sample_steps, dit.t_embedding_norm.weight.dtype, - accelerator.device, scale, neg_crossattn_emb, + height, + width, + seed, + dit, + crossattn_emb, + sample_steps, + dit.t_embedding_norm.weight.dtype, + accelerator.device, + scale, + neg_crossattn_emb, ) # Decode latents @@ -662,4 +687,5 @@ def _sample_image_inference( if "wandb" in [tracker.name for tracker in accelerator.trackers]: wandb_tracker = accelerator.get_tracker("wandb") import wandb + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) diff --git a/library/anima_utils.py b/library/anima_utils.py index d8fb58ba..430c20c2 100644 --- a/library/anima_utils.py +++ b/library/anima_utils.py @@ -21,7 +21,7 @@ from library import anima_models # Keys that should stay in high precision (float32/bfloat16, not quantized) -KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer'] +KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"] def load_safetensors(path: str, device: str = "cpu", dtype: Optional[torch.dtype] = None) -> Dict[str, torch.Tensor]: @@ -56,6 +56,7 @@ def load_anima_dit( logger.info(f"Loading Anima DiT from {dit_path}") if disable_mmap: from library.safetensors_utils import load_safetensors as load_safetensors_no_mmap + state_dict = load_safetensors_no_mmap(dit_path, device="cpu", disable_mmap=True) else: state_dict = load_file(dit_path, device="cpu") @@ -63,8 +64,8 @@ def load_anima_dit( # Remove 'net.' prefix if present new_state_dict = {} for k, v in state_dict.items(): - if k.startswith('net.'): - k = k[len('net.'):] + if k.startswith("net."): + k = k[len("net.") :] new_state_dict[k] = v state_dict = new_state_dict @@ -74,18 +75,20 @@ def load_anima_dit( # Detect LLM adapter if llm_adapter_path is not None: use_llm_adapter = True - dit_config['use_llm_adapter'] = True + dit_config["use_llm_adapter"] = True llm_adapter_state_dict = load_safetensors(llm_adapter_path, device="cpu") - elif 'llm_adapter.out_proj.weight' in state_dict: + elif "llm_adapter.out_proj.weight" in state_dict: use_llm_adapter = True - dit_config['use_llm_adapter'] = True + dit_config["use_llm_adapter"] = True llm_adapter_state_dict = None # Loaded as part of DiT else: use_llm_adapter = False llm_adapter_state_dict = None - logger.info(f"DiT config: model_channels={dit_config['model_channels']}, num_blocks={dit_config['num_blocks']}, " - f"num_heads={dit_config['num_heads']}, use_llm_adapter={use_llm_adapter}") + logger.info( + f"DiT config: model_channels={dit_config['model_channels']}, num_blocks={dit_config['num_blocks']}, " + f"num_heads={dit_config['num_heads']}, use_llm_adapter={use_llm_adapter}" + ) # Build model normally on CPU — buffers get proper values from __init__ dit = anima_models.MiniTrainDIT(**dit_config) @@ -99,9 +102,11 @@ def load_anima_dit( missing, unexpected = dit.load_state_dict(state_dict, strict=False) if missing: # Filter out expected missing buffers (initialized in __init__, not saved in checkpoint) - unexpected_missing = [k for k in missing if not any( - buf_name in k for buf_name in ('seq', 'dim_spatial_range', 'dim_temporal_range', 'inv_freq') - )] + unexpected_missing = [ + k + for k in missing + if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq")) + ] if unexpected_missing: logger.warning(f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}") if unexpected: @@ -109,9 +114,7 @@ def load_anima_dit( # Apply per-parameter dtype (high precision for 1D/critical, transformer_dtype for rest) for name, p in dit.named_parameters(): - dtype_to_use = dtype if ( - any(keyword in name for keyword in KEEP_IN_HIGH_PRECISION) or p.ndim == 1 - ) else transformer_dtype + dtype_to_use = dtype if (any(keyword in name for keyword in KEEP_IN_HIGH_PRECISION) or p.ndim == 1) else transformer_dtype p.data = p.data.to(dtype=dtype_to_use) dit.to(device) @@ -156,7 +159,38 @@ def load_anima_model( loading_device = torch.device(loading_device) # We currently support fixed DiT config for Anima models - dit_config={'max_img_h': 512, 'max_img_w': 512, 'max_frames': 128, 'in_channels': 16, 'out_channels': 16, 'patch_spatial': 2, 'patch_temporal': 1, 'model_channels': 2048, 'concat_padding_mask': True, 'crossattn_emb_channels': 1024, 'pos_emb_cls': 'rope3d', 'pos_emb_learnable': True, 'pos_emb_interpolation': 'crop', 'min_fps': 1, 'max_fps': 30, 'use_adaln_lora': True, 'adaln_lora_dim': 256, 'num_blocks': 28, 'num_heads': 16, 'extra_per_block_abs_pos_emb': False, 'rope_h_extrapolation_ratio': 4.0, 'rope_w_extrapolation_ratio': 4.0, 'rope_t_extrapolation_ratio': 1.0, 'extra_h_extrapolation_ratio': 1.0, 'extra_w_extrapolation_ratio': 1.0, 'extra_t_extrapolation_ratio': 1.0, 'rope_enable_fps_modulation': False, 'use_llm_adapter': True, 'attn_mode': attn_mode, 'split_attn': split_attn} + dit_config = { + "max_img_h": 512, + "max_img_w": 512, + "max_frames": 128, + "in_channels": 16, + "out_channels": 16, + "patch_spatial": 2, + "patch_temporal": 1, + "model_channels": 2048, + "concat_padding_mask": True, + "crossattn_emb_channels": 1024, + "pos_emb_cls": "rope3d", + "pos_emb_learnable": True, + "pos_emb_interpolation": "crop", + "min_fps": 1, + "max_fps": 30, + "use_adaln_lora": True, + "adaln_lora_dim": 256, + "num_blocks": 28, + "num_heads": 16, + "extra_per_block_abs_pos_emb": False, + "rope_h_extrapolation_ratio": 4.0, + "rope_w_extrapolation_ratio": 4.0, + "rope_t_extrapolation_ratio": 1.0, + "extra_h_extrapolation_ratio": 1.0, + "extra_w_extrapolation_ratio": 1.0, + "extra_t_extrapolation_ratio": 1.0, + "rope_enable_fps_modulation": False, + "use_llm_adapter": True, + "attn_mode": attn_mode, + "split_attn": split_attn, + } # model = create_model(attn_mode, split_attn, dit_weight_dtype) with init_empty_weights(): model = anima_models.Anima(dit_config) @@ -190,12 +224,16 @@ def load_anima_model( missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) if missing: # Filter out expected missing buffers (initialized in __init__, not saved in checkpoint) - unexpected_missing = [k for k in missing if not any( - buf_name in k for buf_name in ('seq', 'dim_spatial_range', 'dim_temporal_range', 'inv_freq') - )] + unexpected_missing = [ + k + for k in missing + if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq")) + ] if unexpected_missing: # Raise error to avoid silent failures - raise RuntimeError(f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}") + raise RuntimeError( + f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}" + ) missing = {} # all missing keys were expected if unexpected: # Raise error to avoid silent failures @@ -205,7 +243,6 @@ def load_anima_model( return model - def load_anima_vae(vae_path: str, dtype: torch.dtype = torch.float32, device: str = "cpu"): """Load WanVAE from a safetensors/pth file. @@ -229,14 +266,14 @@ def load_anima_vae(vae_path: str, dtype: torch.dtype = torch.float32, device: st from library.anima_vae import WanVAE_ # Build model - with torch.device('meta'): + with torch.device("meta"): vae = WanVAE_(**vae_config) # Load state dict - if vae_path.endswith('.safetensors'): - vae_sd = load_file(vae_path, device='cpu') + if vae_path.endswith(".safetensors"): + vae_sd = load_file(vae_path, device="cpu") else: - vae_sd = torch.load(vae_path, map_location='cpu', weights_only=True) + vae_sd = torch.load(vae_path, map_location="cpu", weights_only=True) vae.load_state_dict(vae_sd, assign=True) vae = vae.eval().requires_grad_(False).to(device, dtype=dtype) @@ -265,7 +302,7 @@ def load_qwen3_tokenizer(qwen3_path: str): if os.path.isdir(qwen3_path): tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True) else: - config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b') + config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b") if not os.path.exists(config_dir): raise FileNotFoundError( f"Qwen3 config directory not found at {config_dir}. " @@ -299,12 +336,10 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16 if os.path.isdir(qwen3_path): # Directory with full model tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True) - model = transformers.AutoModelForCausalLM.from_pretrained( - qwen3_path, torch_dtype=dtype, local_files_only=True - ).model + model = transformers.AutoModelForCausalLM.from_pretrained(qwen3_path, torch_dtype=dtype, local_files_only=True).model else: # Single safetensors file - use configs/qwen3_06b/ for config - config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b') + config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b") if not os.path.exists(config_dir): raise FileNotFoundError( f"Qwen3 config directory not found at {config_dir}. " @@ -317,16 +352,16 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16 model = transformers.Qwen3ForCausalLM(qwen3_config).model # Load weights - if qwen3_path.endswith('.safetensors'): - state_dict = load_file(qwen3_path, device='cpu') + if qwen3_path.endswith(".safetensors"): + state_dict = load_file(qwen3_path, device="cpu") else: - state_dict = torch.load(qwen3_path, map_location='cpu', weights_only=True) + state_dict = torch.load(qwen3_path, map_location="cpu", weights_only=True) # Remove 'model.' prefix if present new_sd = {} for k, v in state_dict.items(): - if k.startswith('model.'): - new_sd[k[len('model.'):]] = v + if k.startswith("model."): + new_sd[k[len("model.") :]] = v else: new_sd[k] = v @@ -355,11 +390,11 @@ def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None): return T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True) # Use bundled config - config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 't5_old') + config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "t5_old") if os.path.exists(config_dir): return T5TokenizerFast( - vocab_file=os.path.join(config_dir, 'spiece.model'), - tokenizer_file=os.path.join(config_dir, 'tokenizer.json'), + vocab_file=os.path.join(config_dir, "spiece.model"), + tokenizer_file=os.path.join(config_dir, "tokenizer.json"), ) raise FileNotFoundError( @@ -381,9 +416,9 @@ def save_anima_model(save_path: str, dit_state_dict: Dict[str, torch.Tensor], dt for k, v in dit_state_dict.items(): if dtype is not None: v = v.to(dtype) - prefixed_sd['net.' + k] = v.contiguous() + prefixed_sd["net." + k] = v.contiguous() - save_file(prefixed_sd, save_path, metadata={'format': 'pt'}) + save_file(prefixed_sd, save_path, metadata={"format": "pt"}) logger.info(f"Saved Anima model to {save_path}") diff --git a/library/anima_vae.py b/library/anima_vae.py index 872bdfa2..3f6c7d1b 100644 --- a/library/anima_vae.py +++ b/library/anima_vae.py @@ -16,8 +16,7 @@ class CausalConv3d(nn.Conv3d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._padding = (self.padding[2], self.padding[2], self.padding[1], - self.padding[1], 2 * self.padding[0], 0) + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) self.padding = (0, 0, 0) def forward(self, x, cache_x=None): @@ -41,12 +40,10 @@ class RMS_norm(nn.Module): self.channel_first = channel_first self.scale = dim**0.5 self.gamma = nn.Parameter(torch.ones(shape)) - self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 def forward(self, x): - return F.normalize( - x, dim=(1 if self.channel_first else - -1)) * self.scale * self.gamma + self.bias + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias class Upsample(nn.Upsample): @@ -61,65 +58,48 @@ class Upsample(nn.Upsample): class Resample(nn.Module): def __init__(self, dim, mode): - assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', - 'downsample3d') + assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d") super().__init__() self.dim = dim self.mode = mode # layers - if mode == 'upsample2d': + if mode == "upsample2d": self.resample = nn.Sequential( - Upsample(scale_factor=(2., 2.), mode='nearest-exact'), - nn.Conv2d(dim, dim // 2, 3, padding=1)) - elif mode == 'upsample3d': + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + elif mode == "upsample3d": self.resample = nn.Sequential( - Upsample(scale_factor=(2., 2.), mode='nearest-exact'), - nn.Conv2d(dim, dim // 2, 3, padding=1)) - self.time_conv = CausalConv3d( - dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) - elif mode == 'downsample2d': - self.resample = nn.Sequential( - nn.ZeroPad2d((0, 1, 0, 1)), - nn.Conv2d(dim, dim, 3, stride=(2, 2))) - elif mode == 'downsample3d': - self.resample = nn.Sequential( - nn.ZeroPad2d((0, 1, 0, 1)), - nn.Conv2d(dim, dim, 3, stride=(2, 2))) - self.time_conv = CausalConv3d( - dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) else: self.resample = nn.Identity() def forward(self, x, feat_cache=None, feat_idx=[0]): b, c, t, h, w = x.size() - if self.mode == 'upsample3d': + if self.mode == "upsample3d": if feat_cache is not None: idx = feat_idx[0] if feat_cache[idx] is None: - feat_cache[idx] = 'Rep' + feat_cache[idx] = "Rep" feat_idx[0] += 1 else: cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] != 'Rep': + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] == 'Rep': - cache_x = torch.cat([ - torch.zeros_like(cache_x).to(cache_x.device), - cache_x - ], - dim=2) - if feat_cache[idx] == 'Rep': + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": x = self.time_conv(x) else: x = self.time_conv(x, feat_cache[idx]) @@ -127,15 +107,14 @@ class Resample(nn.Module): feat_idx[0] += 1 x = x.reshape(b, 2, c, t, h, w) - x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), - 3) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) x = x.reshape(b, c, t * 2, h, w) t = x.shape[2] - x = rearrange(x, 'b c t h w -> (b t) c h w') + x = rearrange(x, "b c t h w -> (b t) c h w") x = self.resample(x) - x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - if self.mode == 'downsample3d': + if self.mode == "downsample3d": if feat_cache is not None: idx = feat_idx[0] if feat_cache[idx] is None: @@ -144,8 +123,7 @@ class Resample(nn.Module): else: cache_x = x[:, :, -1:, :, :].clone() - x = self.time_conv( - torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) feat_cache[idx] = cache_x feat_idx[0] += 1 return x @@ -166,8 +144,8 @@ class Resample(nn.Module): nn.init.zeros_(conv_weight) c1, c2, t, h, w = conv_weight.size() init_matrix = torch.eye(c1 // 2, c2) - conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix - conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix conv.weight.data.copy_(conv_weight) nn.init.zeros_(conv.bias.data) @@ -181,12 +159,15 @@ class ResidualBlock(nn.Module): # layers self.residual = nn.Sequential( - RMS_norm(in_dim, images=False), nn.SiLU(), + RMS_norm(in_dim, images=False), + nn.SiLU(), CausalConv3d(in_dim, out_dim, 3, padding=1), - RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), - CausalConv3d(out_dim, out_dim, 3, padding=1)) - self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ - if in_dim != out_dim else nn.Identity() + RMS_norm(out_dim, images=False), + nn.SiLU(), + nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1), + ) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() def forward(self, x, feat_cache=None, feat_idx=[0]): h = self.shortcut(x) @@ -196,11 +177,7 @@ class ResidualBlock(nn.Module): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -229,13 +206,10 @@ class AttentionBlock(nn.Module): def forward(self, x): identity = x b, c, t, h, w = x.size() - x = rearrange(x, 'b c t h w -> (b t) c h w') + x = rearrange(x, "b c t h w -> (b t) c h w") x = self.norm(x) # compute query, key, value - q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, - -1).permute(0, 1, 3, - 2).contiguous().chunk( - 3, dim=-1) + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1) # apply attention x = F.scaled_dot_product_attention( @@ -247,20 +221,22 @@ class AttentionBlock(nn.Module): # output x = self.proj(x) - x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + x = rearrange(x, "(b t) c h w-> b c t h w", t=t) return x + identity class Encoder3d(nn.Module): - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], - dropout=0.0): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): super().__init__() self.dim = dim self.z_dim = z_dim @@ -288,21 +264,18 @@ class Encoder3d(nn.Module): # downsample block if i != len(dim_mult) - 1: - mode = 'downsample3d' if temperal_downsample[ - i] else 'downsample2d' + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" downsamples.append(Resample(out_dim, mode=mode)) scale /= 2.0 self.downsamples = nn.Sequential(*downsamples) # middle blocks self.middle = nn.Sequential( - ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), - ResidualBlock(out_dim, out_dim, dropout)) + ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout) + ) # output blocks - self.head = nn.Sequential( - RMS_norm(out_dim, images=False), nn.SiLU(), - CausalConv3d(out_dim, z_dim, 3, padding=1)) + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1)) def forward(self, x, feat_cache=None, feat_idx=[0]): if feat_cache is not None: @@ -310,11 +283,7 @@ class Encoder3d(nn.Module): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -342,11 +311,7 @@ class Encoder3d(nn.Module): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -357,14 +322,16 @@ class Encoder3d(nn.Module): class Decoder3d(nn.Module): - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_upsample=[False, True, True], - dropout=0.0): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + ): super().__init__() self.dim = dim self.z_dim = z_dim @@ -375,15 +342,15 @@ class Decoder3d(nn.Module): # dimensions dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] - scale = 1.0 / 2**(len(dim_mult) - 2) + scale = 1.0 / 2 ** (len(dim_mult) - 2) # init block self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) # middle blocks self.middle = nn.Sequential( - ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), - ResidualBlock(dims[0], dims[0], dropout)) + ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout) + ) # upsample blocks upsamples = [] @@ -399,15 +366,13 @@ class Decoder3d(nn.Module): # upsample block if i != len(dim_mult) - 1: - mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + mode = "upsample3d" if temperal_upsample[i] else "upsample2d" upsamples.append(Resample(out_dim, mode=mode)) scale *= 2.0 self.upsamples = nn.Sequential(*upsamples) # output blocks - self.head = nn.Sequential( - RMS_norm(out_dim, images=False), nn.SiLU(), - CausalConv3d(out_dim, 3, 3, padding=1)) + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1)) def forward(self, x, feat_cache=None, feat_idx=[0]): ## conv1 @@ -416,11 +381,7 @@ class Decoder3d(nn.Module): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -448,11 +409,7 @@ class Decoder3d(nn.Module): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -471,14 +428,16 @@ def count_conv3d(model): class WanVAE_(nn.Module): - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], - dropout=0.0): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): super().__init__() self.dim = dim self.z_dim = z_dim @@ -489,12 +448,10 @@ class WanVAE_(nn.Module): self.temperal_upsample = temperal_downsample[::-1] # modules - self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, - attn_scales, self.temperal_downsample, dropout) + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, - attn_scales, self.temperal_upsample, dropout) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) def forward(self, x): mu, log_var = self.encode(x) @@ -510,20 +467,15 @@ class WanVAE_(nn.Module): for i in range(iter_): self._enc_conv_idx = [0] if i == 0: - out = self.encoder( - x[:, :, :1, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) else: out_ = self.encoder( - x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx + ) out = torch.cat([out, out_], 2) mu, log_var = self.conv1(out).chunk(2, dim=1) if isinstance(scale[0], torch.Tensor): - mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( - 1, self.z_dim, 1, 1, 1) + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) else: mu = (mu - scale[0]) * scale[1] self.clear_cache() @@ -533,8 +485,7 @@ class WanVAE_(nn.Module): self.clear_cache() # z: [b,c,t,h,w] if isinstance(scale[0], torch.Tensor): - z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( - 1, self.z_dim, 1, 1, 1) + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) else: z = z / scale[1] + scale[0] iter_ = z.shape[2] @@ -542,15 +493,9 @@ class WanVAE_(nn.Module): for i in range(iter_): self._conv_idx = [0] if i == 0: - out = self.decoder( - x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) else: - out_ = self.decoder( - x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = torch.cat([out, out_], 2) self.clear_cache() return out @@ -571,7 +516,7 @@ class WanVAE_(nn.Module): self._conv_num = count_conv3d(self.decoder) self._conv_idx = [0] self._feat_map = [None] * self._conv_num - #cache encode + # cache encode self._enc_conv_num = count_conv3d(self.encoder) self._enc_conv_idx = [0] self._enc_feat_map = [None] * self._enc_conv_num diff --git a/library/strategy_anima.py b/library/strategy_anima.py index 398193ba..bee315f3 100644 --- a/library/strategy_anima.py +++ b/library/strategy_anima.py @@ -92,9 +92,9 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): # Cached unconditional embeddings (from encoding empty caption "") # Must be initialized via cache_uncond_embeddings() before text encoder is deleted self._uncond_prompt_embeds: Optional[torch.Tensor] = None # (1, seq_len, hidden) - self._uncond_attn_mask: Optional[torch.Tensor] = None # (1, seq_len) - self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len) - self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len) + self._uncond_attn_mask: Optional[torch.Tensor] = None # (1, seq_len) + self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len) + self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len) def cache_uncond_embeddings( self, @@ -182,8 +182,8 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): ) seq_len = qwen3_input_ids.shape[1] - hidden_size = (nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1]) - dtype = (nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype) + hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1] + dtype = nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype) attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype) @@ -203,7 +203,6 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask] - def drop_cached_text_encoder_outputs( self, prompt_embeds: torch.Tensor, @@ -367,18 +366,10 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy): return self.ANIMA_LATENTS_NPZ_SUFFIX def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: - return ( - os.path.splitext(absolute_path)[0] - + f"_{image_size[0]:04d}x{image_size[1]:04d}" - + self.ANIMA_LATENTS_NPZ_SUFFIX - ) + return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX - def is_disk_cached_latents_expected( - self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool - ): - return self._default_is_disk_cached_latents_expected( - 8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True - ) + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) def load_latents_from_disk( self, npz_path: str, bucket_reso: Tuple[int, int]