mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Fix typo
This commit is contained in:
@@ -1,5 +1,4 @@
|
|||||||
# Anima full finetune training script
|
# Anima full finetune training script
|
||||||
# Reference pattern: sd3_train.py
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
@@ -203,9 +202,7 @@ def train(args):
|
|||||||
}
|
}
|
||||||
transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
|
transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Load tokenizers and set strategies
|
# Load tokenizers and set strategies
|
||||||
# ========================================
|
|
||||||
logger.info("Loading tokenizers...")
|
logger.info("Loading tokenizers...")
|
||||||
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(
|
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(
|
||||||
args.qwen3_path, dtype=weight_dtype, device="cpu"
|
args.qwen3_path, dtype=weight_dtype, device="cpu"
|
||||||
@@ -231,15 +228,11 @@ def train(args):
|
|||||||
)
|
)
|
||||||
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Prepare text encoder (always frozen for Anima)
|
# Prepare text encoder (always frozen for Anima)
|
||||||
# ========================================
|
|
||||||
qwen3_text_encoder.to(weight_dtype)
|
qwen3_text_encoder.to(weight_dtype)
|
||||||
qwen3_text_encoder.requires_grad_(False)
|
qwen3_text_encoder.requires_grad_(False)
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Cache text encoder outputs
|
# Cache text encoder outputs
|
||||||
# ========================================
|
|
||||||
sample_prompts_te_outputs = None
|
sample_prompts_te_outputs = None
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
qwen3_text_encoder.to(accelerator.device)
|
qwen3_text_encoder.to(accelerator.device)
|
||||||
@@ -281,9 +274,7 @@ def train(args):
|
|||||||
qwen3_text_encoder = None
|
qwen3_text_encoder = None
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Load VAE and cache latents
|
# Load VAE and cache latents
|
||||||
# ========================================
|
|
||||||
logger.info("Loading Anima VAE...")
|
logger.info("Loading Anima VAE...")
|
||||||
vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu")
|
vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu")
|
||||||
|
|
||||||
@@ -298,9 +289,7 @@ def train(args):
|
|||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Load DiT (MiniTrainDIT + optional LLM Adapter)
|
# Load DiT (MiniTrainDIT + optional LLM Adapter)
|
||||||
# ========================================
|
|
||||||
logger.info("Loading Anima DiT...")
|
logger.info("Loading Anima DiT...")
|
||||||
dit = anima_utils.load_anima_dit(
|
dit = anima_utils.load_anima_dit(
|
||||||
args.dit_path,
|
args.dit_path,
|
||||||
@@ -325,9 +314,7 @@ def train(args):
|
|||||||
if not train_dit:
|
if not train_dit:
|
||||||
dit.to(accelerator.device, dtype=weight_dtype)
|
dit.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Block swap
|
# Block swap
|
||||||
# ========================================
|
|
||||||
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||||
if is_swapping_blocks:
|
if is_swapping_blocks:
|
||||||
logger.info(f"Enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
logger.info(f"Enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||||
@@ -340,9 +327,7 @@ def train(args):
|
|||||||
# Move scale tensors to same device as VAE for on-the-fly encoding
|
# Move scale tensors to same device as VAE for on-the-fly encoding
|
||||||
vae_scale = [s.to(accelerator.device) if isinstance(s, torch.Tensor) else s for s in vae_scale]
|
vae_scale = [s.to(accelerator.device) if isinstance(s, torch.Tensor) else s for s in vae_scale]
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Setup optimizer with parameter groups
|
# Setup optimizer with parameter groups
|
||||||
# ========================================
|
|
||||||
if train_dit:
|
if train_dit:
|
||||||
param_groups = anima_train_utils.get_anima_param_groups(
|
param_groups = anima_train_utils.get_anima_param_groups(
|
||||||
dit,
|
dit,
|
||||||
@@ -476,15 +461,13 @@ def train(args):
|
|||||||
|
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Prepare with accelerator
|
# Prepare with accelerator
|
||||||
# ========================================
|
|
||||||
# Temporarily move non-training models off GPU to reduce memory during DDP init
|
# Temporarily move non-training models off GPU to reduce memory during DDP init
|
||||||
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
# if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
||||||
qwen3_text_encoder.to("cpu")
|
# qwen3_text_encoder.to("cpu")
|
||||||
if not cache_latents and vae is not None:
|
# if not cache_latents and vae is not None:
|
||||||
vae.to("cpu")
|
# vae.to("cpu")
|
||||||
clean_memory_on_device(accelerator.device)
|
# clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
if args.deepspeed:
|
if args.deepspeed:
|
||||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=dit)
|
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=dit)
|
||||||
@@ -561,9 +544,7 @@ def train(args):
|
|||||||
parameter_optimizer_map[parameter] = opt_idx
|
parameter_optimizer_map[parameter] = opt_idx
|
||||||
num_parameters_per_group[opt_idx] += 1
|
num_parameters_per_group[opt_idx] += 1
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Training loop
|
# Training loop
|
||||||
# ========================================
|
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||||
@@ -633,9 +614,7 @@ def train(args):
|
|||||||
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
||||||
|
|
||||||
with accelerator.accumulate(*training_models):
|
with accelerator.accumulate(*training_models):
|
||||||
# ==============================
|
|
||||||
# Get latents
|
# Get latents
|
||||||
# ==============================
|
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
|
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
|
||||||
else:
|
else:
|
||||||
@@ -649,9 +628,7 @@ def train(args):
|
|||||||
accelerator.print("NaN found in latents, replacing with zeros")
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||||
|
|
||||||
# ==============================
|
|
||||||
# Get text encoder outputs
|
# Get text encoder outputs
|
||||||
# ==============================
|
|
||||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||||
if text_encoder_outputs_list is not None:
|
if text_encoder_outputs_list is not None:
|
||||||
# Cached outputs
|
# Cached outputs
|
||||||
@@ -676,9 +653,7 @@ def train(args):
|
|||||||
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||||
|
|
||||||
# ==============================
|
|
||||||
# Noise and timesteps
|
# Noise and timesteps
|
||||||
# ==============================
|
|
||||||
noise = torch.randn_like(latents)
|
noise = torch.randn_like(latents)
|
||||||
|
|
||||||
noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
|
noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
|
||||||
@@ -690,9 +665,7 @@ def train(args):
|
|||||||
accelerator.print("NaN found in noisy_model_input, replacing with zeros")
|
accelerator.print("NaN found in noisy_model_input, replacing with zeros")
|
||||||
noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input)
|
noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input)
|
||||||
|
|
||||||
# ==============================
|
|
||||||
# Create padding mask
|
# Create padding mask
|
||||||
# ==============================
|
|
||||||
# padding_mask: (B, 1, H_latent, W_latent)
|
# padding_mask: (B, 1, H_latent, W_latent)
|
||||||
bs = latents.shape[0]
|
bs = latents.shape[0]
|
||||||
h_latent = latents.shape[-2]
|
h_latent = latents.shape[-2]
|
||||||
@@ -702,9 +675,7 @@ def train(args):
|
|||||||
dtype=weight_dtype, device=accelerator.device
|
dtype=weight_dtype, device=accelerator.device
|
||||||
)
|
)
|
||||||
|
|
||||||
# ==============================
|
|
||||||
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
|
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
|
||||||
# ==============================
|
|
||||||
if is_swapping_blocks:
|
if is_swapping_blocks:
|
||||||
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
|
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
|
||||||
|
|
||||||
@@ -719,9 +690,7 @@ def train(args):
|
|||||||
t5_attn_mask=t5_attn_mask,
|
t5_attn_mask=t5_attn_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ==============================
|
|
||||||
# Compute loss (rectified flow: target = noise - latents)
|
# Compute loss (rectified flow: target = noise - latents)
|
||||||
# ==============================
|
|
||||||
target = noise - latents
|
target = noise - latents
|
||||||
|
|
||||||
# Weighting
|
# Weighting
|
||||||
@@ -835,9 +804,7 @@ def train(args):
|
|||||||
sample_prompts_te_outputs,
|
sample_prompts_te_outputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# End training
|
# End training
|
||||||
# ========================================
|
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
dit = accelerator.unwrap_model(dit)
|
dit = accelerator.unwrap_model(dit)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user