This commit is contained in:
Duoong
2026-02-06 17:49:09 +07:00
parent 0c636e8a6f
commit df6b1bda33

View File

@@ -1,5 +1,4 @@
# Anima full finetune training script # Anima full finetune training script
# Reference pattern: sd3_train.py
import argparse import argparse
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@@ -203,9 +202,7 @@ def train(args):
} }
transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None) transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
# ========================================
# Load tokenizers and set strategies # Load tokenizers and set strategies
# ========================================
logger.info("Loading tokenizers...") logger.info("Loading tokenizers...")
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder( qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(
args.qwen3_path, dtype=weight_dtype, device="cpu" args.qwen3_path, dtype=weight_dtype, device="cpu"
@@ -231,15 +228,11 @@ def train(args):
) )
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# ========================================
# Prepare text encoder (always frozen for Anima) # Prepare text encoder (always frozen for Anima)
# ========================================
qwen3_text_encoder.to(weight_dtype) qwen3_text_encoder.to(weight_dtype)
qwen3_text_encoder.requires_grad_(False) qwen3_text_encoder.requires_grad_(False)
# ========================================
# Cache text encoder outputs # Cache text encoder outputs
# ========================================
sample_prompts_te_outputs = None sample_prompts_te_outputs = None
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
qwen3_text_encoder.to(accelerator.device) qwen3_text_encoder.to(accelerator.device)
@@ -281,9 +274,7 @@ def train(args):
qwen3_text_encoder = None qwen3_text_encoder = None
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
# ========================================
# Load VAE and cache latents # Load VAE and cache latents
# ========================================
logger.info("Loading Anima VAE...") logger.info("Loading Anima VAE...")
vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu") vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu")
@@ -298,9 +289,7 @@ def train(args):
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# ========================================
# Load DiT (MiniTrainDIT + optional LLM Adapter) # Load DiT (MiniTrainDIT + optional LLM Adapter)
# ========================================
logger.info("Loading Anima DiT...") logger.info("Loading Anima DiT...")
dit = anima_utils.load_anima_dit( dit = anima_utils.load_anima_dit(
args.dit_path, args.dit_path,
@@ -325,9 +314,7 @@ def train(args):
if not train_dit: if not train_dit:
dit.to(accelerator.device, dtype=weight_dtype) dit.to(accelerator.device, dtype=weight_dtype)
# ========================================
# Block swap # Block swap
# ========================================
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if is_swapping_blocks: if is_swapping_blocks:
logger.info(f"Enable block swap: blocks_to_swap={args.blocks_to_swap}") logger.info(f"Enable block swap: blocks_to_swap={args.blocks_to_swap}")
@@ -340,9 +327,7 @@ def train(args):
# Move scale tensors to same device as VAE for on-the-fly encoding # Move scale tensors to same device as VAE for on-the-fly encoding
vae_scale = [s.to(accelerator.device) if isinstance(s, torch.Tensor) else s for s in vae_scale] vae_scale = [s.to(accelerator.device) if isinstance(s, torch.Tensor) else s for s in vae_scale]
# ========================================
# Setup optimizer with parameter groups # Setup optimizer with parameter groups
# ========================================
if train_dit: if train_dit:
param_groups = anima_train_utils.get_anima_param_groups( param_groups = anima_train_utils.get_anima_param_groups(
dit, dit,
@@ -476,15 +461,13 @@ def train(args):
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
# ========================================
# Prepare with accelerator # Prepare with accelerator
# ========================================
# Temporarily move non-training models off GPU to reduce memory during DDP init # Temporarily move non-training models off GPU to reduce memory during DDP init
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None: # if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
qwen3_text_encoder.to("cpu") # qwen3_text_encoder.to("cpu")
if not cache_latents and vae is not None: # if not cache_latents and vae is not None:
vae.to("cpu") # vae.to("cpu")
clean_memory_on_device(accelerator.device) # clean_memory_on_device(accelerator.device)
if args.deepspeed: if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=dit) ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=dit)
@@ -561,9 +544,7 @@ def train(args):
parameter_optimizer_map[parameter] = opt_idx parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1 num_parameters_per_group[opt_idx] += 1
# ========================================
# Training loop # Training loop
# ========================================
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
@@ -633,9 +614,7 @@ def train(args):
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
with accelerator.accumulate(*training_models): with accelerator.accumulate(*training_models):
# ==============================
# Get latents # Get latents
# ==============================
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype) latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
else: else:
@@ -649,9 +628,7 @@ def train(args):
accelerator.print("NaN found in latents, replacing with zeros") accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents) latents = torch.nan_to_num(latents, 0, out=latents)
# ==============================
# Get text encoder outputs # Get text encoder outputs
# ==============================
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None: if text_encoder_outputs_list is not None:
# Cached outputs # Cached outputs
@@ -676,9 +653,7 @@ def train(args):
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long) t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
t5_attn_mask = t5_attn_mask.to(accelerator.device) t5_attn_mask = t5_attn_mask.to(accelerator.device)
# ==============================
# Noise and timesteps # Noise and timesteps
# ==============================
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps( noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
@@ -690,9 +665,7 @@ def train(args):
accelerator.print("NaN found in noisy_model_input, replacing with zeros") accelerator.print("NaN found in noisy_model_input, replacing with zeros")
noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input) noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input)
# ==============================
# Create padding mask # Create padding mask
# ==============================
# padding_mask: (B, 1, H_latent, W_latent) # padding_mask: (B, 1, H_latent, W_latent)
bs = latents.shape[0] bs = latents.shape[0]
h_latent = latents.shape[-2] h_latent = latents.shape[-2]
@@ -702,9 +675,7 @@ def train(args):
dtype=weight_dtype, device=accelerator.device dtype=weight_dtype, device=accelerator.device
) )
# ==============================
# DiT forward (LLM adapter runs inside forward for DDP gradient sync) # DiT forward (LLM adapter runs inside forward for DDP gradient sync)
# ==============================
if is_swapping_blocks: if is_swapping_blocks:
accelerator.unwrap_model(dit).prepare_block_swap_before_forward() accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
@@ -719,9 +690,7 @@ def train(args):
t5_attn_mask=t5_attn_mask, t5_attn_mask=t5_attn_mask,
) )
# ==============================
# Compute loss (rectified flow: target = noise - latents) # Compute loss (rectified flow: target = noise - latents)
# ==============================
target = noise - latents target = noise - latents
# Weighting # Weighting
@@ -835,9 +804,7 @@ def train(args):
sample_prompts_te_outputs, sample_prompts_te_outputs,
) )
# ========================================
# End training # End training
# ========================================
is_main_process = accelerator.is_main_process is_main_process = accelerator.is_main_process
dit = accelerator.unwrap_model(dit) dit = accelerator.unwrap_model(dit)