mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Add/modify some implementation for anima (#2261)
* fix: update extend-exclude list in _typos.toml to include configs * fix: exclude anima tests from pytest * feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE * fix: update default value for --discrete_flow_shift in anima training guide * feat: add Qwen-Image VAE * feat: simplify encode_tokens * feat: use unified attention module, add wrapper for state dict compatibility * feat: loading with dynamic fp8 optimization and LoRA support * feat: add anima minimal inference script (WIP) * format: format * feat: simplify target module selection by regular expression patterns * feat: kept caption dropout rate in cache and handle in training script * feat: update train_llm_adapter and verbose default values to string type * fix: use strategy instead of using tokenizers directly * feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock * feat: support 5d tensor in get_noisy_model_input_and_timesteps * feat: update loss calculation to support 5d tensor * fix: update argument names in anima_train_utils to align with other archtectures * feat: simplify Anima training script and update empty caption handling * feat: support LoRA format without `net.` prefix * fix: update to work fp8_scaled option * feat: add regex-based learning rates and dimensions handling in create_network * fix: improve regex matching for module selection and learning rates in LoRANetwork * fix: update logging message for regex match in LoRANetwork * fix: keep latents 4D except DiT call * feat: enhance block swap functionality for inference and training in Anima model * feat: refactor Anima training script * feat: optimize VAE processing by adjusting tensor dimensions and data types * fix: wait all block trasfer before siwtching offloader mode * feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude! * feat: support LORA for Qwen3 * feat: update Anima SAI model spec metadata handling * fix: remove unused code * feat: split CFG processing in do_sample function to reduce memory usage * feat: add VAE chunking and caching options to reduce memory usage * feat: optimize RMSNorm forward method and remove unused torch_attention_op * Update library/strategy_anima.py Use torch.all instead of all. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/safetensors_utils.py Fix duplicated new_key for concat_hook. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_minimal_inference.py Remove unused code. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_train.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/anima_train_utils.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: review with Copilot * feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet) * feat: add process_escape function to handle escape sequences in prompts * feat: enhance LoRA weight handling in model loading and add text encoder loading function * feat: improve ComfyUI conversion script with prefix constants and module name adjustments * feat: update caption dropout documentation to clarify cache regeneration requirement * feat: add clarification on learning rate adjustments * feat: add note on PyTorch version requirement to prevent NaN loss --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
342
anima_train.py
342
anima_train.py
@@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import copy
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
@@ -12,8 +13,9 @@ import toml
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library import utils
|
||||
from library import flux_train_utils, qwen_image_autoencoder_kl
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -49,21 +51,18 @@ def train(args):
|
||||
args.skip_cache_check = args.skip_latents_validity_check
|
||||
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning(
|
||||
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
|
||||
)
|
||||
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
|
||||
logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
||||
args.gradient_checkpointing = True
|
||||
|
||||
if getattr(args, 'unsloth_offload_checkpointing', False):
|
||||
if args.unsloth_offload_checkpointing:
|
||||
if not args.gradient_checkpointing:
|
||||
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
||||
args.gradient_checkpointing = True
|
||||
assert not args.cpu_offload_checkpointing, \
|
||||
"Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
|
||||
assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
@@ -71,17 +70,7 @@ def train(args):
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not getattr(args, 'unsloth_offload_checkpointing', False), \
|
||||
"blocks_to_swap is not supported with unsloth_offload_checkpointing"
|
||||
|
||||
# Flash attention: validate availability
|
||||
if getattr(args, 'flash_attn', False):
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
logger.info("Flash Attention enabled for DiT blocks")
|
||||
except ImportError:
|
||||
logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
|
||||
args.flash_attn = False
|
||||
) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing"
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
@@ -104,9 +93,7 @@ def train(args):
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0}".format(", ".join(ignored))
|
||||
)
|
||||
logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored)))
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
@@ -145,26 +132,13 @@ def train(args):
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8
|
||||
|
||||
# Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
|
||||
# dataset-level caption dropout, so we save the rate and zero out subset-level
|
||||
# caption_dropout_rate to allow text encoder output caching.
|
||||
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||
if caption_dropout_rate > 0:
|
||||
logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
|
||||
for dataset in train_dataset_group.datasets:
|
||||
for subset in dataset.subsets:
|
||||
subset.caption_dropout_rate = 0.0
|
||||
train_dataset_group.verify_bucket_reso_steps(16) # Qwen-Image VAE spatial downscale = 8 * patch size = 2
|
||||
|
||||
if args.debug_dataset:
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
False,
|
||||
False,
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
|
||||
)
|
||||
)
|
||||
train_dataset_group.set_current_strategies()
|
||||
@@ -175,13 +149,11 @@ def train(args):
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used"
|
||||
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used"
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
assert train_dataset_group.is_text_encoder_output_cacheable(
|
||||
cache_supports_dropout=True
|
||||
), "when caching text encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
|
||||
|
||||
# prepare accelerator
|
||||
@@ -191,24 +163,10 @@ def train(args):
|
||||
# mixed precision dtype
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# parse transformer_dtype
|
||||
transformer_dtype = None
|
||||
if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None:
|
||||
transformer_dtype_map = {
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
|
||||
|
||||
# Load tokenizers and set strategies
|
||||
logger.info("Loading tokenizers...")
|
||||
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(
|
||||
args.qwen3_path, dtype=weight_dtype, device="cpu"
|
||||
)
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(
|
||||
getattr(args, 't5_tokenizer_path', None)
|
||||
)
|
||||
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(args.t5_tokenizer_path)
|
||||
|
||||
# Set tokenize strategy
|
||||
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
|
||||
@@ -219,11 +177,7 @@ def train(args):
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||
|
||||
# Set text encoding strategy
|
||||
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
|
||||
dropout_rate=caption_dropout_rate,
|
||||
)
|
||||
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
|
||||
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||
|
||||
# Prepare text encoder (always frozen for Anima)
|
||||
@@ -237,10 +191,7 @@ def train(args):
|
||||
qwen3_text_encoder.eval()
|
||||
|
||||
text_encoder_caching_strategy = strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
is_partial=False,
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, is_partial=False
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||
|
||||
@@ -259,27 +210,21 @@ def train(args):
|
||||
logger.info(f" cache TE outputs for: {p}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
[qwen3_text_encoder],
|
||||
tokens_and_masks,
|
||||
enable_dropout=False,
|
||||
tokenize_strategy, [qwen3_text_encoder], tokens_and_masks
|
||||
)
|
||||
|
||||
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
|
||||
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||
if caption_dropout_rate > 0.0:
|
||||
with accelerator.autocast():
|
||||
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# free text encoder memory
|
||||
qwen3_text_encoder = None
|
||||
gc.collect() # Force garbage collection to free memory
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# Load VAE and cache latents
|
||||
logger.info("Loading Anima VAE...")
|
||||
vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu")
|
||||
vae = qwen_image_autoencoder_kl.load_vae(
|
||||
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
|
||||
)
|
||||
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
@@ -294,24 +239,16 @@ def train(args):
|
||||
|
||||
# Load DiT (MiniTrainDIT + optional LLM Adapter)
|
||||
logger.info("Loading Anima DiT...")
|
||||
dit = anima_utils.load_anima_dit(
|
||||
args.dit_path,
|
||||
dtype=weight_dtype,
|
||||
device="cpu",
|
||||
transformer_dtype=transformer_dtype,
|
||||
llm_adapter_path=getattr(args, 'llm_adapter_path', None),
|
||||
disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False),
|
||||
dit = anima_utils.load_anima_model(
|
||||
"cpu", args.pretrained_model_name_or_path, args.attn_mode, args.split_attn, "cpu", dit_weight_dtype=None
|
||||
)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
dit.enable_gradient_checkpointing(
|
||||
cpu_offload=args.cpu_offload_checkpointing,
|
||||
unsloth_offload=getattr(args, 'unsloth_offload_checkpointing', False),
|
||||
unsloth_offload=args.unsloth_offload_checkpointing,
|
||||
)
|
||||
|
||||
if getattr(args, 'flash_attn', False):
|
||||
dit.set_flash_attn(True)
|
||||
|
||||
train_dit = args.learning_rate != 0
|
||||
dit.requires_grad_(train_dit)
|
||||
if not train_dit:
|
||||
@@ -327,19 +264,17 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
# Move scale tensors to same device as VAE for on-the-fly encoding
|
||||
vae_scale = [s.to(accelerator.device) if isinstance(s, torch.Tensor) else s for s in vae_scale]
|
||||
|
||||
# Setup optimizer with parameter groups
|
||||
if train_dit:
|
||||
param_groups = anima_train_utils.get_anima_param_groups(
|
||||
dit,
|
||||
base_lr=args.learning_rate,
|
||||
self_attn_lr=getattr(args, 'self_attn_lr', None),
|
||||
cross_attn_lr=getattr(args, 'cross_attn_lr', None),
|
||||
mlp_lr=getattr(args, 'mlp_lr', None),
|
||||
mod_lr=getattr(args, 'mod_lr', None),
|
||||
llm_adapter_lr=getattr(args, 'llm_adapter_lr', None),
|
||||
self_attn_lr=args.self_attn_lr,
|
||||
cross_attn_lr=args.cross_attn_lr,
|
||||
mlp_lr=args.mlp_lr,
|
||||
mod_lr=args.mod_lr,
|
||||
llm_adapter_lr=args.llm_adapter_lr,
|
||||
)
|
||||
else:
|
||||
param_groups = []
|
||||
@@ -361,57 +296,7 @@ def train(args):
|
||||
# prepare optimizer
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
# Split params into per-block groups for blockwise fused optimizer
|
||||
# Build param_id → lr mapping from param_groups to propagate per-component LRs
|
||||
param_lr_map = {}
|
||||
for group in param_groups:
|
||||
for p in group['params']:
|
||||
param_lr_map[id(p)] = group['lr']
|
||||
|
||||
grouped_params = []
|
||||
param_group = {}
|
||||
named_parameters = list(dit.named_parameters())
|
||||
for name, p in named_parameters:
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
# Determine block type and index
|
||||
if name.startswith("blocks."):
|
||||
block_index = int(name.split(".")[1])
|
||||
block_type = "blocks"
|
||||
elif name.startswith("llm_adapter.blocks."):
|
||||
block_index = int(name.split(".")[2])
|
||||
block_type = "llm_adapter"
|
||||
else:
|
||||
block_index = -1
|
||||
block_type = "other"
|
||||
|
||||
param_group_key = (block_type, block_index)
|
||||
if param_group_key not in param_group:
|
||||
param_group[param_group_key] = []
|
||||
param_group[param_group_key].append(p)
|
||||
|
||||
for param_group_key, params in param_group.items():
|
||||
# Use per-component LR from param_groups if available
|
||||
lr = param_lr_map.get(id(params[0]), args.learning_rate)
|
||||
grouped_params.append({"params": params, "lr": lr})
|
||||
num_params = sum(p.numel() for p in params)
|
||||
accelerator.print(f"block {param_group_key}: {num_params} parameters, lr={lr}")
|
||||
|
||||
# Create per-group optimizers
|
||||
optimizers = []
|
||||
for group in grouped_params:
|
||||
_, _, opt = train_util.get_optimizer(args, trainable_params=[group])
|
||||
optimizers.append(opt)
|
||||
optimizer = optimizers[0] # avoid error in following code
|
||||
|
||||
logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
|
||||
|
||||
if train_util.is_schedulefree_optimizer(optimizers[0], args):
|
||||
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
|
||||
optimizer_train_fn = lambda: None
|
||||
optimizer_eval_fn = lambda: None
|
||||
elif args.fused_backward_pass:
|
||||
if args.fused_backward_pass:
|
||||
# Pass per-component param_groups directly to preserve per-component LRs
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
@@ -442,21 +327,19 @@ def train(args):
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr scheduler
|
||||
if args.blockwise_fused_optimizers:
|
||||
lr_schedulers = [train_util.get_scheduler_fix(args, opt, accelerator.num_processes) for opt in optimizers]
|
||||
lr_scheduler = lr_schedulers[0] # avoid error in following code
|
||||
else:
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# full fp16/bf16 training
|
||||
dit_weight_dtype = weight_dtype
|
||||
if args.full_fp16:
|
||||
assert args.mixed_precision == "fp16", "full_fp16 requires mixed_precision='fp16'"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
dit.to(weight_dtype)
|
||||
elif args.full_bf16:
|
||||
assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
dit.to(weight_dtype)
|
||||
else:
|
||||
dit_weight_dtype = torch.float32 # If neither full_fp16 nor full_bf16, the model weights should be in float32
|
||||
dit.to(dit_weight_dtype) # convert dit to target weight dtype
|
||||
|
||||
# move text encoder to GPU if not cached
|
||||
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
||||
@@ -498,6 +381,7 @@ def train(args):
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
if args.fused_backward_pass:
|
||||
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
||||
import library.adafactor_fused
|
||||
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
@@ -517,55 +401,29 @@ def train(args):
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_group))
|
||||
|
||||
elif args.blockwise_fused_optimizers:
|
||||
# Prepare additional optimizers and lr schedulers
|
||||
for i in range(1, len(optimizers)):
|
||||
optimizers[i] = accelerator.prepare(optimizers[i])
|
||||
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
|
||||
|
||||
# Counters for blockwise gradient hook
|
||||
optimizer_hooked_count = {}
|
||||
num_parameters_per_group = [0] * len(optimizers)
|
||||
parameter_optimizer_map = {}
|
||||
|
||||
for opt_idx, opt in enumerate(optimizers):
|
||||
for param_group in opt.param_groups:
|
||||
for parameter in param_group["params"]:
|
||||
if parameter.requires_grad:
|
||||
|
||||
def grad_hook(parameter: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
|
||||
|
||||
i = parameter_optimizer_map[parameter]
|
||||
optimizer_hooked_count[i] += 1
|
||||
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
||||
optimizers[i].step()
|
||||
optimizers[i].zero_grad(set_to_none=True)
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||
parameter_optimizer_map[parameter] = opt_idx
|
||||
num_parameters_per_group[opt_idx] += 1
|
||||
|
||||
# Training loop
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
accelerator.print("running training")
|
||||
accelerator.print(f" num examples: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num batches per epoch: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs: {num_train_epochs}")
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps: {args.max_train_steps}")
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
# Copy for noise and timestep generation, because noise_scheduler may be changed during training in future
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
@@ -580,6 +438,7 @@ def train(args):
|
||||
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
import wandb
|
||||
|
||||
wandb.define_metric("epoch")
|
||||
wandb.define_metric("loss/epoch", step_metric="epoch")
|
||||
|
||||
@@ -589,8 +448,15 @@ def train(args):
|
||||
# For --sample_at_first
|
||||
optimizer_eval_fn()
|
||||
anima_train_utils.sample_images(
|
||||
accelerator, args, 0, global_step, dit, vae, vae_scale,
|
||||
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
|
||||
accelerator,
|
||||
args,
|
||||
0,
|
||||
global_step,
|
||||
dit,
|
||||
vae,
|
||||
qwen3_text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
optimizer_train_fn()
|
||||
@@ -600,11 +466,11 @@ def train(args):
|
||||
# Show model info
|
||||
unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None
|
||||
if unwrapped_dit is not None:
|
||||
logger.info(f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}")
|
||||
logger.info(f"dit device: {unwrapped_dit.device}, dtype: {unwrapped_dit.dtype}")
|
||||
if qwen3_text_encoder is not None:
|
||||
logger.info(f"qwen3 device: {next(qwen3_text_encoder.parameters()).device}")
|
||||
logger.info(f"qwen3 device: {qwen3_text_encoder.device}")
|
||||
if vae is not None:
|
||||
logger.info(f"vae device: {next(vae.parameters()).device}")
|
||||
logger.info(f"vae device: {vae.device}")
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
epoch = 0
|
||||
@@ -618,19 +484,17 @@ def train(args):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
||||
|
||||
with accelerator.accumulate(*training_models):
|
||||
# Get latents
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
|
||||
latents = batch["latents"].to(accelerator.device, dtype=dit_weight_dtype)
|
||||
if latents.ndim == 5: # Fallback for 5D latents (old cache)
|
||||
latents = latents.squeeze(2) # (B, C, 1, H, W) -> (B, C, H, W)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# images are already [-1, 1] from IMAGE_TRANSFORMS, add temporal dim
|
||||
images = batch["images"].to(accelerator.device, dtype=weight_dtype)
|
||||
images = images.unsqueeze(2) # (B, C, 1, H, W)
|
||||
latents = vae.encode(images, vae_scale).to(accelerator.device, dtype=weight_dtype)
|
||||
latents = vae.encode_pixels_to_latents(images).to(accelerator.device, dtype=dit_weight_dtype)
|
||||
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
@@ -640,23 +504,24 @@ def train(args):
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
# Cached outputs
|
||||
caption_dropout_rates = text_encoder_outputs_list[-1]
|
||||
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
|
||||
|
||||
# Apply caption dropout to cached outputs
|
||||
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(
|
||||
*text_encoder_outputs_list
|
||||
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
|
||||
)
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list
|
||||
else:
|
||||
# Encode on-the-fly
|
||||
input_ids_list = batch["input_ids_list"]
|
||||
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = input_ids_list
|
||||
with torch.no_grad():
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
[qwen3_text_encoder],
|
||||
[qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask],
|
||||
tokenize_strategy, [qwen3_text_encoder], input_ids_list
|
||||
)
|
||||
|
||||
# Move to device
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype)
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit_weight_dtype)
|
||||
attn_mask = attn_mask.to(accelerator.device)
|
||||
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||
@@ -664,9 +529,11 @@ def train(args):
|
||||
# Noise and timesteps
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, latents, noise, accelerator.device, weight_dtype
|
||||
# Get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler_copy, latents, noise, accelerator.device, dit_weight_dtype
|
||||
)
|
||||
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
|
||||
|
||||
# NaN checks
|
||||
if torch.any(torch.isnan(noisy_model_input)):
|
||||
@@ -678,15 +545,10 @@ def train(args):
|
||||
bs = latents.shape[0]
|
||||
h_latent = latents.shape[-2]
|
||||
w_latent = latents.shape[-1]
|
||||
padding_mask = torch.zeros(
|
||||
bs, 1, h_latent, w_latent,
|
||||
dtype=weight_dtype, device=accelerator.device
|
||||
)
|
||||
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=dit_weight_dtype, device=accelerator.device)
|
||||
|
||||
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
|
||||
if is_swapping_blocks:
|
||||
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
|
||||
|
||||
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, (B, C, 1, H, W)
|
||||
with accelerator.autocast():
|
||||
model_pred = dit(
|
||||
noisy_model_input,
|
||||
@@ -697,6 +559,7 @@ def train(args):
|
||||
t5_input_ids=t5_input_ids,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
)
|
||||
model_pred = model_pred.squeeze(2) # 5D to 4D, (B, C, H, W)
|
||||
|
||||
# Compute loss (rectified flow: target = noise - latents)
|
||||
target = noise - latents
|
||||
@@ -708,12 +571,10 @@ def train(args):
|
||||
|
||||
# Loss
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None)
|
||||
loss = train_util.conditional_loss(
|
||||
model_pred.float(), target.float(), args.loss_type, "none", huber_c
|
||||
)
|
||||
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3, 4]) # (B, C, T, H, W) -> (B,)
|
||||
loss = loss.mean([1, 2, 3]) # (B, C, H, W) -> (B,)
|
||||
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
@@ -724,7 +585,7 @@ def train(args):
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
|
||||
if not args.fused_backward_pass:
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
@@ -737,9 +598,6 @@ def train(args):
|
||||
else:
|
||||
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
||||
lr_scheduler.step()
|
||||
if args.blockwise_fused_optimizers:
|
||||
for i in range(1, len(optimizers)):
|
||||
lr_schedulers[i].step()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step
|
||||
if accelerator.sync_gradients:
|
||||
@@ -748,8 +606,15 @@ def train(args):
|
||||
|
||||
optimizer_eval_fn()
|
||||
anima_train_utils.sample_images(
|
||||
accelerator, args, None, global_step, dit, vae, vae_scale,
|
||||
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
|
||||
accelerator,
|
||||
args,
|
||||
None,
|
||||
global_step,
|
||||
dit,
|
||||
vae,
|
||||
qwen3_text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
@@ -773,8 +638,10 @@ def train(args):
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs_with_names(
|
||||
logs, lr_scheduler, args.optimizer_type,
|
||||
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else []
|
||||
logs,
|
||||
lr_scheduler,
|
||||
args.optimizer_type,
|
||||
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [],
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
@@ -807,8 +674,15 @@ def train(args):
|
||||
)
|
||||
|
||||
anima_train_utils.sample_images(
|
||||
accelerator, args, epoch + 1, global_step, dit, vae, vae_scale,
|
||||
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
|
||||
accelerator,
|
||||
args,
|
||||
epoch + 1,
|
||||
global_step,
|
||||
dit,
|
||||
vae,
|
||||
qwen3_text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
@@ -852,11 +726,6 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
anima_train_utils.add_anima_training_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--blockwise_fused_optimizers",
|
||||
action="store_true",
|
||||
help="enable blockwise optimizers for fused backward pass and optimizer step",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu_offload_checkpointing",
|
||||
action="store_true",
|
||||
@@ -884,4 +753,7 @@ if __name__ == "__main__":
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
if args.attn_mode == "sdpa":
|
||||
args.attn_mode = "torch" # backward compatibility
|
||||
|
||||
train(args)
|
||||
|
||||
Reference in New Issue
Block a user