From df6b1bda335acb08321e7a084d55c54a7d983f8b Mon Sep 17 00:00:00 2001 From: Duoong Date: Fri, 6 Feb 2026 17:49:09 +0700 Subject: [PATCH] Fix typo --- anima_train.py | 43 +++++-------------------------------------- 1 file changed, 5 insertions(+), 38 deletions(-) diff --git a/anima_train.py b/anima_train.py index 10fa3fbf..9661cd90 100644 --- a/anima_train.py +++ b/anima_train.py @@ -1,5 +1,4 @@ # Anima full finetune training script -# Reference pattern: sd3_train.py import argparse from concurrent.futures import ThreadPoolExecutor @@ -203,9 +202,7 @@ def train(args): } transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None) - # ======================================== # Load tokenizers and set strategies - # ======================================== logger.info("Loading tokenizers...") qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder( args.qwen3_path, dtype=weight_dtype, device="cpu" @@ -231,15 +228,11 @@ def train(args): ) strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) - # ======================================== # Prepare text encoder (always frozen for Anima) - # ======================================== qwen3_text_encoder.to(weight_dtype) qwen3_text_encoder.requires_grad_(False) - # ======================================== # Cache text encoder outputs - # ======================================== sample_prompts_te_outputs = None if args.cache_text_encoder_outputs: qwen3_text_encoder.to(accelerator.device) @@ -281,9 +274,7 @@ def train(args): qwen3_text_encoder = None clean_memory_on_device(accelerator.device) - # ======================================== # Load VAE and cache latents - # ======================================== logger.info("Loading Anima VAE...") vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu") @@ -298,9 +289,7 @@ def train(args): clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() - # ======================================== # Load DiT (MiniTrainDIT + optional LLM Adapter) - # ======================================== logger.info("Loading Anima DiT...") dit = anima_utils.load_anima_dit( args.dit_path, @@ -325,9 +314,7 @@ def train(args): if not train_dit: dit.to(accelerator.device, dtype=weight_dtype) - # ======================================== # Block swap - # ======================================== is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 if is_swapping_blocks: logger.info(f"Enable block swap: blocks_to_swap={args.blocks_to_swap}") @@ -340,9 +327,7 @@ def train(args): # Move scale tensors to same device as VAE for on-the-fly encoding vae_scale = [s.to(accelerator.device) if isinstance(s, torch.Tensor) else s for s in vae_scale] - # ======================================== # Setup optimizer with parameter groups - # ======================================== if train_dit: param_groups = anima_train_utils.get_anima_param_groups( dit, @@ -476,15 +461,13 @@ def train(args): clean_memory_on_device(accelerator.device) - # ======================================== # Prepare with accelerator - # ======================================== # Temporarily move non-training models off GPU to reduce memory during DDP init - if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None: - qwen3_text_encoder.to("cpu") - if not cache_latents and vae is not None: - vae.to("cpu") - clean_memory_on_device(accelerator.device) + # if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None: + # qwen3_text_encoder.to("cpu") + # if not cache_latents and vae is not None: + # vae.to("cpu") + # clean_memory_on_device(accelerator.device) if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=dit) @@ -561,9 +544,7 @@ def train(args): parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 - # ======================================== # Training loop - # ======================================== num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): @@ -633,9 +614,7 @@ def train(args): optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step with accelerator.accumulate(*training_models): - # ============================== # Get latents - # ============================== if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) else: @@ -649,9 +628,7 @@ def train(args): accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) - # ============================== # Get text encoder outputs - # ============================== text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: # Cached outputs @@ -676,9 +653,7 @@ def train(args): t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long) t5_attn_mask = t5_attn_mask.to(accelerator.device) - # ============================== # Noise and timesteps - # ============================== noise = torch.randn_like(latents) noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps( @@ -690,9 +665,7 @@ def train(args): accelerator.print("NaN found in noisy_model_input, replacing with zeros") noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input) - # ============================== # Create padding mask - # ============================== # padding_mask: (B, 1, H_latent, W_latent) bs = latents.shape[0] h_latent = latents.shape[-2] @@ -702,9 +675,7 @@ def train(args): dtype=weight_dtype, device=accelerator.device ) - # ============================== # DiT forward (LLM adapter runs inside forward for DDP gradient sync) - # ============================== if is_swapping_blocks: accelerator.unwrap_model(dit).prepare_block_swap_before_forward() @@ -719,9 +690,7 @@ def train(args): t5_attn_mask=t5_attn_mask, ) - # ============================== # Compute loss (rectified flow: target = noise - latents) - # ============================== target = noise - latents # Weighting @@ -835,9 +804,7 @@ def train(args): sample_prompts_te_outputs, ) - # ======================================== # End training - # ======================================== is_main_process = accelerator.is_main_process dit = accelerator.unwrap_model(dit)