mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Fix latent normlization
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# Anima full finetune training script
|
||||
# Reference pattern: sd3_train.py
|
||||
|
||||
import argparse
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@@ -202,7 +203,9 @@ def train(args):
|
||||
}
|
||||
transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
|
||||
|
||||
# ========================================
|
||||
# Load tokenizers and set strategies
|
||||
# ========================================
|
||||
logger.info("Loading tokenizers...")
|
||||
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(
|
||||
args.qwen3_path, dtype=weight_dtype, device="cpu"
|
||||
@@ -228,11 +231,15 @@ def train(args):
|
||||
)
|
||||
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||
|
||||
# ========================================
|
||||
# Prepare text encoder (always frozen for Anima)
|
||||
# ========================================
|
||||
qwen3_text_encoder.to(weight_dtype)
|
||||
qwen3_text_encoder.requires_grad_(False)
|
||||
|
||||
# ========================================
|
||||
# Cache text encoder outputs
|
||||
# ========================================
|
||||
sample_prompts_te_outputs = None
|
||||
if args.cache_text_encoder_outputs:
|
||||
qwen3_text_encoder.to(accelerator.device)
|
||||
@@ -274,7 +281,9 @@ def train(args):
|
||||
qwen3_text_encoder = None
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# ========================================
|
||||
# Load VAE and cache latents
|
||||
# ========================================
|
||||
logger.info("Loading Anima VAE...")
|
||||
vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu")
|
||||
|
||||
@@ -289,7 +298,9 @@ def train(args):
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# ========================================
|
||||
# Load DiT (MiniTrainDIT + optional LLM Adapter)
|
||||
# ========================================
|
||||
logger.info("Loading Anima DiT...")
|
||||
dit = anima_utils.load_anima_dit(
|
||||
args.dit_path,
|
||||
@@ -314,7 +325,9 @@ def train(args):
|
||||
if not train_dit:
|
||||
dit.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# ========================================
|
||||
# Block swap
|
||||
# ========================================
|
||||
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
if is_swapping_blocks:
|
||||
logger.info(f"Enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
@@ -327,7 +340,9 @@ def train(args):
|
||||
# Move scale tensors to same device as VAE for on-the-fly encoding
|
||||
vae_scale = [s.to(accelerator.device) if isinstance(s, torch.Tensor) else s for s in vae_scale]
|
||||
|
||||
# ========================================
|
||||
# Setup optimizer with parameter groups
|
||||
# ========================================
|
||||
if train_dit:
|
||||
param_groups = anima_train_utils.get_anima_param_groups(
|
||||
dit,
|
||||
@@ -461,20 +476,16 @@ def train(args):
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# ========================================
|
||||
# Prepare with accelerator
|
||||
# Diagnostic: check for meta-device parameters/buffers that would cause DDP to hang
|
||||
if train_dit:
|
||||
meta_params = [n for n, p in dit.named_parameters() if p.device == torch.device('meta')]
|
||||
meta_buffers = [n for n, b in dit.named_buffers() if b.device == torch.device('meta')]
|
||||
if meta_params:
|
||||
logger.error(f"[rank {accelerator.process_index}] FATAL: {len(meta_params)} parameters on meta device: {meta_params[:10]}")
|
||||
if meta_buffers:
|
||||
logger.error(f"[rank {accelerator.process_index}] FATAL: {len(meta_buffers)} buffers on meta device: {meta_buffers[:10]}")
|
||||
n_params_total = sum(p.numel() for p in dit.parameters())
|
||||
n_buffers_total = sum(b.numel() for b in dit.buffers())
|
||||
logger.info(f"[rank {accelerator.process_index}] dit: {n_params_total:,} params, {n_buffers_total:,} buffer elements, device={next(dit.parameters()).device}")
|
||||
# ========================================
|
||||
# Temporarily move non-training models off GPU to reduce memory during DDP init
|
||||
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
||||
qwen3_text_encoder.to("cpu")
|
||||
if not cache_latents and vae is not None:
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
logger.info(f"[rank {accelerator.process_index}] entering accelerator.prepare()")
|
||||
if args.deepspeed:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=dit)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
@@ -483,14 +494,16 @@ def train(args):
|
||||
training_models = [ds_model]
|
||||
else:
|
||||
if train_dit:
|
||||
logger.info(f"[rank {accelerator.process_index}] preparing dit with DDP...")
|
||||
dit = accelerator.prepare(dit, device_placement=[not is_swapping_blocks])
|
||||
logger.info(f"[rank {accelerator.process_index}] dit prepared.")
|
||||
if is_swapping_blocks:
|
||||
accelerator.unwrap_model(dit).move_to_device_except_swap_blocks(accelerator.device)
|
||||
logger.info(f"[rank {accelerator.process_index}] preparing optimizer, dataloader, lr_scheduler...")
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
logger.info(f"[rank {accelerator.process_index}] all prepared.")
|
||||
|
||||
# Move non-training models back to GPU
|
||||
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
||||
qwen3_text_encoder.to(accelerator.device)
|
||||
if not cache_latents and vae is not None:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
@@ -548,8 +561,9 @@ def train(args):
|
||||
parameter_optimizer_map[parameter] = opt_idx
|
||||
num_parameters_per_group[opt_idx] += 1
|
||||
|
||||
|
||||
# ========================================
|
||||
# Training loop
|
||||
# ========================================
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
@@ -619,14 +633,15 @@ def train(args):
|
||||
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
||||
|
||||
with accelerator.accumulate(*training_models):
|
||||
# ==============================
|
||||
# Get latents
|
||||
# ==============================
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# images are [0, 1], need [-1, 1] and add temporal dim
|
||||
# images are already [-1, 1] from IMAGE_TRANSFORMS, add temporal dim
|
||||
images = batch["images"].to(accelerator.device, dtype=weight_dtype)
|
||||
images = images * 2.0 - 1.0
|
||||
images = images.unsqueeze(2) # (B, C, 1, H, W)
|
||||
latents = vae.encode(images, vae_scale).to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
@@ -634,7 +649,9 @@ def train(args):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
|
||||
# ==============================
|
||||
# Get text encoder outputs
|
||||
# ==============================
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
# Cached outputs
|
||||
@@ -659,7 +676,9 @@ def train(args):
|
||||
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||
|
||||
# ==============================
|
||||
# Noise and timesteps
|
||||
# ==============================
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
|
||||
@@ -671,7 +690,9 @@ def train(args):
|
||||
accelerator.print("NaN found in noisy_model_input, replacing with zeros")
|
||||
noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input)
|
||||
|
||||
# ==============================
|
||||
# Create padding mask
|
||||
# ==============================
|
||||
# padding_mask: (B, 1, H_latent, W_latent)
|
||||
bs = latents.shape[0]
|
||||
h_latent = latents.shape[-2]
|
||||
@@ -681,7 +702,9 @@ def train(args):
|
||||
dtype=weight_dtype, device=accelerator.device
|
||||
)
|
||||
|
||||
# ==============================
|
||||
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
|
||||
# ==============================
|
||||
if is_swapping_blocks:
|
||||
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
|
||||
|
||||
@@ -696,7 +719,9 @@ def train(args):
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
# ==============================
|
||||
# Compute loss (rectified flow: target = noise - latents)
|
||||
# ==============================
|
||||
target = noise - latents
|
||||
|
||||
# Weighting
|
||||
@@ -810,7 +835,9 @@ def train(args):
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
# ========================================
|
||||
# End training
|
||||
# ========================================
|
||||
is_main_process = accelerator.is_main_process
|
||||
dit = accelerator.unwrap_model(dit)
|
||||
|
||||
|
||||
@@ -267,8 +267,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, vae, images):
|
||||
# images are [0,1], need [-1,1] and temporal dim
|
||||
images = images * 2.0 - 1.0
|
||||
# images are already [-1,1] from IMAGE_TRANSFORMS, add temporal dim
|
||||
images = images.unsqueeze(2) # (B, C, 1, H, W)
|
||||
# Ensure scale tensors are on the same device as images
|
||||
vae_device = images.device
|
||||
|
||||
@@ -162,6 +162,15 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
prompt_embeds[non_drop_indices] = nd_encoded_text
|
||||
attn_mask[non_drop_indices] = nd_attn_mask
|
||||
|
||||
# Zero out t5_input_ids and t5_attn_mask for dropped items
|
||||
# so the LLM adapter sees a consistent unconditional signal
|
||||
t5_input_ids = t5_input_ids.clone()
|
||||
t5_attn_mask = t5_attn_mask.clone()
|
||||
drop_indices = [i for i in range(batch_size) if i not in non_drop_indices]
|
||||
for i in drop_indices:
|
||||
t5_input_ids[i] = torch.zeros_like(t5_input_ids[i])
|
||||
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
|
||||
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
def drop_cached_text_encoder_outputs(
|
||||
@@ -181,6 +190,8 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
prompt_embeds = prompt_embeds.clone()
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.clone()
|
||||
if t5_input_ids is not None:
|
||||
t5_input_ids = t5_input_ids.clone()
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask = t5_attn_mask.clone()
|
||||
|
||||
@@ -189,6 +200,8 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
prompt_embeds[i] = torch.zeros_like(prompt_embeds[i])
|
||||
if attn_mask is not None:
|
||||
attn_mask[i] = torch.zeros_like(attn_mask[i])
|
||||
if t5_input_ids is not None:
|
||||
t5_input_ids[i] = torch.zeros_like(t5_input_ids[i])
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
|
||||
|
||||
@@ -350,11 +363,9 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
def encode_by_vae(img_tensor):
|
||||
"""Encode image tensor to latents.
|
||||
|
||||
img_tensor: (B, C, H, W) in [0, 1] range
|
||||
Need to convert to (B, C, T=1, H, W) in [-1, 1] range for WanVAE
|
||||
img_tensor: (B, C, H, W) in [-1, 1] range (already normalized by IMAGE_TRANSFORMS)
|
||||
Need to add temporal dim to get (B, C, T=1, H, W) for WanVAE
|
||||
"""
|
||||
# Convert [0, 1] -> [-1, 1]
|
||||
img_tensor = img_tensor * 2.0 - 1.0
|
||||
# Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W)
|
||||
img_tensor = img_tensor.unsqueeze(2)
|
||||
img_tensor = img_tensor.to(vae_device, dtype=vae_dtype)
|
||||
|
||||
@@ -6138,7 +6138,8 @@ def conditional_loss(
|
||||
elif loss_type == "huber":
|
||||
if huber_c is None:
|
||||
raise NotImplementedError("huber_c not implemented correctly")
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
# Reshape huber_c to broadcast with model_pred (supports 4D and 5D tensors)
|
||||
huber_c = huber_c.view(-1, *([1] * (model_pred.ndim - 1)))
|
||||
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
loss = torch.mean(loss)
|
||||
@@ -6147,7 +6148,8 @@ def conditional_loss(
|
||||
elif loss_type == "smooth_l1":
|
||||
if huber_c is None:
|
||||
raise NotImplementedError("huber_c not implemented correctly")
|
||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
||||
# Reshape huber_c to broadcast with model_pred (supports 4D and 5D tensors)
|
||||
huber_c = huber_c.view(-1, *([1] * (model_pred.ndim - 1)))
|
||||
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||
if reduction == "mean":
|
||||
loss = torch.mean(loss)
|
||||
|
||||
Reference in New Issue
Block a user