diff --git a/anima_train.py b/anima_train.py index bce9c77e..10fa3fbf 100644 --- a/anima_train.py +++ b/anima_train.py @@ -1,4 +1,5 @@ # Anima full finetune training script +# Reference pattern: sd3_train.py import argparse from concurrent.futures import ThreadPoolExecutor @@ -202,7 +203,9 @@ def train(args): } transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None) + # ======================================== # Load tokenizers and set strategies + # ======================================== logger.info("Loading tokenizers...") qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder( args.qwen3_path, dtype=weight_dtype, device="cpu" @@ -228,11 +231,15 @@ def train(args): ) strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) + # ======================================== # Prepare text encoder (always frozen for Anima) + # ======================================== qwen3_text_encoder.to(weight_dtype) qwen3_text_encoder.requires_grad_(False) + # ======================================== # Cache text encoder outputs + # ======================================== sample_prompts_te_outputs = None if args.cache_text_encoder_outputs: qwen3_text_encoder.to(accelerator.device) @@ -274,7 +281,9 @@ def train(args): qwen3_text_encoder = None clean_memory_on_device(accelerator.device) + # ======================================== # Load VAE and cache latents + # ======================================== logger.info("Loading Anima VAE...") vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu") @@ -289,7 +298,9 @@ def train(args): clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + # ======================================== # Load DiT (MiniTrainDIT + optional LLM Adapter) + # ======================================== logger.info("Loading Anima DiT...") dit = anima_utils.load_anima_dit( args.dit_path, @@ -314,7 +325,9 @@ def train(args): if not train_dit: dit.to(accelerator.device, dtype=weight_dtype) + # ======================================== # Block swap + # ======================================== is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 if is_swapping_blocks: logger.info(f"Enable block swap: blocks_to_swap={args.blocks_to_swap}") @@ -327,7 +340,9 @@ def train(args): # Move scale tensors to same device as VAE for on-the-fly encoding vae_scale = [s.to(accelerator.device) if isinstance(s, torch.Tensor) else s for s in vae_scale] + # ======================================== # Setup optimizer with parameter groups + # ======================================== if train_dit: param_groups = anima_train_utils.get_anima_param_groups( dit, @@ -461,20 +476,16 @@ def train(args): clean_memory_on_device(accelerator.device) + # ======================================== # Prepare with accelerator - # Diagnostic: check for meta-device parameters/buffers that would cause DDP to hang - if train_dit: - meta_params = [n for n, p in dit.named_parameters() if p.device == torch.device('meta')] - meta_buffers = [n for n, b in dit.named_buffers() if b.device == torch.device('meta')] - if meta_params: - logger.error(f"[rank {accelerator.process_index}] FATAL: {len(meta_params)} parameters on meta device: {meta_params[:10]}") - if meta_buffers: - logger.error(f"[rank {accelerator.process_index}] FATAL: {len(meta_buffers)} buffers on meta device: {meta_buffers[:10]}") - n_params_total = sum(p.numel() for p in dit.parameters()) - n_buffers_total = sum(b.numel() for b in dit.buffers()) - logger.info(f"[rank {accelerator.process_index}] dit: {n_params_total:,} params, {n_buffers_total:,} buffer elements, device={next(dit.parameters()).device}") + # ======================================== + # Temporarily move non-training models off GPU to reduce memory during DDP init + if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None: + qwen3_text_encoder.to("cpu") + if not cache_latents and vae is not None: + vae.to("cpu") + clean_memory_on_device(accelerator.device) - logger.info(f"[rank {accelerator.process_index}] entering accelerator.prepare()") if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=dit) ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -483,14 +494,16 @@ def train(args): training_models = [ds_model] else: if train_dit: - logger.info(f"[rank {accelerator.process_index}] preparing dit with DDP...") dit = accelerator.prepare(dit, device_placement=[not is_swapping_blocks]) - logger.info(f"[rank {accelerator.process_index}] dit prepared.") if is_swapping_blocks: accelerator.unwrap_model(dit).move_to_device_except_swap_blocks(accelerator.device) - logger.info(f"[rank {accelerator.process_index}] preparing optimizer, dataloader, lr_scheduler...") optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) - logger.info(f"[rank {accelerator.process_index}] all prepared.") + + # Move non-training models back to GPU + if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None: + qwen3_text_encoder.to(accelerator.device) + if not cache_latents and vae is not None: + vae.to(accelerator.device, dtype=weight_dtype) if args.full_fp16: train_util.patch_accelerator_for_fp16_training(accelerator) @@ -548,8 +561,9 @@ def train(args): parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 - + # ======================================== # Training loop + # ======================================== num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): @@ -619,14 +633,15 @@ def train(args): optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step with accelerator.accumulate(*training_models): + # ============================== # Get latents + # ============================== if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) else: with torch.no_grad(): - # images are [0, 1], need [-1, 1] and add temporal dim + # images are already [-1, 1] from IMAGE_TRANSFORMS, add temporal dim images = batch["images"].to(accelerator.device, dtype=weight_dtype) - images = images * 2.0 - 1.0 images = images.unsqueeze(2) # (B, C, 1, H, W) latents = vae.encode(images, vae_scale).to(accelerator.device, dtype=weight_dtype) @@ -634,7 +649,9 @@ def train(args): accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) + # ============================== # Get text encoder outputs + # ============================== text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: # Cached outputs @@ -659,7 +676,9 @@ def train(args): t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long) t5_attn_mask = t5_attn_mask.to(accelerator.device) + # ============================== # Noise and timesteps + # ============================== noise = torch.randn_like(latents) noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps( @@ -671,7 +690,9 @@ def train(args): accelerator.print("NaN found in noisy_model_input, replacing with zeros") noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input) + # ============================== # Create padding mask + # ============================== # padding_mask: (B, 1, H_latent, W_latent) bs = latents.shape[0] h_latent = latents.shape[-2] @@ -681,7 +702,9 @@ def train(args): dtype=weight_dtype, device=accelerator.device ) + # ============================== # DiT forward (LLM adapter runs inside forward for DDP gradient sync) + # ============================== if is_swapping_blocks: accelerator.unwrap_model(dit).prepare_block_swap_before_forward() @@ -696,7 +719,9 @@ def train(args): t5_attn_mask=t5_attn_mask, ) + # ============================== # Compute loss (rectified flow: target = noise - latents) + # ============================== target = noise - latents # Weighting @@ -810,7 +835,9 @@ def train(args): sample_prompts_te_outputs, ) + # ======================================== # End training + # ======================================== is_main_process = accelerator.is_main_process dit = accelerator.unwrap_model(dit) diff --git a/anima_train_network.py b/anima_train_network.py index 5d532ae0..e4fb46e3 100644 --- a/anima_train_network.py +++ b/anima_train_network.py @@ -267,8 +267,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): return noise_scheduler def encode_images_to_latents(self, args, vae, images): - # images are [0,1], need [-1,1] and temporal dim - images = images * 2.0 - 1.0 + # images are already [-1,1] from IMAGE_TRANSFORMS, add temporal dim images = images.unsqueeze(2) # (B, C, 1, H, W) # Ensure scale tensors are on the same device as images vae_device = images.device diff --git a/library/strategy_anima.py b/library/strategy_anima.py index f93adf43..a545f9e9 100644 --- a/library/strategy_anima.py +++ b/library/strategy_anima.py @@ -162,6 +162,15 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): prompt_embeds[non_drop_indices] = nd_encoded_text attn_mask[non_drop_indices] = nd_attn_mask + # Zero out t5_input_ids and t5_attn_mask for dropped items + # so the LLM adapter sees a consistent unconditional signal + t5_input_ids = t5_input_ids.clone() + t5_attn_mask = t5_attn_mask.clone() + drop_indices = [i for i in range(batch_size) if i not in non_drop_indices] + for i in drop_indices: + t5_input_ids[i] = torch.zeros_like(t5_input_ids[i]) + t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) + return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask] def drop_cached_text_encoder_outputs( @@ -181,6 +190,8 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): prompt_embeds = prompt_embeds.clone() if attn_mask is not None: attn_mask = attn_mask.clone() + if t5_input_ids is not None: + t5_input_ids = t5_input_ids.clone() if t5_attn_mask is not None: t5_attn_mask = t5_attn_mask.clone() @@ -189,6 +200,8 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): prompt_embeds[i] = torch.zeros_like(prompt_embeds[i]) if attn_mask is not None: attn_mask[i] = torch.zeros_like(attn_mask[i]) + if t5_input_ids is not None: + t5_input_ids[i] = torch.zeros_like(t5_input_ids[i]) if t5_attn_mask is not None: t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) @@ -350,11 +363,9 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy): def encode_by_vae(img_tensor): """Encode image tensor to latents. - img_tensor: (B, C, H, W) in [0, 1] range - Need to convert to (B, C, T=1, H, W) in [-1, 1] range for WanVAE + img_tensor: (B, C, H, W) in [-1, 1] range (already normalized by IMAGE_TRANSFORMS) + Need to add temporal dim to get (B, C, T=1, H, W) for WanVAE """ - # Convert [0, 1] -> [-1, 1] - img_tensor = img_tensor * 2.0 - 1.0 # Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W) img_tensor = img_tensor.unsqueeze(2) img_tensor = img_tensor.to(vae_device, dtype=vae_dtype) diff --git a/library/train_util.py b/library/train_util.py index a1900609..6874076d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6138,7 +6138,8 @@ def conditional_loss( elif loss_type == "huber": if huber_c is None: raise NotImplementedError("huber_c not implemented correctly") - huber_c = huber_c.view(-1, 1, 1, 1) + # Reshape huber_c to broadcast with model_pred (supports 4D and 5D tensors) + huber_c = huber_c.view(-1, *([1] * (model_pred.ndim - 1))) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) @@ -6147,7 +6148,8 @@ def conditional_loss( elif loss_type == "smooth_l1": if huber_c is None: raise NotImplementedError("huber_c not implemented correctly") - huber_c = huber_c.view(-1, 1, 1, 1) + # Reshape huber_c to broadcast with model_pred (supports 4D and 5D tensors) + huber_c = huber_c.view(-1, *([1] * (model_pred.ndim - 1))) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss)