mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Compare commits
7 Commits
f3b6e59900
...
f355a97a32
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f355a97a32 | ||
|
|
dbb40ae4c0 | ||
|
|
4992aae311 | ||
|
|
8d2d286a13 | ||
|
|
6d08c93b23 | ||
|
|
02a75944b3 | ||
|
|
6a4e392445 |
@@ -286,6 +286,8 @@ def decode_latent(vae: WanVAE_, latent: torch.Tensor, device: torch.device) -> t
|
||||
with torch.no_grad():
|
||||
pixels = vae.decode_to_pixels(latent.to(device, dtype=vae.dtype))
|
||||
# pixels = vae.decode(latent.to(device, dtype=torch.bfloat16), scale=vae_scale)
|
||||
if pixels.ndim == 5: # remove frame dimension if exists, [B, C, F, H, W] -> [B, C, H, W]
|
||||
pixels = pixels.squeeze(2)
|
||||
|
||||
pixels = pixels.to("cpu", dtype=torch.float32) # move to CPU and convert to float32 (bfloat16 is not supported by numpy)
|
||||
vae.to("cpu")
|
||||
@@ -719,6 +721,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
|
||||
# 1. Prepare VAE
|
||||
logger.info("Loading VAE for batch generation...")
|
||||
vae_for_batch = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
|
||||
vae_for_batch.to(torch.bfloat16)
|
||||
vae_for_batch.eval()
|
||||
|
||||
all_prompt_args_list = [apply_overrides(args, pd) for pd in prompts_data] # Create all arg instances first
|
||||
@@ -840,6 +843,7 @@ def process_interactive(args: argparse.Namespace) -> None:
|
||||
shared_models["conds_cache"] = {} # Initialize empty cache for interactive mode
|
||||
|
||||
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
|
||||
vae.to(torch.bfloat16)
|
||||
vae.eval()
|
||||
|
||||
print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):")
|
||||
@@ -965,6 +969,7 @@ def main():
|
||||
args.seed = seeds[i]
|
||||
|
||||
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device=device, disable_mmap=True)
|
||||
vae.to(torch.bfloat16)
|
||||
vae.eval()
|
||||
save_output(args, vae, latent, device, original_base_names[i])
|
||||
|
||||
@@ -1009,7 +1014,7 @@ def main():
|
||||
|
||||
# Save latent and video
|
||||
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
|
||||
|
||||
vae.to(torch.bfloat16)
|
||||
vae.eval()
|
||||
save_output(args, vae, latent, device)
|
||||
|
||||
|
||||
290
anima_train.py
290
anima_train.py
@@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import copy
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
@@ -12,8 +13,9 @@ import toml
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library import utils
|
||||
from library import flux_train_utils, qwen_image_autoencoder_kl, utils
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -56,7 +58,7 @@ def train(args):
|
||||
logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
||||
args.gradient_checkpointing = True
|
||||
|
||||
if getattr(args, "unsloth_offload_checkpointing", False):
|
||||
if args.unsloth_offload_checkpointing:
|
||||
if not args.gradient_checkpointing:
|
||||
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
||||
args.gradient_checkpointing = True
|
||||
@@ -66,19 +68,19 @@ def train(args):
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
|
||||
|
||||
assert (args.blocks_to_swap is None or args.blocks_to_swap == 0) or not getattr(
|
||||
args, "unsloth_offload_checkpointing", False
|
||||
), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing"
|
||||
|
||||
# Flash attention: validate availability
|
||||
if getattr(args, "flash_attn", False):
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
# # Flash attention: validate availability
|
||||
# if args.flash_attn:
|
||||
# try:
|
||||
# import flash_attn # noqa: F401
|
||||
|
||||
logger.info("Flash Attention enabled for DiT blocks")
|
||||
except ImportError:
|
||||
logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
|
||||
args.flash_attn = False
|
||||
# logger.info("Flash Attention enabled for DiT blocks")
|
||||
# except ImportError:
|
||||
# logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
|
||||
# args.flash_attn = False
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
@@ -140,26 +142,13 @@ def train(args):
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8
|
||||
|
||||
# Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
|
||||
# dataset-level caption dropout, so we save the rate and zero out subset-level
|
||||
# caption_dropout_rate to allow text encoder output caching.
|
||||
caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0)
|
||||
if caption_dropout_rate > 0:
|
||||
logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
|
||||
for dataset in train_dataset_group.datasets:
|
||||
for subset in dataset.subsets:
|
||||
subset.caption_dropout_rate = 0.0
|
||||
train_dataset_group.verify_bucket_reso_steps(16) # Qwen-Image VAE spatial downscale = 8 * patch size = 2
|
||||
|
||||
if args.debug_dataset:
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
False,
|
||||
False,
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
|
||||
)
|
||||
)
|
||||
train_dataset_group.set_current_strategies()
|
||||
@@ -173,8 +162,8 @@ def train(args):
|
||||
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used"
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
assert train_dataset_group.is_text_encoder_output_cacheable(
|
||||
cache_supports_dropout=True
|
||||
), "when caching text encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
|
||||
|
||||
# prepare accelerator
|
||||
@@ -184,20 +173,10 @@ def train(args):
|
||||
# mixed precision dtype
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# parse transformer_dtype
|
||||
transformer_dtype = None
|
||||
if hasattr(args, "transformer_dtype") and args.transformer_dtype is not None:
|
||||
transformer_dtype_map = {
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
|
||||
|
||||
# Load tokenizers and set strategies
|
||||
logger.info("Loading tokenizers...")
|
||||
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=weight_dtype, device="cpu")
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
|
||||
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(args.t5_tokenizer_path)
|
||||
|
||||
# Set tokenize strategy
|
||||
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
|
||||
@@ -208,11 +187,7 @@ def train(args):
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||
|
||||
# Set text encoding strategy
|
||||
caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0)
|
||||
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
|
||||
dropout_rate=caption_dropout_rate,
|
||||
)
|
||||
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
|
||||
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||
|
||||
# Prepare text encoder (always frozen for Anima)
|
||||
@@ -226,10 +201,7 @@ def train(args):
|
||||
qwen3_text_encoder.eval()
|
||||
|
||||
text_encoder_caching_strategy = strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
is_partial=False,
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, is_partial=False
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||
|
||||
@@ -248,25 +220,19 @@ def train(args):
|
||||
logger.info(f" cache TE outputs for: {p}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
[qwen3_text_encoder],
|
||||
tokens_and_masks,
|
||||
enable_dropout=False,
|
||||
tokenize_strategy, [qwen3_text_encoder], tokens_and_masks
|
||||
)
|
||||
|
||||
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
|
||||
with accelerator.autocast():
|
||||
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# free text encoder memory
|
||||
qwen3_text_encoder = None
|
||||
gc.collect() # Force garbage collection to free memory
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# Load VAE and cache latents
|
||||
logger.info("Loading Anima VAE...")
|
||||
vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu")
|
||||
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu")
|
||||
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
@@ -281,24 +247,16 @@ def train(args):
|
||||
|
||||
# Load DiT (MiniTrainDIT + optional LLM Adapter)
|
||||
logger.info("Loading Anima DiT...")
|
||||
dit = anima_utils.load_anima_dit(
|
||||
args.dit_path,
|
||||
dtype=weight_dtype,
|
||||
device="cpu",
|
||||
transformer_dtype=transformer_dtype,
|
||||
llm_adapter_path=getattr(args, "llm_adapter_path", None),
|
||||
disable_mmap=getattr(args, "disable_mmap_load_safetensors", False),
|
||||
dit = anima_utils.load_anima_model(
|
||||
"cpu", args.pretrained_model_name_or_path, args.attn_mode, args.split_attn, "cpu", dit_weight_dtype=None
|
||||
)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
dit.enable_gradient_checkpointing(
|
||||
cpu_offload=args.cpu_offload_checkpointing,
|
||||
unsloth_offload=getattr(args, "unsloth_offload_checkpointing", False),
|
||||
unsloth_offload=args.unsloth_offload_checkpointing,
|
||||
)
|
||||
|
||||
if getattr(args, "flash_attn", False):
|
||||
dit.set_flash_attn(True)
|
||||
|
||||
train_dit = args.learning_rate != 0
|
||||
dit.requires_grad_(train_dit)
|
||||
if not train_dit:
|
||||
@@ -314,19 +272,17 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
# Move scale tensors to same device as VAE for on-the-fly encoding
|
||||
vae_scale = [s.to(accelerator.device) if isinstance(s, torch.Tensor) else s for s in vae_scale]
|
||||
|
||||
# Setup optimizer with parameter groups
|
||||
if train_dit:
|
||||
param_groups = anima_train_utils.get_anima_param_groups(
|
||||
dit,
|
||||
base_lr=args.learning_rate,
|
||||
self_attn_lr=getattr(args, "self_attn_lr", None),
|
||||
cross_attn_lr=getattr(args, "cross_attn_lr", None),
|
||||
mlp_lr=getattr(args, "mlp_lr", None),
|
||||
mod_lr=getattr(args, "mod_lr", None),
|
||||
llm_adapter_lr=getattr(args, "llm_adapter_lr", None),
|
||||
self_attn_lr=args.self_attn_lr,
|
||||
cross_attn_lr=args.cross_attn_lr,
|
||||
mlp_lr=args.mlp_lr,
|
||||
mod_lr=args.mod_lr,
|
||||
llm_adapter_lr=args.llm_adapter_lr,
|
||||
)
|
||||
else:
|
||||
param_groups = []
|
||||
@@ -348,57 +304,7 @@ def train(args):
|
||||
# prepare optimizer
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
# Split params into per-block groups for blockwise fused optimizer
|
||||
# Build param_id → lr mapping from param_groups to propagate per-component LRs
|
||||
param_lr_map = {}
|
||||
for group in param_groups:
|
||||
for p in group["params"]:
|
||||
param_lr_map[id(p)] = group["lr"]
|
||||
|
||||
grouped_params = []
|
||||
param_group = {}
|
||||
named_parameters = list(dit.named_parameters())
|
||||
for name, p in named_parameters:
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
# Determine block type and index
|
||||
if name.startswith("blocks."):
|
||||
block_index = int(name.split(".")[1])
|
||||
block_type = "blocks"
|
||||
elif name.startswith("llm_adapter.blocks."):
|
||||
block_index = int(name.split(".")[2])
|
||||
block_type = "llm_adapter"
|
||||
else:
|
||||
block_index = -1
|
||||
block_type = "other"
|
||||
|
||||
param_group_key = (block_type, block_index)
|
||||
if param_group_key not in param_group:
|
||||
param_group[param_group_key] = []
|
||||
param_group[param_group_key].append(p)
|
||||
|
||||
for param_group_key, params in param_group.items():
|
||||
# Use per-component LR from param_groups if available
|
||||
lr = param_lr_map.get(id(params[0]), args.learning_rate)
|
||||
grouped_params.append({"params": params, "lr": lr})
|
||||
num_params = sum(p.numel() for p in params)
|
||||
accelerator.print(f"block {param_group_key}: {num_params} parameters, lr={lr}")
|
||||
|
||||
# Create per-group optimizers
|
||||
optimizers = []
|
||||
for group in grouped_params:
|
||||
_, _, opt = train_util.get_optimizer(args, trainable_params=[group])
|
||||
optimizers.append(opt)
|
||||
optimizer = optimizers[0] # avoid error in following code
|
||||
|
||||
logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
|
||||
|
||||
if train_util.is_schedulefree_optimizer(optimizers[0], args):
|
||||
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
|
||||
optimizer_train_fn = lambda: None
|
||||
optimizer_eval_fn = lambda: None
|
||||
elif args.fused_backward_pass:
|
||||
if args.fused_backward_pass:
|
||||
# Pass per-component param_groups directly to preserve per-component LRs
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
@@ -429,21 +335,19 @@ def train(args):
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr scheduler
|
||||
if args.blockwise_fused_optimizers:
|
||||
lr_schedulers = [train_util.get_scheduler_fix(args, opt, accelerator.num_processes) for opt in optimizers]
|
||||
lr_scheduler = lr_schedulers[0] # avoid error in following code
|
||||
else:
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# full fp16/bf16 training
|
||||
dit_weight_dtype = weight_dtype
|
||||
if args.full_fp16:
|
||||
assert args.mixed_precision == "fp16", "full_fp16 requires mixed_precision='fp16'"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
dit.to(weight_dtype)
|
||||
elif args.full_bf16:
|
||||
assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
dit.to(weight_dtype)
|
||||
else:
|
||||
dit_weight_dtype = torch.float32 # Default to float32
|
||||
dit.to(dit_weight_dtype) # convert dit to target weight dtype
|
||||
|
||||
# move text encoder to GPU if not cached
|
||||
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
||||
@@ -485,6 +389,7 @@ def train(args):
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
if args.fused_backward_pass:
|
||||
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
||||
import library.adafactor_fused
|
||||
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
@@ -504,53 +409,28 @@ def train(args):
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_group))
|
||||
|
||||
elif args.blockwise_fused_optimizers:
|
||||
# Prepare additional optimizers and lr schedulers
|
||||
for i in range(1, len(optimizers)):
|
||||
optimizers[i] = accelerator.prepare(optimizers[i])
|
||||
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
|
||||
|
||||
# Counters for blockwise gradient hook
|
||||
optimizer_hooked_count = {}
|
||||
num_parameters_per_group = [0] * len(optimizers)
|
||||
parameter_optimizer_map = {}
|
||||
|
||||
for opt_idx, opt in enumerate(optimizers):
|
||||
for param_group in opt.param_groups:
|
||||
for parameter in param_group["params"]:
|
||||
if parameter.requires_grad:
|
||||
|
||||
def grad_hook(parameter: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
|
||||
|
||||
i = parameter_optimizer_map[parameter]
|
||||
optimizer_hooked_count[i] += 1
|
||||
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
||||
optimizers[i].step()
|
||||
optimizers[i].zero_grad(set_to_none=True)
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||
parameter_optimizer_map[parameter] = opt_idx
|
||||
num_parameters_per_group[opt_idx] += 1
|
||||
|
||||
# Training loop
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
accelerator.print("running training")
|
||||
accelerator.print(f" num examples: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num batches per epoch: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs: {num_train_epochs}")
|
||||
accelerator.print(f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
||||
accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps: {args.max_train_steps}")
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
@@ -581,7 +461,6 @@ def train(args):
|
||||
global_step,
|
||||
dit,
|
||||
vae,
|
||||
vae_scale,
|
||||
qwen3_text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
@@ -594,13 +473,11 @@ def train(args):
|
||||
# Show model info
|
||||
unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None
|
||||
if unwrapped_dit is not None:
|
||||
logger.info(
|
||||
f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}"
|
||||
)
|
||||
logger.info(f"dit device: {unwrapped_dit.device}, dtype: {unwrapped_dit.dtype}")
|
||||
if qwen3_text_encoder is not None:
|
||||
logger.info(f"qwen3 device: {next(qwen3_text_encoder.parameters()).device}")
|
||||
logger.info(f"qwen3 device: {qwen3_text_encoder.device}")
|
||||
if vae is not None:
|
||||
logger.info(f"vae device: {next(vae.parameters()).device}")
|
||||
logger.info(f"vae device: {vae.device}")
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
epoch = 0
|
||||
@@ -614,19 +491,17 @@ def train(args):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
||||
|
||||
with accelerator.accumulate(*training_models):
|
||||
# Get latents
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
|
||||
latents = batch["latents"].to(accelerator.device, dtype=dit_weight_dtype)
|
||||
if latents.ndim == 5: # Fallback for 5D latents (old cache)
|
||||
latents = latents.squeeze(2) # (B, C, 1, H, W) -> (B, C, H, W)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# images are already [-1, 1] from IMAGE_TRANSFORMS, add temporal dim
|
||||
images = batch["images"].to(accelerator.device, dtype=weight_dtype)
|
||||
images = images.unsqueeze(2) # (B, C, 1, H, W)
|
||||
latents = vae.encode(images, vae_scale).to(accelerator.device, dtype=weight_dtype)
|
||||
latents = vae.encode_pixels_to_latents(images).to(accelerator.device, dtype=dit_weight_dtype)
|
||||
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
@@ -636,21 +511,24 @@ def train(args):
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
# Cached outputs
|
||||
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
|
||||
caption_dropout_rates = text_encoder_outputs_list[-1]
|
||||
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
|
||||
|
||||
# Apply caption dropout to cached outputs
|
||||
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(
|
||||
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
|
||||
)
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list
|
||||
else:
|
||||
# Encode on-the-fly
|
||||
input_ids_list = batch["input_ids_list"]
|
||||
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = input_ids_list
|
||||
with torch.no_grad():
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
[qwen3_text_encoder],
|
||||
[qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask],
|
||||
tokenize_strategy, [qwen3_text_encoder], input_ids_list
|
||||
)
|
||||
|
||||
# Move to device
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype)
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit_weight_dtype)
|
||||
attn_mask = attn_mask.to(accelerator.device)
|
||||
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||
@@ -658,9 +536,11 @@ def train(args):
|
||||
# Noise and timesteps
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, latents, noise, accelerator.device, weight_dtype
|
||||
# Get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, dit_weight_dtype
|
||||
)
|
||||
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
|
||||
|
||||
# NaN checks
|
||||
if torch.any(torch.isnan(noisy_model_input)):
|
||||
@@ -672,12 +552,10 @@ def train(args):
|
||||
bs = latents.shape[0]
|
||||
h_latent = latents.shape[-2]
|
||||
w_latent = latents.shape[-1]
|
||||
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
|
||||
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=dit_weight_dtype, device=accelerator.device)
|
||||
|
||||
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
|
||||
if is_swapping_blocks:
|
||||
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
|
||||
|
||||
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, (B, C, 1, H, W)
|
||||
with accelerator.autocast():
|
||||
model_pred = dit(
|
||||
noisy_model_input,
|
||||
@@ -688,6 +566,7 @@ def train(args):
|
||||
t5_input_ids=t5_input_ids,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
)
|
||||
model_pred = model_pred.squeeze(2) # 5D to 4D, (B, C, H, W)
|
||||
|
||||
# Compute loss (rectified flow: target = noise - latents)
|
||||
target = noise - latents
|
||||
@@ -702,7 +581,7 @@ def train(args):
|
||||
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3, 4]) # (B, C, T, H, W) -> (B,)
|
||||
loss = loss.mean([1, 2, 3]) # (B, C, H, W) -> (B,)
|
||||
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
@@ -713,7 +592,7 @@ def train(args):
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
|
||||
if not args.fused_backward_pass:
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
@@ -726,9 +605,6 @@ def train(args):
|
||||
else:
|
||||
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
||||
lr_scheduler.step()
|
||||
if args.blockwise_fused_optimizers:
|
||||
for i in range(1, len(optimizers)):
|
||||
lr_schedulers[i].step()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step
|
||||
if accelerator.sync_gradients:
|
||||
@@ -743,7 +619,6 @@ def train(args):
|
||||
global_step,
|
||||
dit,
|
||||
vae,
|
||||
vae_scale,
|
||||
qwen3_text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
@@ -812,7 +687,6 @@ def train(args):
|
||||
global_step,
|
||||
dit,
|
||||
vae,
|
||||
vae_scale,
|
||||
qwen3_text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
@@ -859,11 +733,6 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
anima_train_utils.add_anima_training_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
|
||||
# parser.add_argument(
|
||||
# "--blockwise_fused_optimizers",
|
||||
# action="store_true",
|
||||
# help="enable blockwise optimizers for fused backward pass and optimizer step",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--cpu_offload_checkpointing",
|
||||
action="store_true",
|
||||
@@ -891,4 +760,7 @@ if __name__ == "__main__":
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
if args.attn_mode == "sdpa":
|
||||
args.attn_mode = "torch" # backward compatibility
|
||||
|
||||
train(args)
|
||||
|
||||
@@ -41,16 +41,11 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
if (args.fp8_base or args.fp8_base_unet) and not args.fp8_scaled:
|
||||
logger.warning(
|
||||
"fp8_base and fp8_base_unet are not supported. Use fp8_scaled instead / fp8_baseとfp8_base_unetはサポートされていません。代わりにfp8_scaledを使用してください"
|
||||
)
|
||||
if args.fp8_scaled and (args.fp8_base or args.fp8_base_unet):
|
||||
logger.info(
|
||||
"fp8_scaled is used, so fp8_base and fp8_base_unet are ignored / fp8_scaledが使われているので、fp8_baseとfp8_base_unetは無視されます"
|
||||
)
|
||||
if args.fp8_base or args.fp8_base_unet:
|
||||
logger.warning("fp8_base and fp8_base_unet are not supported. / fp8_baseとfp8_base_unetはサポートされていません。")
|
||||
args.fp8_base = False
|
||||
args.fp8_base_unet = False
|
||||
args.fp8_scaled = False # Anima DiT does not support fp8_scaled
|
||||
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
|
||||
@@ -91,6 +86,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
# Load VAE
|
||||
logger.info("Loading Anima VAE...")
|
||||
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
|
||||
vae.to(weight_dtype)
|
||||
vae.eval()
|
||||
|
||||
# Return format: (model_type, text_encoders, vae, unet)
|
||||
return "anima", [qwen3_text_encoder], vae, None # unet loaded lazily
|
||||
@@ -249,7 +246,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
def encode_images_to_latents(self, args, vae, images):
|
||||
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage
|
||||
return vae.encode_pixels_to_latents(images)
|
||||
return vae.encode_pixels_to_latents(images) # Keep 4D for input/output
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
# Latents already normalized by vae.encode with scale
|
||||
@@ -272,6 +269,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
anima: anima_models.Anima = unet
|
||||
|
||||
# Sample noise
|
||||
if latents.ndim == 5: # Fallback for 5D latents (old cache)
|
||||
latents = latents.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
# Get noisy model input and timesteps
|
||||
@@ -302,11 +301,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
w_latent = latents.shape[-1]
|
||||
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
|
||||
|
||||
# Prepare block swap
|
||||
if self.is_swapping_blocks:
|
||||
accelerator.unwrap_model(anima).prepare_block_swap_before_forward()
|
||||
|
||||
# Call model
|
||||
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, [B, C, H, W] -> [B, C, 1, H, W]
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
model_pred = anima(
|
||||
noisy_model_input,
|
||||
@@ -317,6 +313,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
target_attention_mask=t5_attn_mask,
|
||||
source_attention_mask=attn_mask,
|
||||
)
|
||||
model_pred = model_pred.squeeze(2) # 5D to 4D, [B, C, 1, H, W] -> [B, C, H, W]
|
||||
|
||||
# Rectified flow target: noise - latents
|
||||
target = noise - latents
|
||||
@@ -344,10 +341,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
train_text_encoder=True,
|
||||
train_unet=True,
|
||||
) -> torch.Tensor:
|
||||
"""Override base process_batch for caption dropout with cached text encoder outputs.
|
||||
|
||||
Base class now supports 4D and 5D latents, so we only need to handle caption dropout here.
|
||||
"""
|
||||
"""Override base process_batch for caption dropout with cached text encoder outputs."""
|
||||
|
||||
# Text encoder conditions
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
@@ -418,6 +412,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
if self.is_swapping_blocks:
|
||||
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
|
||||
@@ -425,7 +420,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
anima_train_utils.add_anima_training_arguments(parser)
|
||||
parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
|
||||
# parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
|
||||
parser.add_argument(
|
||||
"--unsloth_offload_checkpointing",
|
||||
action="store_true",
|
||||
|
||||
@@ -37,14 +37,14 @@ This guide assumes you already understand the basics of LoRA training. For commo
|
||||
|
||||
## 2. Differences from `train_network.py` / `train_network.py` との違い
|
||||
|
||||
`anima_train_network.py` is based on `train_network.py` but modified for Anima . Main differences are:
|
||||
`anima_train_network.py` is based on `train_network.py` but modified for Anima. Main differences are:
|
||||
|
||||
* **Target models:** Anima DiT models.
|
||||
* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a WanVAE (16-channel latent space with 8x spatial downscale).
|
||||
* **Arguments:** Options exist to specify the Anima DiT model, Qwen3 text encoder, WanVAE, LLM adapter, and T5 tokenizer separately.
|
||||
* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used.
|
||||
* **Anima specific options:** Additional parameters for component-wise learning rates (self_attn, cross_attn, mlp, mod, llm_adapter), timestep sampling, discrete flow shift, and flash attention.
|
||||
* **6 Parameter Groups:** Independent learning rates for `base`, `self_attn`, `cross_attn`, `mlp`, `adaln_modulation`, and `llm_adapter` components.
|
||||
* **Arguments:** Uses the common `--pretrained_model_name_or_path` for the DiT model path, `--qwen3` for the Qwen3 text encoder, and `--vae` for the WanVAE. The LLM adapter and T5 tokenizer can be specified separately with `--llm_adapter_path` and `--t5_tokenizer_path`.
|
||||
* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used. `--fp8_base` is not supported.
|
||||
* **Timestep sampling:** Uses the same `--timestep_sampling` options as FLUX training (`sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift`).
|
||||
* **LoRA:** Uses regex-based module selection and per-module rank/learning rate control (`network_reg_dims`, `network_reg_lrs`) instead of per-component arguments. Module exclusion/inclusion is controlled by `exclude_patterns` and `include_patterns`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
@@ -53,10 +53,10 @@ This guide assumes you already understand the basics of LoRA training. For commo
|
||||
|
||||
* **対象モデル:** Anima DiTモデルを対象とします。
|
||||
* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびWanVAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。
|
||||
* **引数:** Anima DiTモデル、Qwen3テキストエンコーダー、WanVAE、LLM Adapter、T5トークナイザーを個別に指定する引数があります。
|
||||
* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)はAnimaの学習では使用されません。
|
||||
* **Anima特有の引数:** コンポーネント別学習率(self_attn, cross_attn, mlp, mod, llm_adapter)、タイムステップサンプリング、離散フローシフト、Flash Attentionに関する引数が追加されています。
|
||||
* **6パラメータグループ:** `base`、`self_attn`、`cross_attn`、`mlp`、`adaln_modulation`、`llm_adapter`の各コンポーネントに対して独立した学習率を設定できます。
|
||||
* **引数:** DiTモデルのパスには共通引数`--pretrained_model_name_or_path`を、Qwen3テキストエンコーダーには`--qwen3`を、WanVAEには`--vae`を使用します。LLM AdapterとT5トークナイザーはそれぞれ`--llm_adapter_path`、`--t5_tokenizer_path`で個別に指定できます。
|
||||
* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)は使用されません。`--fp8_base`はサポートされていません。
|
||||
* **タイムステップサンプリング:** FLUX学習と同じ`--timestep_sampling`オプション(`sigma`、`uniform`、`sigmoid`、`shift`、`flux_shift`)を使用します。
|
||||
* **LoRA:** コンポーネント別の引数の代わりに、正規表現ベースのモジュール選択とモジュール単位のランク/学習率制御(`network_reg_dims`、`network_reg_lrs`)を使用します。モジュールの除外/包含は`exclude_patterns`と`include_patterns`で制御します。
|
||||
</details>
|
||||
|
||||
## 3. Preparation / 準備
|
||||
@@ -74,7 +74,6 @@ The following files are required before starting training:
|
||||
**Notes:**
|
||||
* When using a single `.safetensors` file for Qwen3, download the `config.json`, `tokenizer.json`, `tokenizer_config.json`, and `vocab.json` from the [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFace repository into the `configs/qwen3_06b/` directory.
|
||||
* The T5 tokenizer only needs the tokenizer files (not the T5 model weights). It uses the vocabulary from `google/t5-v1_1-xxl`.
|
||||
* Models are saved with a `net.` prefix on all keys for ComfyUI compatibility.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
@@ -92,7 +91,6 @@ The following files are required before starting training:
|
||||
**注意:**
|
||||
* Qwen3の単体`.safetensors`ファイルを使用する場合、[Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFaceリポジトリから`config.json`、`tokenizer.json`、`tokenizer_config.json`、`vocab.json`をダウンロードし、`configs/qwen3_06b/`ディレクトリに配置してください。
|
||||
* T5トークナイザーはトークナイザーファイルのみ必要です(T5モデルの重みは不要)。`google/t5-v1_1-xxl`の語彙を使用します。
|
||||
* モデルはComfyUI互換のため、すべてのキーに`net.`プレフィックスを付けて保存されます。
|
||||
</details>
|
||||
|
||||
## 4. Running the Training / 学習の実行
|
||||
@@ -103,9 +101,9 @@ Example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
|
||||
--dit_path="<path to Anima DiT model>" \
|
||||
--qwen3_path="<path to Qwen3-0.6B model or directory>" \
|
||||
--vae_path="<path to WanVAE model>" \
|
||||
--pretrained_model_name_or_path="<path to Anima DiT model>" \
|
||||
--qwen3="<path to Qwen3-0.6B model or directory>" \
|
||||
--vae="<path to WanVAE model>" \
|
||||
--llm_adapter_path="<path to LLM adapter model>" \
|
||||
--dataset_config="my_anima_dataset_config.toml" \
|
||||
--output_dir="<output directory>" \
|
||||
@@ -117,7 +115,7 @@ accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
|
||||
--learning_rate=1e-4 \
|
||||
--optimizer_type="AdamW8bit" \
|
||||
--lr_scheduler="constant" \
|
||||
--timestep_sample_method="logit_normal" \
|
||||
--timestep_sampling="sigmoid" \
|
||||
--discrete_flow_shift=1.0 \
|
||||
--max_train_epochs=10 \
|
||||
--save_every_n_epochs=1 \
|
||||
@@ -146,11 +144,11 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
|
||||
|
||||
#### Model Options [Required] / モデル関連 [必須]
|
||||
|
||||
* `--dit_path="<path to Anima DiT model>"` **[Required]**
|
||||
* `--pretrained_model_name_or_path="<path to Anima DiT model>"` **[Required]**
|
||||
- Path to the Anima DiT model `.safetensors` file. The model config (channels, blocks, heads) is auto-detected from the state dict. ComfyUI format with `net.` prefix is supported.
|
||||
* `--qwen3_path="<path to Qwen3-0.6B model>"` **[Required]**
|
||||
* `--qwen3="<path to Qwen3-0.6B model>"` **[Required]**
|
||||
- Path to the Qwen3-0.6B text encoder. Can be a HuggingFace model directory or a single `.safetensors` file. The text encoder is always frozen during training.
|
||||
* `--vae_path="<path to WanVAE model>"` **[Required]**
|
||||
* `--vae="<path to WanVAE model>"` **[Required]**
|
||||
- Path to the WanVAE model `.safetensors` or `.pth` file. Fixed config: `dim=96, z_dim=16`.
|
||||
* `--llm_adapter_path="<path to LLM adapter>"` *[Optional]*
|
||||
- Path to a separate LLM adapter weights file. If omitted, the adapter is loaded from the DiT file when the key `llm_adapter.out_proj.weight` exists.
|
||||
@@ -159,53 +157,54 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
|
||||
|
||||
#### Anima Training Parameters / Anima 学習パラメータ
|
||||
|
||||
* `--timestep_sample_method=<choice>`
|
||||
- Timestep sampling method. Choose from `logit_normal` (default) or `uniform`.
|
||||
* `--timestep_sampling=<choice>`
|
||||
- Timestep sampling method. Choose from `sigma`, `uniform`, `sigmoid` (default), `shift`, `flux_shift`. Same options as FLUX training. See the [flux_train_network.py guide](flux_train_network.md) for details on each method.
|
||||
* `--discrete_flow_shift=<float>`
|
||||
- Shift for the timestep distribution in Rectified Flow training. Default `1.0`. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`. 1.0 means no shift.
|
||||
- Shift for the timestep distribution in Rectified Flow training. Default `1.0`. This value is used when `--timestep_sampling` is set to **`shift`**. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
|
||||
* `--sigmoid_scale=<float>`
|
||||
- Scale factor for `logit_normal` timestep sampling. Default `1.0`.
|
||||
- Scale factor when `--timestep_sampling` is set to `sigmoid`, `shift`, or `flux_shift`. Default `1.0`.
|
||||
* `--qwen3_max_token_length=<integer>`
|
||||
- Maximum token length for the Qwen3 tokenizer. Default `512`.
|
||||
* `--t5_max_token_length=<integer>`
|
||||
- Maximum token length for the T5 tokenizer. Default `512`.
|
||||
* `--flash_attn`
|
||||
- Use Flash Attention for DiT self/cross-attention. Requires `pip install flash-attn`. Falls back to PyTorch SDPA if the package is not installed. Note: Flash Attention is only applied to DiT blocks; the LLM Adapter uses standard attention because it requires attention masks.
|
||||
* `--transformer_dtype=<choice>`
|
||||
- Separate dtype for transformer blocks. Choose from `float16`, `bfloat16`, `float32`. If not specified, uses the same dtype as `--mixed_precision`.
|
||||
* `--attn_mode=<choice>`
|
||||
- Attention implementation to use. Choose from `torch` (default), `xformers`, `flash`, `sageattn`. `xformers` requires `--split_attn`. `sageattn` does not support training (inference only). This option overrides `--xformers`.
|
||||
* `--split_attn`
|
||||
- Split attention computation to reduce memory usage. Required when using `--attn_mode xformers`.
|
||||
|
||||
#### Component-wise Learning Rates / コンポーネント別学習率
|
||||
|
||||
Anima supports 6 independent learning rate groups. Set to `0` to freeze a component:
|
||||
These options set separate learning rates for each component of the Anima model. They are primarily used for full fine-tuning. Set to `0` to freeze a component:
|
||||
|
||||
* `--self_attn_lr=<float>` - Learning rate for self-attention layers. Default: same as `--learning_rate`.
|
||||
* `--cross_attn_lr=<float>` - Learning rate for cross-attention layers. Default: same as `--learning_rate`.
|
||||
* `--mlp_lr=<float>` - Learning rate for MLP layers. Default: same as `--learning_rate`.
|
||||
* `--mod_lr=<float>` - Learning rate for AdaLN modulation layers. Default: same as `--learning_rate`.
|
||||
* `--mod_lr=<float>` - Learning rate for AdaLN modulation layers. Default: same as `--learning_rate`. Note: modulation layers are not included in LoRA by default.
|
||||
* `--llm_adapter_lr=<float>` - Learning rate for LLM adapter layers. Default: same as `--learning_rate`.
|
||||
|
||||
For LoRA training, use `network_reg_lrs` in `--network_args` instead. See [Section 5.2](#52-regex-based-rank-and-learning-rate-control--正規表現によるランク学習率の制御).
|
||||
|
||||
#### Memory and Speed / メモリ・速度関連
|
||||
|
||||
* `--blocks_to_swap=<integer>` **[Experimental]**
|
||||
* `--blocks_to_swap=<integer>`
|
||||
- Number of Transformer blocks to swap between CPU and GPU. More blocks reduce VRAM but slow training. Maximum values depend on model size:
|
||||
- 28-block model: max **26**
|
||||
- 28-block model: max **26** (Anima-Preview)
|
||||
- 36-block model: max **34**
|
||||
- 20-block model: max **18**
|
||||
- Cannot be used with `--cpu_offload_checkpointing` or `--unsloth_offload_checkpointing`.
|
||||
* `--unsloth_offload_checkpointing`
|
||||
- Offload activations to CPU RAM using async non-blocking transfers. Faster than `--cpu_offload_checkpointing`. Cannot be combined with `--cpu_offload_checkpointing` or `--blocks_to_swap`.
|
||||
- Offload activations to CPU RAM using async non-blocking transfers (faster than `--cpu_offload_checkpointing`). Cannot be combined with `--cpu_offload_checkpointing` or `--blocks_to_swap`.
|
||||
* `--cache_text_encoder_outputs`
|
||||
- Cache Qwen3 text encoder outputs to reduce VRAM usage. Recommended when not training text encoder LoRA.
|
||||
* `--cache_text_encoder_outputs_to_disk`
|
||||
- Cache text encoder outputs to disk. Auto-enables `--cache_text_encoder_outputs`.
|
||||
* `--cache_latents`, `--cache_latents_to_disk`
|
||||
- Cache WanVAE latent outputs.
|
||||
* `--fp8_base`
|
||||
- Use FP8 precision for the base model to reduce VRAM usage.
|
||||
|
||||
#### Incompatible or Deprecated Options / 非互換・非推奨の引数
|
||||
#### Incompatible or Unsupported Options / 非互換・非サポートの引数
|
||||
|
||||
* `--v2`, `--v_parameterization`, `--clip_skip` - Options for Stable Diffusion v1/v2 that are not used for Anima training.
|
||||
* `--fp8_base` - Not supported for Anima. If specified, it will be disabled with a warning.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
@@ -214,39 +213,45 @@ Anima supports 6 independent learning rate groups. Set to `0` to freeze a compon
|
||||
|
||||
#### モデル関連 [必須]
|
||||
|
||||
* `--dit_path="<path to Anima DiT model>"` **[必須]** - Anima DiTモデルの`.safetensors`ファイルのパスを指定します。
|
||||
* `--qwen3_path="<path to Qwen3-0.6B model>"` **[必須]** - Qwen3-0.6Bテキストエンコーダーのパスを指定します。
|
||||
* `--vae_path="<path to WanVAE model>"` **[必須]** - WanVAEモデルのパスを指定します。
|
||||
* `--pretrained_model_name_or_path="<path to Anima DiT model>"` **[必須]** - Anima DiTモデルの`.safetensors`ファイルのパスを指定します。モデルの設定はstate dictから自動検出されます。`net.`プレフィックス付きのComfyUIフォーマットもサポートしています。
|
||||
* `--qwen3="<path to Qwen3-0.6B model>"` **[必須]** - Qwen3-0.6Bテキストエンコーダーのパスを指定します。HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイルが使用できます。
|
||||
* `--vae="<path to WanVAE model>"` **[必須]** - WanVAEモデルのパスを指定します。
|
||||
* `--llm_adapter_path="<path to LLM adapter>"` *[オプション]* - 個別のLLM Adapterの重みファイルのパス。
|
||||
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[オプション]* - T5トークナイザーディレクトリのパス。
|
||||
|
||||
#### Anima 学習パラメータ
|
||||
|
||||
* `--timestep_sample_method` - タイムステップのサンプリング方法。`logit_normal`(デフォルト)または`uniform`。
|
||||
* `--discrete_flow_shift` - Rectified Flow学習のタイムステップ分布シフト。デフォルト`3.0`。
|
||||
* `--sigmoid_scale` - logit_normalタイムステップサンプリングのスケール係数。デフォルト`1.0`。
|
||||
* `--timestep_sampling` - タイムステップのサンプリング方法。`sigma`、`uniform`、`sigmoid`(デフォルト)、`shift`、`flux_shift`から選択。FLUX学習と同じオプションです。各方法の詳細は[flux_train_network.pyのガイド](flux_train_network.md)を参照してください。
|
||||
* `--discrete_flow_shift` - Rectified Flow学習のタイムステップ分布シフト。デフォルト`1.0`。`--timestep_sampling`が`shift`の場合に使用されます。
|
||||
* `--sigmoid_scale` - `sigmoid`、`shift`、`flux_shift`タイムステップサンプリングのスケール係数。デフォルト`1.0`。
|
||||
* `--qwen3_max_token_length` - Qwen3トークナイザーの最大トークン長。デフォルト`512`。
|
||||
* `--t5_max_token_length` - T5トークナイザーの最大トークン長。デフォルト`512`。
|
||||
* `--flash_attn` - DiTのself/cross-attentionにFlash Attentionを使用。`pip install flash-attn`が必要。
|
||||
* `--transformer_dtype` - Transformerブロック用の個別dtype。
|
||||
* `--attn_mode` - 使用するAttentionの実装。`torch`(デフォルト)、`xformers`、`flash`、`sageattn`から選択。`xformers`は`--split_attn`の指定が必要です。`sageattn`はトレーニングをサポートしていません(推論のみ)。
|
||||
* `--split_attn` - メモリ使用量を減らすためにattention時にバッチを分割します。`--attn_mode xformers`使用時に必要です。
|
||||
|
||||
#### コンポーネント別学習率
|
||||
|
||||
Animaは6つの独立した学習率グループをサポートします。`0`に設定するとそのコンポーネントをフリーズします:
|
||||
これらのオプションは、Animaモデルの各コンポーネントに個別の学習率を設定します。主にフルファインチューニング用です。`0`に設定するとそのコンポーネントをフリーズします:
|
||||
|
||||
* `--self_attn_lr` - Self-attention層の学習率。
|
||||
* `--cross_attn_lr` - Cross-attention層の学習率。
|
||||
* `--mlp_lr` - MLP層の学習率。
|
||||
* `--mod_lr` - AdaLNモジュレーション層の学習率。
|
||||
* `--mod_lr` - AdaLNモジュレーション層の学習率。モジュレーション層はデフォルトではLoRAに含まれません。
|
||||
* `--llm_adapter_lr` - LLM Adapter層の学習率。
|
||||
|
||||
LoRA学習の場合は、`--network_args`の`network_reg_lrs`を使用してください。[セクション5.2](#52-regex-based-rank-and-learning-rate-control--正規表現によるランク学習率の制御)を参照。
|
||||
|
||||
#### メモリ・速度関連
|
||||
|
||||
* `--blocks_to_swap` **[実験的機能]** - TransformerブロックをCPUとGPUでスワップしてVRAMを節約。
|
||||
* `--unsloth_offload_checkpointing` - 非同期転送でアクティベーションをCPU RAMにオフロード。
|
||||
* `--blocks_to_swap` - TransformerブロックをCPUとGPUでスワップしてVRAMを節約。`--cpu_offload_checkpointing`および`--unsloth_offload_checkpointing`とは併用できません。
|
||||
* `--unsloth_offload_checkpointing` - 非同期転送でアクティベーションをCPU RAMにオフロード。`--cpu_offload_checkpointing`および`--blocks_to_swap`とは併用できません。
|
||||
* `--cache_text_encoder_outputs` - Qwen3の出力をキャッシュしてメモリ使用量を削減。
|
||||
* `--cache_latents`, `--cache_latents_to_disk` - WanVAEの出力をキャッシュ。
|
||||
* `--fp8_base` - ベースモデルにFP8精度を使用。
|
||||
|
||||
#### 非互換・非サポートの引数
|
||||
|
||||
* `--v2`, `--v_parameterization`, `--clip_skip` - Stable Diffusion v1/v2向けの引数。Animaの学習では使用されません。
|
||||
* `--fp8_base` - Animaではサポートされていません。指定した場合、警告とともに無効化されます。
|
||||
</details>
|
||||
|
||||
### 4.2. Starting Training / 学習の開始
|
||||
@@ -262,67 +267,64 @@ After setting the required arguments, run the command to begin training. The ove
|
||||
|
||||
## 5. LoRA Target Modules / LoRAの学習対象モジュール
|
||||
|
||||
When training LoRA with `anima_train_network.py`, the following modules are targeted:
|
||||
When training LoRA with `anima_train_network.py`, the following modules are targeted by default:
|
||||
|
||||
* **DiT Blocks (`Block`)**: Self-attention, cross-attention, MLP, and AdaLN modulation layers within each transformer block.
|
||||
* **DiT Blocks (`Block`)**: Self-attention (`self_attn`), cross-attention (`cross_attn`), and MLP (`mlp`) layers within each transformer block. Modulation (`adaln_modulation`), norm, embedder, and final layers are excluded by default.
|
||||
* **Embedding layers (`PatchEmbed`, `TimestepEmbedding`) and Final layer (`FinalLayer`)**: Excluded by default but can be included using `include_patterns`.
|
||||
* **LLM Adapter Blocks (`LLMAdapterTransformerBlock`)**: Only when `--network_args "train_llm_adapter=True"` is specified.
|
||||
* **Text Encoder (Qwen3)**: Only when `--network_train_unet_only` is NOT specified.
|
||||
* **Text Encoder (Qwen3)**: Only when `--network_train_unet_only` is NOT specified and `--cache_text_encoder_outputs` is NOT used.
|
||||
|
||||
The LoRA network module is `networks.lora_anima`.
|
||||
|
||||
### 5.1. Layer-specific Rank Configuration / 各層に対するランク指定
|
||||
### 5.1. Module Selection with Patterns / パターンによるモジュール選択
|
||||
|
||||
You can specify different ranks (network_dim) for each component of the Anima model. Setting `0` disables LoRA for that component.
|
||||
|
||||
| network_args | Target Component |
|
||||
|---|---|
|
||||
| `self_attn_dim` | Self-attention layers in DiT blocks |
|
||||
| `cross_attn_dim` | Cross-attention layers in DiT blocks |
|
||||
| `mlp_dim` | MLP layers in DiT blocks |
|
||||
| `mod_dim` | AdaLN modulation layers in DiT blocks |
|
||||
| `llm_adapter_dim` | LLM adapter layers (requires `train_llm_adapter=True`) |
|
||||
|
||||
Example usage:
|
||||
By default, the following modules are excluded from LoRA via the built-in exclude pattern:
|
||||
```
|
||||
--network_args "self_attn_dim=8" "cross_attn_dim=4" "mlp_dim=8" "mod_dim=4"
|
||||
.*(_modulation|_norm|_embedder|final_layer).*
|
||||
```
|
||||
|
||||
### 5.2. Embedding Layer LoRA / 埋め込み層LoRA
|
||||
You can customize which modules are included or excluded using regex patterns in `--network_args`:
|
||||
|
||||
You can apply LoRA to embedding/output layers by specifying `emb_dims` in network_args as a comma-separated list of 3 numbers:
|
||||
* `exclude_patterns` - Exclude modules matching these patterns (in addition to the default exclusion).
|
||||
* `include_patterns` - Force-include modules matching these patterns, overriding exclusion.
|
||||
|
||||
Patterns are matched against the full module name using `re.fullmatch()`.
|
||||
|
||||
Example to include the final layer:
|
||||
```
|
||||
--network_args "emb_dims=[8,4,8]"
|
||||
--network_args "include_patterns=['.*final_layer.*']"
|
||||
```
|
||||
|
||||
Each number corresponds to:
|
||||
1. `x_embedder` (patch embedding)
|
||||
2. `t_embedder` (timestep embedding)
|
||||
3. `final_layer` (output layer)
|
||||
|
||||
Setting `0` disables LoRA for that layer.
|
||||
|
||||
### 5.3. Block Selection for Training / 学習するブロックの指定
|
||||
|
||||
You can specify which DiT blocks to train using `train_block_indices` in network_args. The indices are 0-based. Default is to train all blocks.
|
||||
|
||||
Specify indices as comma-separated integers or ranges:
|
||||
|
||||
Example to additionally exclude MLP layers:
|
||||
```
|
||||
--network_args "train_block_indices=0-5,10,15-27"
|
||||
--network_args "exclude_patterns=['.*mlp.*']"
|
||||
```
|
||||
|
||||
Special values: `all` (train all blocks), `none` (skip all blocks).
|
||||
### 5.2. Regex-based Rank and Learning Rate Control / 正規表現によるランク・学習率の制御
|
||||
|
||||
### 5.4. LLM Adapter LoRA / LLM Adapter LoRA
|
||||
You can specify different ranks (network_dim) and learning rates for modules matching specific regex patterns:
|
||||
|
||||
* `network_reg_dims`: Specify ranks for modules matching a regular expression. The format is a comma-separated string of `pattern=rank`.
|
||||
* Example: `--network_args "network_reg_dims=.*self_attn.*=8,.*cross_attn.*=4,.*mlp.*=8"`
|
||||
* This sets the rank to 8 for self-attention modules, 4 for cross-attention modules, and 8 for MLP modules.
|
||||
* `network_reg_lrs`: Specify learning rates for modules matching a regular expression. The format is a comma-separated string of `pattern=lr`.
|
||||
* Example: `--network_args "network_reg_lrs=.*self_attn.*=1e-4,.*cross_attn.*=5e-5"`
|
||||
* This sets the learning rate to `1e-4` for self-attention modules and `5e-5` for cross-attention modules.
|
||||
|
||||
**Notes:**
|
||||
|
||||
* Settings via `network_reg_dims` and `network_reg_lrs` take precedence over the global `--network_dim` and `--learning_rate` settings.
|
||||
* Patterns are matched using `re.fullmatch()` against the module's original name (e.g., `blocks.0.self_attn.q_proj`).
|
||||
|
||||
### 5.3. LLM Adapter LoRA / LLM Adapter LoRA
|
||||
|
||||
To apply LoRA to the LLM Adapter blocks:
|
||||
|
||||
```
|
||||
--network_args "train_llm_adapter=True" "llm_adapter_dim=4"
|
||||
--network_args "train_llm_adapter=True"
|
||||
```
|
||||
|
||||
### 5.5. Other Network Args / その他のネットワーク引数
|
||||
### 5.4. Other Network Args / その他のネットワーク引数
|
||||
|
||||
* `--network_args "verbose=True"` - Print all LoRA module names and their dimensions.
|
||||
* `--network_args "rank_dropout=0.1"` - Rank dropout rate.
|
||||
@@ -336,48 +338,56 @@ To apply LoRA to the LLM Adapter blocks:
|
||||
|
||||
`anima_train_network.py`でLoRAを学習させる場合、デフォルトでは以下のモジュールが対象となります。
|
||||
|
||||
* **DiTブロック (`Block`)**: 各Transformerブロック内のSelf-attention、Cross-attention、MLP、AdaLNモジュレーション層。
|
||||
* **DiTブロック (`Block`)**: 各Transformerブロック内のSelf-attention(`self_attn`)、Cross-attention(`cross_attn`)、MLP(`mlp`)層。モジュレーション(`adaln_modulation`)、norm、embedder、final layerはデフォルトで除外されます。
|
||||
* **埋め込み層 (`PatchEmbed`, `TimestepEmbedding`) と最終層 (`FinalLayer`)**: デフォルトで除外されますが、`include_patterns`で含めることができます。
|
||||
* **LLM Adapterブロック (`LLMAdapterTransformerBlock`)**: `--network_args "train_llm_adapter=True"`を指定した場合のみ。
|
||||
* **テキストエンコーダー (Qwen3)**: `--network_train_unet_only`を指定しない場合のみ。
|
||||
* **テキストエンコーダー (Qwen3)**: `--network_train_unet_only`を指定せず、かつ`--cache_text_encoder_outputs`を使用しない場合のみ。
|
||||
|
||||
### 5.1. 各層のランクを指定する
|
||||
### 5.1. パターンによるモジュール選択
|
||||
|
||||
`--network_args`で各コンポーネントに異なるランクを指定できます。`0`を指定するとその層にはLoRAが適用されません。
|
||||
デフォルトでは以下のモジュールが組み込みの除外パターンによりLoRAから除外されます:
|
||||
```
|
||||
.*(_modulation|_norm|_embedder|final_layer).*
|
||||
```
|
||||
|
||||
|network_args|対象コンポーネント|
|
||||
|---|---|
|
||||
|`self_attn_dim`|DiTブロック内のSelf-attention層|
|
||||
|`cross_attn_dim`|DiTブロック内のCross-attention層|
|
||||
|`mlp_dim`|DiTブロック内のMLP層|
|
||||
|`mod_dim`|DiTブロック内のAdaLNモジュレーション層|
|
||||
|`llm_adapter_dim`|LLM Adapter層(`train_llm_adapter=True`が必要)|
|
||||
`--network_args`で正規表現パターンを使用して、含めるモジュールと除外するモジュールをカスタマイズできます:
|
||||
|
||||
### 5.2. 埋め込み層LoRA
|
||||
* `exclude_patterns` - これらのパターンにマッチするモジュールを除外(デフォルトの除外に追加)。
|
||||
* `include_patterns` - これらのパターンにマッチするモジュールを強制的に含める(除外を上書き)。
|
||||
|
||||
`emb_dims`で埋め込み/出力層にLoRAを適用できます。3つの数値をカンマ区切りで指定します。
|
||||
パターンは`re.fullmatch()`を使用して完全なモジュール名に対してマッチングされます。
|
||||
|
||||
各数値は `x_embedder`(パッチ埋め込み)、`t_embedder`(タイムステップ埋め込み)、`final_layer`(出力層)に対応します。
|
||||
### 5.2. 正規表現によるランク・学習率の制御
|
||||
|
||||
### 5.3. 学習するブロックの指定
|
||||
正規表現にマッチするモジュールに対して、異なるランクや学習率を指定できます:
|
||||
|
||||
`train_block_indices`でLoRAを適用するDiTブロックを指定できます。
|
||||
* `network_reg_dims`: 正規表現にマッチするモジュールに対してランクを指定します。`pattern=rank`形式の文字列をカンマで区切って指定します。
|
||||
* 例: `--network_args "network_reg_dims=.*self_attn.*=8,.*cross_attn.*=4,.*mlp.*=8"`
|
||||
* `network_reg_lrs`: 正規表現にマッチするモジュールに対して学習率を指定します。`pattern=lr`形式の文字列をカンマで区切って指定します。
|
||||
* 例: `--network_args "network_reg_lrs=.*self_attn.*=1e-4,.*cross_attn.*=5e-5"`
|
||||
|
||||
### 5.4. LLM Adapter LoRA
|
||||
**注意点:**
|
||||
* `network_reg_dims`および`network_reg_lrs`での設定は、全体設定である`--network_dim`や`--learning_rate`よりも優先されます。
|
||||
* パターンはモジュールのオリジナル名(例: `blocks.0.self_attn.q_proj`)に対して`re.fullmatch()`でマッチングされます。
|
||||
|
||||
LLM AdapterブロックにLoRAを適用するには:`--network_args "train_llm_adapter=True" "llm_adapter_dim=4"`
|
||||
### 5.3. LLM Adapter LoRA
|
||||
|
||||
### 5.5. その他のネットワーク引数
|
||||
LLM AdapterブロックにLoRAを適用するには:`--network_args "train_llm_adapter=True"`
|
||||
|
||||
### 5.4. その他のネットワーク引数
|
||||
|
||||
* `verbose=True` - 全LoRAモジュール名とdimを表示
|
||||
* `rank_dropout` - ランクドロップアウト率
|
||||
* `module_dropout` - モジュールドロップアウト率
|
||||
* `loraplus_lr_ratio` - LoRA+学習率比率
|
||||
* `loraplus_unet_lr_ratio` - DiT専用のLoRA+学習率比率
|
||||
* `loraplus_text_encoder_lr_ratio` - テキストエンコーダー専用のLoRA+学習率比率
|
||||
|
||||
</details>
|
||||
|
||||
## 6. Using the Trained Model / 学習済みモデルの利用
|
||||
|
||||
When training finishes, a LoRA model file (e.g. `my_anima_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Anima , such as ComfyUI with appropriate nodes.
|
||||
When training finishes, a LoRA model file (e.g. `my_anima_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Anima, such as ComfyUI with appropriate nodes.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
@@ -394,8 +404,6 @@ Anima models can be large, so GPUs with limited VRAM may require optimization:
|
||||
|
||||
#### Key VRAM Reduction Options
|
||||
|
||||
- **`--fp8_base`**: Enables training in FP8 format for the DiT model.
|
||||
|
||||
- **`--blocks_to_swap <number>`**: Swaps blocks between CPU and GPU to reduce VRAM usage. Higher numbers save more VRAM but reduce training speed. See model-specific max values in section 4.1.
|
||||
|
||||
- **`--unsloth_offload_checkpointing`**: Offloads gradient checkpoints to CPU using async non-blocking transfers. Faster than `--cpu_offload_checkpointing`. Cannot be combined with `--blocks_to_swap`.
|
||||
@@ -417,7 +425,6 @@ Anima models can be large, so GPUs with limited VRAM may require optimization:
|
||||
Animaモデルは大きい場合があるため、VRAMが限られたGPUでは最適化が必要です。
|
||||
|
||||
主要なVRAM削減オプション:
|
||||
- `--fp8_base`: FP8形式での学習を有効化
|
||||
- `--blocks_to_swap`: CPUとGPU間でブロックをスワップ
|
||||
- `--unsloth_offload_checkpointing`: 非同期転送でアクティベーションをCPUにオフロード
|
||||
- `--gradient_checkpointing`: 標準的な勾配チェックポイント
|
||||
@@ -431,21 +438,24 @@ Animaモデルは大きい場合があるため、VRAMが限られたGPUでは
|
||||
|
||||
#### Timestep Sampling
|
||||
|
||||
The `--timestep_sample_method` option specifies how timesteps (0-1) are sampled:
|
||||
The `--timestep_sampling` option specifies how timesteps are sampled. The available methods are the same as FLUX training:
|
||||
|
||||
- `logit_normal` (default): Sample from Normal(0,1), multiply by `sigmoid_scale`, apply sigmoid. Good general-purpose option.
|
||||
- `sigma`: Sigma-based sampling like SD3.
|
||||
- `uniform`: Uniform random sampling from [0, 1].
|
||||
- `sigmoid` (default): Sample from Normal(0,1), multiply by `sigmoid_scale`, apply sigmoid. Good general-purpose option.
|
||||
- `shift`: Like `sigmoid`, but applies the discrete flow shift formula: `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
|
||||
- `flux_shift`: Resolution-dependent shift used in FLUX training.
|
||||
|
||||
See the [flux_train_network.py guide](flux_train_network.md) for detailed descriptions.
|
||||
|
||||
#### Discrete Flow Shift
|
||||
|
||||
The `--discrete_flow_shift` option (default `3.0`) shifts the timestep distribution toward higher noise levels. The formula is:
|
||||
The `--discrete_flow_shift` option (default `1.0`) only applies when `--timestep_sampling` is set to `shift`. The formula is:
|
||||
|
||||
```
|
||||
t_shifted = (t * shift) / (1 + (shift - 1) * t)
|
||||
```
|
||||
|
||||
Timesteps are clamped to `[1e-5, 1-1e-5]` after shifting.
|
||||
|
||||
#### Loss Weighting
|
||||
|
||||
The `--weighting_scheme` option specifies loss weighting by timestep:
|
||||
@@ -454,23 +464,30 @@ The `--weighting_scheme` option specifies loss weighting by timestep:
|
||||
- `sigma_sqrt`: Weight by `sigma^(-2)`.
|
||||
- `cosmap`: Weight by `2 / (pi * (1 - 2*sigma + 2*sigma^2))`.
|
||||
- `none`: Same as uniform.
|
||||
- `logit_normal`, `mode`: Additional schemes from SD3 training. See the [`sd3_train_network.md` guide](sd3_train_network.md) for details.
|
||||
|
||||
#### Caption Dropout
|
||||
|
||||
Use `--caption_dropout_rate` for embedding-level caption dropout. This is handled by `AnimaTextEncodingStrategy` and is compatible with text encoder output caching. The subset-level `caption_dropout_rate` is automatically zeroed when this is set.
|
||||
Caption dropout uses the `caption_dropout_rate` setting from the dataset configuration (per-subset in TOML). When using `--cache_text_encoder_outputs`, the dropout rate is stored with each cached entry and applied during training, so caption dropout is compatible with text encoder output caching.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
#### タイムステップサンプリング
|
||||
|
||||
`--timestep_sample_method`でタイムステップのサンプリング方法を指定します:
|
||||
- `logit_normal`(デフォルト): 正規分布からサンプリングし、sigmoidを適用。
|
||||
`--timestep_sampling`でタイムステップのサンプリング方法を指定します。FLUX学習と同じ方法が利用できます:
|
||||
|
||||
- `sigma`: SD3と同様のシグマベースサンプリング。
|
||||
- `uniform`: [0, 1]の一様分布からサンプリング。
|
||||
- `sigmoid`(デフォルト): 正規分布からサンプリングし、sigmoidを適用。汎用的なオプション。
|
||||
- `shift`: `sigmoid`と同様だが、離散フローシフトの式を適用。
|
||||
- `flux_shift`: FLUX学習で使用される解像度依存のシフト。
|
||||
|
||||
詳細は[flux_train_network.pyのガイド](flux_train_network.md)を参照してください。
|
||||
|
||||
#### 離散フローシフト
|
||||
|
||||
`--discrete_flow_shift`(デフォルト`3.0`)はタイムステップ分布を高ノイズ側にシフトします。
|
||||
`--discrete_flow_shift`(デフォルト`1.0`)は`--timestep_sampling`が`shift`の場合のみ適用されます。
|
||||
|
||||
#### 損失の重み付け
|
||||
|
||||
@@ -478,7 +495,7 @@ Use `--caption_dropout_rate` for embedding-level caption dropout. This is handle
|
||||
|
||||
#### キャプションドロップアウト
|
||||
|
||||
`--caption_dropout_rate`で埋め込みレベルのキャプションドロップアウトを使用します。テキストエンコーダー出力のキャッシュと互換性があります。
|
||||
キャプションドロップアウトにはデータセット設定(TOMLでのサブセット単位)の`caption_dropout_rate`を使用します。`--cache_text_encoder_outputs`使用時は、ドロップアウト率が各キャッシュエントリとともに保存され、学習中に適用されるため、テキストエンコーダー出力キャッシュとの互換性があります。
|
||||
|
||||
</details>
|
||||
|
||||
@@ -487,17 +504,23 @@ Use `--caption_dropout_rate` for embedding-level caption dropout. This is handle
|
||||
Anima LoRA training supports training Qwen3 text encoder LoRA:
|
||||
|
||||
- To train only DiT: specify `--network_train_unet_only`
|
||||
- To train DiT and Qwen3: omit `--network_train_unet_only`
|
||||
- To train DiT and Qwen3: omit `--network_train_unet_only` and do NOT use `--cache_text_encoder_outputs`
|
||||
|
||||
You can specify a separate learning rate for Qwen3 with `--text_encoder_lr`. If not specified, the default `--learning_rate` is used.
|
||||
|
||||
Note: When `--cache_text_encoder_outputs` is used, text encoder outputs are pre-computed and the text encoder is removed from GPU, so text encoder LoRA cannot be trained.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
Anima LoRA学習では、Qwen3テキストエンコーダーのLoRAもトレーニングできます。
|
||||
|
||||
- DiTのみ学習: `--network_train_unet_only`を指定
|
||||
- DiTとQwen3を学習: `--network_train_unet_only`を省略
|
||||
- DiTとQwen3を学習: `--network_train_unet_only`を省略し、`--cache_text_encoder_outputs`を使用しない
|
||||
|
||||
Qwen3に個別の学習率を指定するには`--text_encoder_lr`を使用します。未指定の場合は`--learning_rate`が使われます。
|
||||
|
||||
注意: `--cache_text_encoder_outputs`を使用する場合、テキストエンコーダーの出力が事前に計算されGPUから解放されるため、テキストエンコーダーLoRAは学習できません。
|
||||
|
||||
</details>
|
||||
|
||||
@@ -532,12 +555,15 @@ Anima LoRA学習では、Qwen3テキストエンコーダーのLoRAもトレー
|
||||
|
||||
### Metadata Saved in LoRA Models
|
||||
|
||||
The following Anima-specific metadata is saved in the LoRA model file:
|
||||
The following metadata is saved in the LoRA model file:
|
||||
|
||||
* `ss_weighting_scheme`
|
||||
* `ss_discrete_flow_shift`
|
||||
* `ss_timestep_sample_method`
|
||||
* `ss_logit_mean`
|
||||
* `ss_logit_std`
|
||||
* `ss_mode_scale`
|
||||
* `ss_timestep_sampling`
|
||||
* `ss_sigmoid_scale`
|
||||
* `ss_discrete_flow_shift`
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
@@ -546,11 +572,14 @@ The following Anima-specific metadata is saved in the LoRA model file:
|
||||
|
||||
### LoRAモデルに保存されるメタデータ
|
||||
|
||||
以下のAnima固有のメタデータがLoRAモデルファイルに保存されます:
|
||||
以下のメタデータがLoRAモデルファイルに保存されます:
|
||||
|
||||
* `ss_weighting_scheme`
|
||||
* `ss_discrete_flow_shift`
|
||||
* `ss_timestep_sample_method`
|
||||
* `ss_logit_mean`
|
||||
* `ss_logit_std`
|
||||
* `ss_mode_scale`
|
||||
* `ss_timestep_sampling`
|
||||
* `ss_sigmoid_scale`
|
||||
* `ss_discrete_flow_shift`
|
||||
|
||||
</details>
|
||||
|
||||
@@ -401,6 +401,12 @@ class Attention(nn.Module):
|
||||
rope_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
||||
if q.dtype != v.dtype:
|
||||
if (not attn_params.supports_fp32 or attn_params.requires_same_dtype) and torch.is_autocast_enabled():
|
||||
# FlashAttention requires fp16/bf16, xformers require same dtype; only cast when autocast is active.
|
||||
target_dtype = v.dtype # v has fp16/bf16 dtype
|
||||
q = q.to(target_dtype)
|
||||
k = k.to(target_dtype)
|
||||
# return self.compute_attention(q, k, v)
|
||||
qkv = [q, k, v]
|
||||
del q, k, v
|
||||
@@ -1304,6 +1310,20 @@ class Anima(nn.Module):
|
||||
if self.blocks_to_swap:
|
||||
self.blocks = save_blocks
|
||||
|
||||
def switch_block_swap_for_inference(self):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
self.offloader.set_forward_only(True)
|
||||
self.prepare_block_swap_before_forward()
|
||||
print(f"Anima: Block swap set to forward only.")
|
||||
|
||||
def switch_block_swap_for_training(self):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
self.offloader.set_forward_only(False)
|
||||
self.prepare_block_swap_before_forward()
|
||||
print(f"Anima: Block swap set to forward and backward.")
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
|
||||
@@ -444,7 +444,7 @@ def sample_images(
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
dit,
|
||||
dit: anima_models.Anima,
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenize_strategy,
|
||||
@@ -479,6 +479,8 @@ def sample_images(
|
||||
if text_encoder is not None:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
dit.switch_block_swap_for_inference()
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
save_dir = os.path.join(args.output_dir, "sample")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
@@ -493,6 +495,7 @@ def sample_images(
|
||||
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
for prompt_dict in prompts:
|
||||
dit.prepare_block_swap_before_forward()
|
||||
_sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
@@ -514,6 +517,7 @@ def sample_images(
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
dit.switch_block_swap_for_training()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
|
||||
@@ -154,7 +154,9 @@ def load_anima_model(
|
||||
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
|
||||
"""
|
||||
# dit_weight_dtype is None for fp8_scaled
|
||||
assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None)
|
||||
assert (
|
||||
not fp8_scaled and dit_weight_dtype is not None
|
||||
) or dit_weight_dtype is None, "dit_weight_dtype should be None when fp8_scaled is True"
|
||||
|
||||
device = torch.device(device)
|
||||
loading_device = torch.device(loading_device)
|
||||
|
||||
@@ -37,6 +37,14 @@ class AttentionParams:
|
||||
cu_seqlens: Optional[torch.Tensor] = None
|
||||
max_seqlen: Optional[int] = None
|
||||
|
||||
@property
|
||||
def supports_fp32(self) -> bool:
|
||||
return self.attn_mode not in ["flash"]
|
||||
|
||||
@property
|
||||
def requires_same_dtype(self) -> bool:
|
||||
return self.attn_mode in ["xformers"]
|
||||
|
||||
@staticmethod
|
||||
def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams":
|
||||
return AttentionParams(attn_mode, split_attn)
|
||||
@@ -95,7 +103,7 @@ def attention(
|
||||
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
|
||||
k: Key tensor [B, L, H, D].
|
||||
v: Value tensor [B, L, H, D].
|
||||
attn_param: Attention parameters including mask and sequence lengths.
|
||||
attn_params: Attention parameters including mask and sequence lengths.
|
||||
drop_rate: Attention dropout rate.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -195,6 +195,9 @@ class ModelOffloader(Offloader):
|
||||
self.remove_handles.append(handle)
|
||||
|
||||
def set_forward_only(self, forward_only: bool):
|
||||
# switching must wait for all pending transfers
|
||||
for block_idx in list(self.futures.keys()):
|
||||
self._wait_blocks_move(block_idx)
|
||||
self.forward_only = forward_only
|
||||
|
||||
def __del__(self):
|
||||
@@ -237,6 +240,10 @@ class ModelOffloader(Offloader):
|
||||
if self.debug:
|
||||
print(f"Prepare block devices before forward")
|
||||
|
||||
# wait for all pending transfers
|
||||
for block_idx in list(self.futures.keys()):
|
||||
self._wait_blocks_move(block_idx)
|
||||
|
||||
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
|
||||
b.to(self.device)
|
||||
weighs_to_device(b, self.device) # make sure weights are on device
|
||||
|
||||
@@ -1008,14 +1008,19 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina
|
||||
return {"sample": decoded}
|
||||
|
||||
def decode_to_pixels(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
vae_scale_factor = 2 ** len(self.temperal_downsample)
|
||||
# latents = qwen_image_utils.unpack_latents(latent, height, width, vae_scale_factor)
|
||||
is_4d = latents.dim() == 4
|
||||
if is_4d:
|
||||
latents = latents.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
|
||||
|
||||
latents = latents.to(self.dtype)
|
||||
latents_mean = torch.tensor(self.latents_mean).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = 1.0 / torch.tensor(self.latents_std).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents / latents_std + latents_mean
|
||||
image = self.decode(latents, return_dict=False)[0][:, :, 0] # -1 to 1
|
||||
# return (image * 0.5 + 0.5).clamp(0.0, 1.0) # Convert to [0, 1] range
|
||||
|
||||
image = self.decode(latents, return_dict=False)[0] # -1 to 1
|
||||
if is_4d:
|
||||
image = image.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
|
||||
|
||||
return image.clamp(-1.0, 1.0)
|
||||
|
||||
def encode_pixels_to_latents(self, pixels: torch.Tensor) -> torch.Tensor:
|
||||
@@ -1032,7 +1037,8 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina
|
||||
# pixels = (pixels * 2.0 - 1.0).clamp(-1.0, 1.0)
|
||||
|
||||
# Handle 2D input by adding temporal dimension
|
||||
if pixels.dim() == 4:
|
||||
is_4d = pixels.dim() == 4
|
||||
if is_4d:
|
||||
pixels = pixels.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
|
||||
|
||||
pixels = pixels.to(self.dtype)
|
||||
@@ -1047,6 +1053,9 @@ class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOrigina
|
||||
latents_std = 1.0 / torch.tensor(self.latents_std).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * latents_std
|
||||
|
||||
if is_4d:
|
||||
latents = latents.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
|
||||
|
||||
return latents
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
|
||||
@@ -291,7 +291,7 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
Qwen Image VAE accepts inputs in (B, C, H, W) or (B, C, 1, H, W) shape.
|
||||
Returns latents in (B, 16, 1, H/8, W/8) shape on CPU.
|
||||
"""
|
||||
latents = vae.encode_pixels_to_latents(img_tensor)
|
||||
latents = vae.encode_pixels_to_latents(img_tensor) # Keep 4D for input/output
|
||||
return latents.to("cpu")
|
||||
|
||||
self._default_cache_batch_latents(
|
||||
|
||||
@@ -284,7 +284,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
if re.fullmatch(reg, original_name):
|
||||
dim = d
|
||||
alpha_val = self.alpha
|
||||
logger.info(f"LoRA {original_name} matched with regex {reg}, using dim: {dim}")
|
||||
logger.info(f"Module {original_name} matched with regex '{reg}' -> dim: {dim}")
|
||||
break
|
||||
# fallback to default dim if not matched by reg_dims or reg_dims is not specified
|
||||
if dim is None:
|
||||
|
||||
Reference in New Issue
Block a user