This commit is contained in:
Kohya S.
2026-02-09 14:31:40 +00:00
committed by GitHub
20 changed files with 3636 additions and 1114 deletions

View File

@@ -32,6 +32,7 @@ hime="hime"
OT="OT" OT="OT"
byt="byt" byt="byt"
tak="tak" tak="tak"
temperal="temperal"
[files] [files]
extend-exclude = ["_typos.toml", "venv"] extend-exclude = ["_typos.toml", "venv", "configs"]

1020
anima_minimal_inference.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -49,35 +49,32 @@ def train(args):
args.skip_cache_check = args.skip_latents_validity_check args.skip_cache_check = args.skip_latents_validity_check
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning( logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
)
args.cache_text_encoder_outputs = True args.cache_text_encoder_outputs = True
if args.cpu_offload_checkpointing and not args.gradient_checkpointing: if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled") logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True args.gradient_checkpointing = True
if getattr(args, 'unsloth_offload_checkpointing', False): if getattr(args, "unsloth_offload_checkpointing", False):
if not args.gradient_checkpointing: if not args.gradient_checkpointing:
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled") logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True args.gradient_checkpointing = True
assert not args.cpu_offload_checkpointing, \ assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
"Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert ( assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0 args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing" ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
assert ( assert (args.blocks_to_swap is None or args.blocks_to_swap == 0) or not getattr(
args.blocks_to_swap is None or args.blocks_to_swap == 0 args, "unsloth_offload_checkpointing", False
) or not getattr(args, 'unsloth_offload_checkpointing', False), \ ), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
"blocks_to_swap is not supported with unsloth_offload_checkpointing"
# Flash attention: validate availability # Flash attention: validate availability
if getattr(args, 'flash_attn', False): if getattr(args, "flash_attn", False):
try: try:
import flash_attn # noqa: F401 import flash_attn # noqa: F401
logger.info("Flash Attention enabled for DiT blocks") logger.info("Flash Attention enabled for DiT blocks")
except ImportError: except ImportError:
logger.warning("flash_attn package not installed, falling back to PyTorch SDPA") logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
@@ -104,9 +101,7 @@ def train(args):
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"] ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored): if any(getattr(args, attr) is not None for attr in ignored):
logger.warning( logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored)))
"ignore following options because config file is found: {0}".format(", ".join(ignored))
)
else: else:
if use_dreambooth_method: if use_dreambooth_method:
logger.info("Using DreamBooth method.") logger.info("Using DreamBooth method.")
@@ -150,7 +145,7 @@ def train(args):
# Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of # Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
# dataset-level caption dropout, so we save the rate and zero out subset-level # dataset-level caption dropout, so we save the rate and zero out subset-level
# caption_dropout_rate to allow text encoder output caching. # caption_dropout_rate to allow text encoder output caching.
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0)
if caption_dropout_rate > 0: if caption_dropout_rate > 0:
logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}") logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
for dataset in train_dataset_group.datasets: for dataset in train_dataset_group.datasets:
@@ -175,9 +170,7 @@ def train(args):
return return
if cache_latents: if cache_latents:
assert ( assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used"
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used"
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
assert ( assert (
@@ -193,7 +186,7 @@ def train(args):
# parse transformer_dtype # parse transformer_dtype
transformer_dtype = None transformer_dtype = None
if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None: if hasattr(args, "transformer_dtype") and args.transformer_dtype is not None:
transformer_dtype_map = { transformer_dtype_map = {
"float16": torch.float16, "float16": torch.float16,
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
@@ -203,12 +196,8 @@ def train(args):
# Load tokenizers and set strategies # Load tokenizers and set strategies
logger.info("Loading tokenizers...") logger.info("Loading tokenizers...")
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder( qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=weight_dtype, device="cpu")
args.qwen3_path, dtype=weight_dtype, device="cpu" t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
)
t5_tokenizer = anima_utils.load_t5_tokenizer(
getattr(args, 't5_tokenizer_path', None)
)
# Set tokenize strategy # Set tokenize strategy
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy( tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
@@ -220,7 +209,7 @@ def train(args):
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# Set text encoding strategy # Set text encoding strategy
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0)
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy( text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
dropout_rate=caption_dropout_rate, dropout_rate=caption_dropout_rate,
) )
@@ -266,10 +255,8 @@ def train(args):
) )
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted # Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) with accelerator.autocast():
if caption_dropout_rate > 0.0: text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
with accelerator.autocast():
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@@ -299,17 +286,17 @@ def train(args):
dtype=weight_dtype, dtype=weight_dtype,
device="cpu", device="cpu",
transformer_dtype=transformer_dtype, transformer_dtype=transformer_dtype,
llm_adapter_path=getattr(args, 'llm_adapter_path', None), llm_adapter_path=getattr(args, "llm_adapter_path", None),
disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False), disable_mmap=getattr(args, "disable_mmap_load_safetensors", False),
) )
if args.gradient_checkpointing: if args.gradient_checkpointing:
dit.enable_gradient_checkpointing( dit.enable_gradient_checkpointing(
cpu_offload=args.cpu_offload_checkpointing, cpu_offload=args.cpu_offload_checkpointing,
unsloth_offload=getattr(args, 'unsloth_offload_checkpointing', False), unsloth_offload=getattr(args, "unsloth_offload_checkpointing", False),
) )
if getattr(args, 'flash_attn', False): if getattr(args, "flash_attn", False):
dit.set_flash_attn(True) dit.set_flash_attn(True)
train_dit = args.learning_rate != 0 train_dit = args.learning_rate != 0
@@ -335,11 +322,11 @@ def train(args):
param_groups = anima_train_utils.get_anima_param_groups( param_groups = anima_train_utils.get_anima_param_groups(
dit, dit,
base_lr=args.learning_rate, base_lr=args.learning_rate,
self_attn_lr=getattr(args, 'self_attn_lr', None), self_attn_lr=getattr(args, "self_attn_lr", None),
cross_attn_lr=getattr(args, 'cross_attn_lr', None), cross_attn_lr=getattr(args, "cross_attn_lr", None),
mlp_lr=getattr(args, 'mlp_lr', None), mlp_lr=getattr(args, "mlp_lr", None),
mod_lr=getattr(args, 'mod_lr', None), mod_lr=getattr(args, "mod_lr", None),
llm_adapter_lr=getattr(args, 'llm_adapter_lr', None), llm_adapter_lr=getattr(args, "llm_adapter_lr", None),
) )
else: else:
param_groups = [] param_groups = []
@@ -366,8 +353,8 @@ def train(args):
# Build param_id → lr mapping from param_groups to propagate per-component LRs # Build param_id → lr mapping from param_groups to propagate per-component LRs
param_lr_map = {} param_lr_map = {}
for group in param_groups: for group in param_groups:
for p in group['params']: for p in group["params"]:
param_lr_map[id(p)] = group['lr'] param_lr_map[id(p)] = group["lr"]
grouped_params = [] grouped_params = []
param_group = {} param_group = {}
@@ -557,9 +544,7 @@ def train(args):
accelerator.print(f" num examples: {train_dataset_group.num_train_images}") accelerator.print(f" num examples: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch: {len(train_dataloader)}") accelerator.print(f" num batches per epoch: {len(train_dataloader)}")
accelerator.print(f" num epochs: {num_train_epochs}") accelerator.print(f" num epochs: {num_train_epochs}")
accelerator.print( accelerator.print(f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}") accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps: {args.max_train_steps}") accelerator.print(f" total optimization steps: {args.max_train_steps}")
@@ -580,6 +565,7 @@ def train(args):
if "wandb" in [tracker.name for tracker in accelerator.trackers]: if "wandb" in [tracker.name for tracker in accelerator.trackers]:
import wandb import wandb
wandb.define_metric("epoch") wandb.define_metric("epoch")
wandb.define_metric("loss/epoch", step_metric="epoch") wandb.define_metric("loss/epoch", step_metric="epoch")
@@ -589,8 +575,16 @@ def train(args):
# For --sample_at_first # For --sample_at_first
optimizer_eval_fn() optimizer_eval_fn()
anima_train_utils.sample_images( anima_train_utils.sample_images(
accelerator, args, 0, global_step, dit, vae, vae_scale, accelerator,
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy, args,
0,
global_step,
dit,
vae,
vae_scale,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs, sample_prompts_te_outputs,
) )
optimizer_train_fn() optimizer_train_fn()
@@ -600,7 +594,9 @@ def train(args):
# Show model info # Show model info
unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None
if unwrapped_dit is not None: if unwrapped_dit is not None:
logger.info(f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}") logger.info(
f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}"
)
if qwen3_text_encoder is not None: if qwen3_text_encoder is not None:
logger.info(f"qwen3 device: {next(qwen3_text_encoder.parameters()).device}") logger.info(f"qwen3 device: {next(qwen3_text_encoder.parameters()).device}")
if vae is not None: if vae is not None:
@@ -640,9 +636,7 @@ def train(args):
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None: if text_encoder_outputs_list is not None:
# Cached outputs # Cached outputs
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs( text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
*text_encoder_outputs_list
)
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list
else: else:
# Encode on-the-fly # Encode on-the-fly
@@ -678,10 +672,7 @@ def train(args):
bs = latents.shape[0] bs = latents.shape[0]
h_latent = latents.shape[-2] h_latent = latents.shape[-2]
w_latent = latents.shape[-1] w_latent = latents.shape[-1]
padding_mask = torch.zeros( padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
bs, 1, h_latent, w_latent,
dtype=weight_dtype, device=accelerator.device
)
# DiT forward (LLM adapter runs inside forward for DDP gradient sync) # DiT forward (LLM adapter runs inside forward for DDP gradient sync)
if is_swapping_blocks: if is_swapping_blocks:
@@ -708,9 +699,7 @@ def train(args):
# Loss # Loss
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None) huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None)
loss = train_util.conditional_loss( loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
model_pred.float(), target.float(), args.loss_type, "none", huber_c
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch) loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3, 4]) # (B, C, T, H, W) -> (B,) loss = loss.mean([1, 2, 3, 4]) # (B, C, T, H, W) -> (B,)
@@ -748,8 +737,16 @@ def train(args):
optimizer_eval_fn() optimizer_eval_fn()
anima_train_utils.sample_images( anima_train_utils.sample_images(
accelerator, args, None, global_step, dit, vae, vae_scale, accelerator,
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy, args,
None,
global_step,
dit,
vae,
vae_scale,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs, sample_prompts_te_outputs,
) )
@@ -773,8 +770,10 @@ def train(args):
if len(accelerator.trackers) > 0: if len(accelerator.trackers) > 0:
logs = {"loss": current_loss} logs = {"loss": current_loss}
train_util.append_lr_to_logs_with_names( train_util.append_lr_to_logs_with_names(
logs, lr_scheduler, args.optimizer_type, logs,
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [] lr_scheduler,
args.optimizer_type,
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [],
) )
accelerator.log(logs, step=global_step) accelerator.log(logs, step=global_step)
@@ -807,8 +806,16 @@ def train(args):
) )
anima_train_utils.sample_images( anima_train_utils.sample_images(
accelerator, args, epoch + 1, global_step, dit, vae, vae_scale, accelerator,
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy, args,
epoch + 1,
global_step,
dit,
vae,
vae_scale,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs, sample_prompts_te_outputs,
) )
@@ -852,11 +859,11 @@ def setup_parser() -> argparse.ArgumentParser:
anima_train_utils.add_anima_training_arguments(parser) anima_train_utils.add_anima_training_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser) sai_model_spec.add_model_spec_arguments(parser)
parser.add_argument( # parser.add_argument(
"--blockwise_fused_optimizers", # "--blockwise_fused_optimizers",
action="store_true", # action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step", # help="enable blockwise optimizers for fused backward pass and optimizer step",
) # )
parser.add_argument( parser.add_argument(
"--cpu_offload_checkpointing", "--cpu_offload_checkpointing",
action="store_true", action="store_true",

View File

@@ -1,16 +1,26 @@
# Anima LoRA training script # Anima LoRA training script
import argparse import argparse
import math
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
import torch.nn as nn
from accelerate import Accelerator from accelerate import Accelerator
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
init_ipex() init_ipex()
from library import anima_models, anima_train_utils, anima_utils, strategy_anima, strategy_base, train_util from library import (
anima_models,
anima_train_utils,
anima_utils,
flux_train_utils,
qwen_image_autoencoder_kl,
sd3_train_utils,
strategy_anima,
strategy_base,
train_util,
)
import train_network import train_network
from library.utils import setup_logging from library.utils import setup_logging
@@ -24,13 +34,6 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.sample_prompts_te_outputs = None self.sample_prompts_te_outputs = None
self.vae = None
self.vae_scale = None
self.qwen3_text_encoder = None
self.qwen3_tokenizer = None
self.t5_tokenizer = None
self.tokenize_strategy = None
self.text_encoding_strategy = None
def assert_extra_args( def assert_extra_args(
self, self,
@@ -38,137 +41,113 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup], val_dataset_group: Optional[train_util.DatasetGroup],
): ):
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: if (args.fp8_base or args.fp8_base_unet) and not args.fp8_scaled:
logger.warning( logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled" "fp8_base and fp8_base_unet are not supported. Use fp8_scaled instead / fp8_baseとfp8_base_unetはサポートされていません。代わりにfp8_scaledを使用してください"
) )
if args.fp8_scaled and (args.fp8_base or args.fp8_base_unet):
logger.info(
"fp8_scaled is used, so fp8_base and fp8_base_unet are ignored / fp8_scaledが使われているので、fp8_baseとfp8_base_unetは無視されます"
)
args.fp8_base = False
args.fp8_base_unet = False
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
args.cache_text_encoder_outputs = True args.cache_text_encoder_outputs = True
# Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
# dataset-level caption dropout, so zero out subset-level rates to allow caching.
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
if caption_dropout_rate > 0:
logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
if hasattr(train_dataset_group, 'datasets'):
for dataset in train_dataset_group.datasets:
for subset in dataset.subsets:
subset.caption_dropout_rate = 0.0
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
assert ( assert train_dataset_group.is_text_encoder_output_cacheable(
train_dataset_group.is_text_encoder_output_cacheable() cache_supports_dropout=True
), "when caching Text Encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used" ), "when caching Text Encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
assert ( assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0 args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing" ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
if getattr(args, 'unsloth_offload_checkpointing', False): if args.unsloth_offload_checkpointing:
if not args.gradient_checkpointing: if not args.gradient_checkpointing:
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled") logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True args.gradient_checkpointing = True
assert not args.cpu_offload_checkpointing, \ assert (
"Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing" not args.cpu_offload_checkpointing
), "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert ( assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0 args.blocks_to_swap is None or args.blocks_to_swap == 0
), "blocks_to_swap is not supported with unsloth_offload_checkpointing" ), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
# Flash attention: validate availability train_dataset_group.verify_bucket_reso_steps(16) # WanVAE spatial downscale = 8 and patch size = 2
if getattr(args, 'flash_attn', False):
try:
import flash_attn # noqa: F401
logger.info("Flash Attention enabled for DiT blocks")
except ImportError:
logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
args.flash_attn = False
if getattr(args, 'blockwise_fused_optimizers', False):
raise ValueError("blockwise_fused_optimizers is not supported with LoRA/NetworkTrainer")
train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8
if val_dataset_group is not None: if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(8) val_dataset_group.verify_bucket_reso_steps(16)
def load_target_model(self, args, weight_dtype, accelerator): def load_target_model(self, args, weight_dtype, accelerator):
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
# Load Qwen3 text encoder (tokenizers already loaded in get_tokenize_strategy) # Load Qwen3 text encoder (tokenizers already loaded in get_tokenize_strategy)
logger.info("Loading Qwen3 text encoder...") logger.info("Loading Qwen3 text encoder...")
self.qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder( qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
args.qwen3_path, dtype=weight_dtype, device="cpu" qwen3_text_encoder.eval()
)
self.qwen3_text_encoder.eval()
# Parse transformer_dtype # Load VAE
transformer_dtype = None logger.info("Loading Anima VAE...")
if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None: vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
transformer_dtype_map = {
"float16": torch.float16, # Return format: (model_type, text_encoders, vae, unet)
"bfloat16": torch.bfloat16, return "anima", [qwen3_text_encoder], vae, None # unet loaded lazily
"float32": torch.float32,
} def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, list[nn.Module]]:
transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None) loading_dtype = None if args.fp8_scaled else weight_dtype
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
attn_mode = "torch"
if args.xformers:
attn_mode = "xformers"
if args.attn_mode is not None:
attn_mode = args.attn_mode
# Load DiT # Load DiT
logger.info("Loading Anima DiT...") logger.info(f"Loading Anima DiT model with attn_mode={attn_mode}, split_attn: {args.split_attn}...")
dit = anima_utils.load_anima_dit( model = anima_utils.load_anima_model(
args.dit_path, accelerator.device,
dtype=weight_dtype, args.pretrained_model_name_or_path,
device="cpu", attn_mode,
transformer_dtype=transformer_dtype, args.split_attn,
llm_adapter_path=getattr(args, 'llm_adapter_path', None), loading_device,
disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False), loading_dtype,
args.fp8_scaled,
) )
# Flash attention
if getattr(args, 'flash_attn', False):
dit.set_flash_attn(True)
# Store unsloth preference so that when the base NetworkTrainer calls # Store unsloth preference so that when the base NetworkTrainer calls
# dit.enable_gradient_checkpointing(cpu_offload=...), we can override to use unsloth. # dit.enable_gradient_checkpointing(cpu_offload=...), we can override to use unsloth.
# The base trainer only passes cpu_offload, so we store the flag on the model. # The base trainer only passes cpu_offload, so we store the flag on the model.
self._use_unsloth_offload_checkpointing = getattr(args, 'unsloth_offload_checkpointing', False) self._use_unsloth_offload_checkpointing = args.unsloth_offload_checkpointing
# Block swap # Block swap
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks: if self.is_swapping_blocks:
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
dit.enable_block_swap(args.blocks_to_swap, accelerator.device) model.enable_block_swap(args.blocks_to_swap, accelerator.device)
# Load VAE return model, text_encoders
logger.info("Loading Anima VAE...")
self.vae, vae_mean, vae_std, self.vae_scale = anima_utils.load_anima_vae(
args.vae_path, dtype=weight_dtype, device="cpu"
)
# Return format: (model_type, text_encoders, vae, unet)
return "anima", [self.qwen3_text_encoder], self.vae, dit
def get_tokenize_strategy(self, args): def get_tokenize_strategy(self, args):
# Load tokenizers from paths (called before load_target_model, so self.qwen3_tokenizer isn't set yet) # Load tokenizers from paths (called before load_target_model, so self.qwen3_tokenizer isn't set yet)
self.tokenize_strategy = strategy_anima.AnimaTokenizeStrategy( tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
qwen3_path=args.qwen3_path, qwen3_path=args.qwen3,
t5_tokenizer_path=getattr(args, 't5_tokenizer_path', None), t5_tokenizer_path=args.t5_tokenizer_path,
qwen3_max_length=args.qwen3_max_token_length, qwen3_max_length=args.qwen3_max_token_length,
t5_max_length=args.t5_max_token_length, t5_max_length=args.t5_max_token_length,
) )
# Store references so load_target_model can reuse them return tokenize_strategy
self.qwen3_tokenizer = self.tokenize_strategy.qwen3_tokenizer
self.t5_tokenizer = self.tokenize_strategy.t5_tokenizer
return self.tokenize_strategy
def get_tokenizers(self, tokenize_strategy: strategy_anima.AnimaTokenizeStrategy): def get_tokenizers(self, tokenize_strategy: strategy_anima.AnimaTokenizeStrategy):
return [tokenize_strategy.qwen3_tokenizer] return [tokenize_strategy.qwen3_tokenizer]
def get_latents_caching_strategy(self, args): def get_latents_caching_strategy(self, args):
return strategy_anima.AnimaLatentsCachingStrategy( return strategy_anima.AnimaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check)
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
def get_text_encoding_strategy(self, args): def get_text_encoding_strategy(self, args):
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) return strategy_anima.AnimaTextEncodingStrategy()
self.text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
dropout_rate=caption_dropout_rate,
)
return self.text_encoding_strategy
def post_process_network(self, args, accelerator, network, text_encoders, unet): def post_process_network(self, args, accelerator, network, text_encoders, unet):
# Qwen3 text encoder is always frozen for Anima # Qwen3 text encoder is always frozen for Anima
@@ -188,10 +167,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoder_outputs_caching_strategy(self, args): def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
return strategy_anima.AnimaTextEncoderOutputsCachingStrategy( return strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=False,
) )
return None return None
@@ -200,15 +176,14 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
): ):
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
if not args.lowram: if not args.lowram:
logger.info("move vae and unet to cpu to save memory") # We cannot move DiT to CPU because of block swap, so only move VAE
org_vae_device = next(vae.parameters()).device logger.info("move vae to cpu to save memory")
org_unet_device = unet.device org_vae_device = vae.device
vae.to("cpu") vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
logger.info("move text encoder to gpu") logger.info("move text encoder to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype) text_encoders[0].to(accelerator.device)
with accelerator.autocast(): with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
@@ -229,59 +204,52 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
logger.info(f" cache TE outputs for: {p}") logger.info(f" cache TE outputs for: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p) tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy, tokenize_strategy, text_encoders, tokens_and_masks
text_encoders,
tokens_and_masks,
enable_dropout=False,
) )
self.sample_prompts_te_outputs = sample_prompts_te_outputs self.sample_prompts_te_outputs = sample_prompts_te_outputs
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
text_encoding_strategy_for_uncond = strategy_base.TextEncodingStrategy.get_strategy()
if caption_dropout_rate > 0.0:
tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy()
with accelerator.autocast():
text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# move text encoder back to cpu # move text encoder back to cpu
logger.info("move text encoder back to cpu") logger.info("move text encoder back to cpu")
text_encoders[0].to("cpu") text_encoders[0].to("cpu")
clean_memory_on_device(accelerator.device)
if not args.lowram: if not args.lowram:
logger.info("move vae and unet back to original device") logger.info("move vae back to original device")
vae.to(org_vae_device) vae.to(org_vae_device)
unet.to(org_unet_device)
clean_memory_on_device(accelerator.device)
else: else:
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # move text encoder to device for encoding during training/validation
text_encoders[0].to(accelerator.device)
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] # compatibility text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] # compatibility
te = self.get_models_for_text_encoding(args, accelerator, text_encoders) te = self.get_models_for_text_encoding(args, accelerator, text_encoders)
qwen3_te = te[0] if te is not None else None qwen3_te = te[0] if te is not None else None
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
anima_train_utils.sample_images( anima_train_utils.sample_images(
accelerator, args, epoch, global_step, unet, vae, self.vae_scale, accelerator,
qwen3_te, self.tokenize_strategy, self.text_encoding_strategy, args,
epoch,
global_step,
unet,
vae,
qwen3_te,
tokenize_strategy,
text_encoding_strategy,
self.sample_prompts_te_outputs, self.sample_prompts_te_outputs,
) )
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = anima_train_utils.FlowMatchEulerDiscreteScheduler( noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
num_train_timesteps=1000, shift=args.discrete_flow_shift
)
return noise_scheduler return noise_scheduler
def encode_images_to_latents(self, args, vae, images): def encode_images_to_latents(self, args, vae, images):
# images are already [-1,1] from IMAGE_TRANSFORMS, add temporal dim vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage
images = images.unsqueeze(2) # (B, C, 1, H, W) return vae.encode_pixels_to_latents(images)
# Ensure scale tensors are on the same device as images
vae_device = images.device
scale = [s.to(vae_device) if isinstance(s, torch.Tensor) else s for s in self.vae_scale]
return vae.encode(images, scale)
def shift_scale_latents(self, args, latents): def shift_scale_latents(self, args, latents):
# Latents already normalized by vae.encode with scale # Latents already normalized by vae.encode with scale
@@ -301,13 +269,16 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
train_unet, train_unet,
is_train=True, is_train=True,
): ):
anima: anima_models.Anima = unet
# Sample noise # Sample noise
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
# Get noisy model input and timesteps # Get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps( noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, latents, noise, accelerator.device, weight_dtype args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
) )
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
# Gradient checkpointing support # Gradient checkpointing support
if args.gradient_checkpointing: if args.gradient_checkpointing:
@@ -329,147 +300,86 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
bs = latents.shape[0] bs = latents.shape[0]
h_latent = latents.shape[-2] h_latent = latents.shape[-2]
w_latent = latents.shape[-1] w_latent = latents.shape[-1]
padding_mask = torch.zeros( padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
bs, 1, h_latent, w_latent,
dtype=weight_dtype, device=accelerator.device
)
# Prepare block swap # Prepare block swap
if self.is_swapping_blocks: if self.is_swapping_blocks:
accelerator.unwrap_model(unet).prepare_block_swap_before_forward() accelerator.unwrap_model(anima).prepare_block_swap_before_forward()
# Call model (LLM adapter runs inside forward for DDP gradient sync) # Call model
with torch.set_grad_enabled(is_train), accelerator.autocast(): with torch.set_grad_enabled(is_train), accelerator.autocast():
model_pred = unet( model_pred = anima(
noisy_model_input, noisy_model_input,
timesteps, timesteps,
prompt_embeds, prompt_embeds,
padding_mask=padding_mask, padding_mask=padding_mask,
target_input_ids=t5_input_ids,
target_attention_mask=t5_attn_mask,
source_attention_mask=attn_mask, source_attention_mask=attn_mask,
t5_input_ids=t5_input_ids,
t5_attn_mask=t5_attn_mask,
) )
# Rectified flow target: noise - latents # Rectified flow target: noise - latents
target = noise - latents target = noise - latents
# Loss weighting # Loss weighting
weighting = anima_train_utils.compute_loss_weighting_for_anima( weighting = anima_train_utils.compute_loss_weighting_for_anima(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
weighting_scheme=args.weighting_scheme, sigmas=sigmas
)
# Differential output preservation
if "custom_attributes" in batch:
diff_output_pr_indices = []
for i, custom_attributes in enumerate(batch["custom_attributes"]):
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
diff_output_pr_indices.append(i)
if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
if self.is_swapping_blocks:
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
model_pred_prior = unet(
noisy_model_input[diff_output_pr_indices],
timesteps[diff_output_pr_indices],
prompt_embeds[diff_output_pr_indices],
padding_mask=padding_mask[diff_output_pr_indices],
source_attention_mask=attn_mask[diff_output_pr_indices],
t5_input_ids=t5_input_ids[diff_output_pr_indices],
t5_attn_mask=t5_attn_mask[diff_output_pr_indices],
)
network.set_multiplier(1.0)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, weighting return model_pred, target, timesteps, weighting
def process_batch( def process_batch(
self, batch, text_encoders, unet, network, vae, noise_scheduler, self,
vae_dtype, weight_dtype, accelerator, args, batch,
text_encoding_strategy, tokenize_strategy, text_encoders,
is_train=True, train_text_encoder=True, train_unet=True, unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=True,
train_unet=True,
) -> torch.Tensor: ) -> torch.Tensor:
"""Override base process_batch for 5D video latents (B, C, T, H, W). """Override base process_batch for caption dropout with cached text encoder outputs.
Base class assumes 4D (B, C, H, W) for loss.mean([1,2,3]) and weighting broadcast. Base class now supports 4D and 5D latents, so we only need to handle caption dropout here.
""" """
import typing
from library.custom_train_functions import apply_masked_loss
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
else:
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
else:
chunks = [
batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
]
list_latents = []
for chunk in chunks:
with torch.no_grad():
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
list_latents.append(chunk)
latents = torch.cat(list_latents, dim=0)
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents))
latents = self.shift_scale_latents(args, latents)
# Text encoder conditions # Text encoder conditions
text_encoder_conds = []
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
anima_text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = text_encoding_strategy
if text_encoder_outputs_list is not None: if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list caption_dropout_rates = text_encoder_outputs_list[-1]
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: # Apply caption dropout to cached outputs
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): text_encoder_outputs_list = anima_text_encoding_strategy.drop_cached_text_encoder_outputs(
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] *text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( )
tokenize_strategy, batch["text_encoder_outputs_list"] = text_encoder_outputs_list
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
)
if args.full_fp16:
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
if len(text_encoder_conds) == 0: return super().process_batch(
text_encoder_conds = encoded_text_encoder_conds batch,
else: text_encoders,
for i in range(len(encoded_text_encoder_conds)): unet,
if encoded_text_encoder_conds[i] is not None: network,
text_encoder_conds[i] = encoded_text_encoder_conds[i] vae,
noise_scheduler,
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( vae_dtype,
args, accelerator, noise_scheduler, latents, batch, weight_dtype,
text_encoder_conds, unet, network, weight_dtype, train_unet, is_train=is_train, accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train,
train_text_encoder,
train_unet,
) )
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
# Reduce all non-batch dims: (B, C, T, H, W) -> (B,) for 5D, (B, C, H, W) -> (B,) for 4D
reduce_dims = list(range(1, loss.ndim))
loss = loss.mean(reduce_dims)
# Apply weighting after reducing to (B,)
if weighting is not None:
loss = loss * weighting
loss_weights = batch["loss_weights"]
loss = loss * loss_weights
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
return loss.mean()
def post_process_loss(self, loss, args, timesteps, noise_scheduler): def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss return loss
@@ -478,12 +388,15 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
def update_metadata(self, metadata, args): def update_metadata(self, metadata, args):
metadata["ss_weighting_scheme"] = args.weighting_scheme metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
metadata["ss_timestep_sampling"] = args.timestep_sampling
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
metadata["ss_timestep_sample_method"] = getattr(args, 'timestep_sample_method', 'logit_normal')
metadata["ss_sigmoid_scale"] = getattr(args, 'sigmoid_scale', 1.0)
def is_text_encoder_not_needed_for_training(self, args): def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_unet_with_accelerator( def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
@@ -496,20 +409,12 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
if not self.is_swapping_blocks: if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet) return super().prepare_unet_with_accelerator(args, accelerator, unet)
dit = unet model = unet
dit = accelerator.prepare(dit, device_placement=[not self.is_swapping_blocks]) model = accelerator.prepare(model, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(dit).move_to_device_except_swap_blocks(accelerator.device) accelerator.unwrap_model(model).move_to_device_except_swap_blocks(accelerator.device)
accelerator.unwrap_model(dit).prepare_block_swap_before_forward() accelerator.unwrap_model(model).prepare_block_swap_before_forward()
return dit return model
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
# Drop cached text encoder outputs for caption dropout
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks: if self.is_swapping_blocks:
@@ -520,6 +425,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser() parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser) train_util.add_dit_training_arguments(parser)
anima_train_utils.add_anima_training_arguments(parser) anima_train_utils.add_anima_training_arguments(parser)
parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
parser.add_argument( parser.add_argument(
"--unsloth_offload_checkpointing", "--unsloth_offload_checkpointing",
action="store_true", action="store_true",
@@ -536,5 +442,8 @@ if __name__ == "__main__":
train_util.verify_command_line_training_args(args) train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser) args = train_util.read_config_from_file(args, parser)
if args.attn_mode == "sdpa":
args.attn_mode = "torch" # backward compatibility
trainer = AnimaNetworkTrainer() trainer = AnimaNetworkTrainer()
trainer.train(args) trainer.train(args)

View File

@@ -118,7 +118,7 @@ accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
--optimizer_type="AdamW8bit" \ --optimizer_type="AdamW8bit" \
--lr_scheduler="constant" \ --lr_scheduler="constant" \
--timestep_sample_method="logit_normal" \ --timestep_sample_method="logit_normal" \
--discrete_flow_shift=3.0 \ --discrete_flow_shift=1.0 \
--max_train_epochs=10 \ --max_train_epochs=10 \
--save_every_n_epochs=1 \ --save_every_n_epochs=1 \
--mixed_precision="bf16" \ --mixed_precision="bf16" \
@@ -162,7 +162,7 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
* `--timestep_sample_method=<choice>` * `--timestep_sample_method=<choice>`
- Timestep sampling method. Choose from `logit_normal` (default) or `uniform`. - Timestep sampling method. Choose from `logit_normal` (default) or `uniform`.
* `--discrete_flow_shift=<float>` * `--discrete_flow_shift=<float>`
- Shift for the timestep distribution in Rectified Flow training. Default `3.0`. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`. - Shift for the timestep distribution in Rectified Flow training. Default `1.0`. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`. 1.0 means no shift.
* `--sigmoid_scale=<float>` * `--sigmoid_scale=<float>`
- Scale factor for `logit_normal` timestep sampling. Default `1.0`. - Scale factor for `logit_normal` timestep sampling. Default `1.0`.
* `--qwen3_max_token_length=<integer>` * `--qwen3_max_token_length=<integer>`

View File

@@ -13,11 +13,10 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as torch_checkpoint from torch.utils.checkpoint import checkpoint as torch_checkpoint
from library import custom_offloading_utils from library import custom_offloading_utils, attention
from library.device_utils import clean_memory_on_device from library.device_utils import clean_memory_on_device
def to_device(x, device): def to_device(x, device):
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
return x.to(device) return x.to(device)
@@ -39,11 +38,13 @@ def to_cpu(x):
else: else:
return x return x
# Unsloth Offloaded Gradient Checkpointing # Unsloth Offloaded Gradient Checkpointing
# Based on Unsloth Zoo by Daniel Han-Chen & the Unsloth team # Based on Unsloth Zoo by Daniel Han-Chen & the Unsloth team
try: try:
from deepspeed.runtime.activation_checkpointing.checkpointing import detach_variable from deepspeed.runtime.activation_checkpointing.checkpointing import detach_variable
except ImportError: except ImportError:
def detach_variable(inputs, device=None): def detach_variable(inputs, device=None):
"""Detach tensors from computation graph, optionally moving to a device. """Detach tensors from computation graph, optionally moving to a device.
@@ -80,11 +81,11 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
""" """
@staticmethod @staticmethod
@torch.amp.custom_fwd(device_type='cuda') @torch.amp.custom_fwd(device_type="cuda")
def forward(ctx, forward_function, hidden_states, *args): def forward(ctx, forward_function, hidden_states, *args):
# Remember the original device for backward pass (multi-GPU support) # Remember the original device for backward pass (multi-GPU support)
ctx.input_device = hidden_states.device ctx.input_device = hidden_states.device
saved_hidden_states = hidden_states.to('cpu', non_blocking=True) saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad(): with torch.no_grad():
output = forward_function(hidden_states, *args) output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states) ctx.save_for_backward(saved_hidden_states)
@@ -96,7 +97,7 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
return output return output
@staticmethod @staticmethod
@torch.amp.custom_bwd(device_type='cuda') @torch.amp.custom_bwd(device_type="cuda")
def backward(ctx, *grads): def backward(ctx, *grads):
(hidden_states,) = ctx.saved_tensors (hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to(ctx.input_device, non_blocking=True).detach() hidden_states = hidden_states.to(ctx.input_device, non_blocking=True).detach()
@@ -108,8 +109,9 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
output_tensors = [] output_tensors = []
grad_tensors = [] grad_tensors = []
for out, grad in zip(outputs if isinstance(outputs, tuple) else (outputs,), for out, grad in zip(
grads if isinstance(grads, tuple) else (grads,)): outputs if isinstance(outputs, tuple) else (outputs,), grads if isinstance(grads, tuple) else (grads,)
):
if isinstance(out, torch.Tensor) and out.requires_grad: if isinstance(out, torch.Tensor) and out.requires_grad:
output_tensors.append(out) output_tensors.append(out)
grad_tensors.append(grad) grad_tensors.append(grad)
@@ -123,24 +125,24 @@ def unsloth_checkpoint(function, *args):
return UnslothOffloadedGradientCheckpointer.apply(function, *args) return UnslothOffloadedGradientCheckpointer.apply(function, *args)
# Flash Attention support # # Flash Attention support
try: # try:
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
FLASH_ATTN_AVAILABLE = True # FLASH_ATTN_AVAILABLE = True
except ImportError: # except ImportError:
_flash_attn_func = None # _flash_attn_func = None
FLASH_ATTN_AVAILABLE = False # FLASH_ATTN_AVAILABLE = False
def flash_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor: # def flash_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
"""Computes multi-head attention using Flash Attention. # """Computes multi-head attention using Flash Attention.
Input format: (batch, seq_len, n_heads, head_dim) # Input format: (batch, seq_len, n_heads, head_dim)
Output format: (batch, seq_len, n_heads * head_dim) — matches torch_attention_op output. # Output format: (batch, seq_len, n_heads * head_dim) — matches torch_attention_op output.
""" # """
# flash_attn_func expects (B, S, H, D) and returns (B, S, H, D) # # flash_attn_func expects (B, S, H, D) and returns (B, S, H, D)
out = _flash_attn_func(q_B_S_H_D, k_B_S_H_D, v_B_S_H_D) # out = _flash_attn_func(q_B_S_H_D, k_B_S_H_D, v_B_S_H_D)
return rearrange(out, "b s h d -> b s (h d)") # return rearrange(out, "b s h d -> b s (h d)")
from .utils import setup_logging from .utils import setup_logging
@@ -174,14 +176,10 @@ def _apply_rotary_pos_emb_base(
if start_positions is not None: if start_positions is not None:
max_offset = torch.max(start_positions) max_offset = torch.max(start_positions)
assert ( assert max_offset + cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
max_offset + cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1) freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1)
assert ( assert cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len] freqs = freqs[:cur_seq_len]
if tensor_format == "bshd": if tensor_format == "bshd":
@@ -205,13 +203,9 @@ def apply_rotary_pos_emb(
cu_seqlens: Union[torch.Tensor, None] = None, cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1, cp_size: int = 1,
) -> torch.Tensor: ) -> torch.Tensor:
assert not ( assert not (cp_size > 1 and start_positions is not None), "start_positions != None with CP SIZE > 1 is not supported!"
cp_size > 1 and start_positions is not None
), "start_positions != None with CP SIZE > 1 is not supported!"
assert ( assert tensor_format != "thd" or cu_seqlens is not None, "cu_seqlens must not be None when tensor_format is 'thd'."
tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
assert fused == False assert fused == False
@@ -223,9 +217,7 @@ def apply_rotary_pos_emb(
_apply_rotary_pos_emb_base( _apply_rotary_pos_emb_base(
x.unsqueeze(1), x.unsqueeze(1),
freqs, freqs,
start_positions=( start_positions=(start_positions[idx : idx + 1] if start_positions is not None else None),
start_positions[idx : idx + 1] if start_positions is not None else None
),
interleaved=interleaved, interleaved=interleaved,
) )
for idx, x in enumerate(torch.split(t, seqlens)) for idx, x in enumerate(torch.split(t, seqlens))
@@ -262,7 +254,7 @@ class RMSNorm(torch.nn.Module):
def _norm(self, x: torch.Tensor) -> torch.Tensor: def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
@torch.amp.autocast(device_type='cuda', dtype=torch.float32) @torch.amp.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x) output = self._norm(x.float()).type_as(x)
return output * self.weight return output * self.weight
@@ -308,9 +300,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
result_B_S_HD = rearrange( result_B_S_HD = rearrange(F.scaled_dot_product_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D), "b h ... l -> b ... (h l)")
F.scaled_dot_product_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D), "b h ... l -> b ... (h l)"
)
return result_B_S_HD return result_B_S_HD
@@ -399,18 +389,23 @@ class Attention(nn.Module):
return q, k, v return q, k, v
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: # def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
result = self.attn_op(q, k, v) # [B, S, H, D] # result = self.attn_op(q, k, v) # [B, S, H, D]
return self.output_dropout(self.output_proj(result)) # return self.output_dropout(self.output_proj(result))
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
attn_params: attention.AttentionParams,
context: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None, rope_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
return self.compute_attention(q, k, v) # return self.compute_attention(q, k, v)
qkv = [q, k, v]
del q, k, v
result = attention.attention(qkv, attn_params=attn_params)
return self.output_dropout(self.output_proj(result))
# Positional Embeddings # Positional Embeddings
@@ -484,12 +479,8 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
dim_t = self._dim_t dim_t = self._dim_t
self.seq = torch.arange(max(self.max_h, self.max_w, self.max_t)).float().to(self.dim_spatial_range.device) self.seq = torch.arange(max(self.max_h, self.max_w, self.max_t)).float().to(self.dim_spatial_range.device)
self.dim_spatial_range = ( self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h
torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t
)
self.dim_temporal_range = (
torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t
)
def generate_embeddings( def generate_embeddings(
self, self,
@@ -679,9 +670,7 @@ class FourierFeatures(nn.Module):
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
generator = torch.Generator() generator = torch.Generator()
generator.manual_seed(0) generator.manual_seed(0)
self.freqs = ( self.freqs = 2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
)
self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device) self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device)
def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor: def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
@@ -713,9 +702,7 @@ class PatchEmbed(nn.Module):
m=spatial_patch_size, m=spatial_patch_size,
n=spatial_patch_size, n=spatial_patch_size,
), ),
nn.Linear( nn.Linear(in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False),
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False
),
) )
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
@@ -765,9 +752,7 @@ class FinalLayer(nn.Module):
nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False), nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False),
) )
else: else:
self.adaln_modulation = nn.Sequential( self.adaln_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False))
nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False)
)
self.init_weights() self.init_weights()
@@ -790,9 +775,9 @@ class FinalLayer(nn.Module):
): ):
if self.use_adaln_lora: if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None assert adaln_lora_B_T_3D is not None
shift_B_T_D, scale_B_T_D = ( shift_B_T_D, scale_B_T_D = (self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]).chunk(
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size] 2, dim=-1
).chunk(2, dim=-1) )
else: else:
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
@@ -833,7 +818,11 @@ class Block(nn.Module):
self.layer_norm_cross_attn = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) self.layer_norm_cross_attn = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
self.cross_attn = Attention( self.cross_attn = Attention(
x_dim, context_dim, num_heads, x_dim // num_heads, qkv_format="bshd", x_dim,
context_dim,
num_heads,
x_dim // num_heads,
qkv_format="bshd",
) )
self.layer_norm_mlp = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) self.layer_norm_mlp = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
@@ -904,6 +893,7 @@ class Block(nn.Module):
x_B_T_H_W_D: torch.Tensor, x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor, emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor, crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -919,13 +909,13 @@ class Block(nn.Module):
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1) ).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = ( shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk(
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D 3, dim=-1
).chunk(3, dim=-1) )
else: else:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn( shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(emb_B_T_D).chunk(
emb_B_T_D 3, dim=-1
).chunk(3, dim=-1) )
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
emb_B_T_D emb_B_T_D
).chunk(3, dim=-1) ).chunk(3, dim=-1)
@@ -954,11 +944,14 @@ class Block(nn.Module):
result = rearrange( result = rearrange(
self.self_attn( self.self_attn(
rearrange(normalized_x, "b t h w d -> b (t h w) d"), rearrange(normalized_x, "b t h w d -> b (t h w) d"),
attn_params,
None, None,
rope_emb=rope_emb_L_1_1_D, rope_emb=rope_emb_L_1_1_D,
), ),
"b (t h w) d -> b t h w d", "b (t h w) d -> b t h w d",
t=T, h=H, w=W, t=T,
h=H,
w=W,
) )
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result
@@ -967,11 +960,14 @@ class Block(nn.Module):
result = rearrange( result = rearrange(
self.cross_attn( self.cross_attn(
rearrange(normalized_x, "b t h w d -> b (t h w) d"), rearrange(normalized_x, "b t h w d -> b (t h w) d"),
attn_params,
crossattn_emb, crossattn_emb,
rope_emb=rope_emb_L_1_1_D, rope_emb=rope_emb_L_1_1_D,
), ),
"b (t h w) d -> b t h w d", "b (t h w) d -> b t h w d",
t=T, h=H, w=W, t=T,
h=H,
w=W,
) )
x_B_T_H_W_D = result * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D x_B_T_H_W_D = result * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
@@ -987,6 +983,7 @@ class Block(nn.Module):
x_B_T_H_W_D: torch.Tensor, x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor, emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor, crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None, rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -996,8 +993,13 @@ class Block(nn.Module):
# Unsloth: async non-blocking CPU RAM offload (fastest offload method) # Unsloth: async non-blocking CPU RAM offload (fastest offload method)
return unsloth_checkpoint( return unsloth_checkpoint(
self._forward, self._forward,
x_B_T_H_W_D, emb_B_T_D, crossattn_emb, x_B_T_H_W_D,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, emb_B_T_D,
crossattn_emb,
attn_params,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
) )
elif self.cpu_offload_checkpointing: elif self.cpu_offload_checkpointing:
# Standard cpu offload: blocking transfers # Standard cpu offload: blocking transfers
@@ -1008,36 +1010,54 @@ class Block(nn.Module):
device_inputs = to_device(inputs, device) device_inputs = to_device(inputs, device)
outputs = func(*device_inputs) outputs = func(*device_inputs)
return to_cpu(outputs) return to_cpu(outputs)
return custom_forward return custom_forward
return torch_checkpoint( return torch_checkpoint(
create_custom_forward(self._forward), create_custom_forward(self._forward),
x_B_T_H_W_D, emb_B_T_D, crossattn_emb, x_B_T_H_W_D,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, emb_B_T_D,
crossattn_emb,
attn_params,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
use_reentrant=False, use_reentrant=False,
) )
else: else:
# Standard gradient checkpointing (no offload) # Standard gradient checkpointing (no offload)
return torch_checkpoint( return torch_checkpoint(
self._forward, self._forward,
x_B_T_H_W_D, emb_B_T_D, crossattn_emb, x_B_T_H_W_D,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, emb_B_T_D,
crossattn_emb,
attn_params,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
use_reentrant=False, use_reentrant=False,
) )
else: else:
return self._forward( return self._forward(
x_B_T_H_W_D, emb_B_T_D, crossattn_emb, x_B_T_H_W_D,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, emb_B_T_D,
crossattn_emb,
attn_params,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
) )
# Main DiT Model: MiniTrainDIT # Main DiT Model: MiniTrainDIT (renamed to Anima)
class MiniTrainDIT(nn.Module): class Anima(nn.Module):
"""Cosmos-Predict2 DiT model for image/video generation. """Cosmos-Predict2 DiT model for image/video generation.
28 transformer blocks with AdaLN-LoRA modulation, 3D RoPE, and optional LLM Adapter. 28 transformer blocks with AdaLN-LoRA modulation, 3D RoPE, and optional LLM Adapter.
""" """
LATENT_CHANNELS = 16
def __init__( def __init__(
self, self,
max_img_h: int, max_img_h: int,
@@ -1069,6 +1089,8 @@ class MiniTrainDIT(nn.Module):
extra_t_extrapolation_ratio: float = 1.0, extra_t_extrapolation_ratio: float = 1.0,
rope_enable_fps_modulation: bool = True, rope_enable_fps_modulation: bool = True,
use_llm_adapter: bool = False, use_llm_adapter: bool = False,
attn_mode: str = "torch",
split_attn: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.max_img_h = max_img_h self.max_img_h = max_img_h
@@ -1097,6 +1119,9 @@ class MiniTrainDIT(nn.Module):
self.rope_enable_fps_modulation = rope_enable_fps_modulation self.rope_enable_fps_modulation = rope_enable_fps_modulation
self.use_llm_adapter = use_llm_adapter self.use_llm_adapter = use_llm_adapter
self.attn_mode = attn_mode
self.split_attn = split_attn
# Block swap support # Block swap support
self.blocks_to_swap = None self.blocks_to_swap = None
self.offloader: Optional[custom_offloading_utils.ModelOffloader] = None self.offloader: Optional[custom_offloading_utils.ModelOffloader] = None
@@ -1156,7 +1181,6 @@ class MiniTrainDIT(nn.Module):
self.final_layer.init_weights() self.final_layer.init_weights()
self.t_embedding_norm.reset_parameters() self.t_embedding_norm.reset_parameters()
def enable_gradient_checkpointing(self, cpu_offload: bool = False, unsloth_offload: bool = False): def enable_gradient_checkpointing(self, cpu_offload: bool = False, unsloth_offload: bool = False):
for block in self.blocks: for block in self.blocks:
block.enable_gradient_checkpointing(cpu_offload=cpu_offload, unsloth_offload=unsloth_offload) block.enable_gradient_checkpointing(cpu_offload=cpu_offload, unsloth_offload=unsloth_offload)
@@ -1169,18 +1193,21 @@ class MiniTrainDIT(nn.Module):
def device(self): def device(self):
return next(self.parameters()).device return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def set_flash_attn(self, use_flash_attn: bool): # def set_flash_attn(self, use_flash_attn: bool):
"""Toggle flash attention for all DiT blocks (self-attn + cross-attn). # """Toggle flash attention for all DiT blocks (self-attn + cross-attn).
LLM Adapter attention is NOT affected (it uses attention masks incompatible with flash_attn). # LLM Adapter attention is NOT affected (it uses attention masks incompatible with flash_attn).
""" # """
if use_flash_attn and not FLASH_ATTN_AVAILABLE: # if use_flash_attn and not FLASH_ATTN_AVAILABLE:
raise ImportError("flash_attn package is required for --flash_attn but is not installed") # raise ImportError("flash_attn package is required for --flash_attn but is not installed")
attn_op = flash_attention_op if use_flash_attn else torch_attention_op # attn_op = flash_attention_op if use_flash_attn else torch_attention_op
for block in self.blocks: # for block in self.blocks:
block.self_attn.attn_op = attn_op # block.self_attn.attn_op = attn_op
block.cross_attn.attn_op = attn_op # block.cross_attn.attn_op = attn_op
def build_patch_embed(self) -> None: def build_patch_embed(self) -> None:
in_channels = self.in_channels + 1 if self.concat_padding_mask else self.in_channels in_channels = self.in_channels + 1 if self.concat_padding_mask else self.in_channels
@@ -1232,9 +1259,7 @@ class MiniTrainDIT(nn.Module):
padding_mask = transforms.functional.resize( padding_mask = transforms.functional.resize(
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
) )
x_B_C_T_H_W = torch.cat( x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1)
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
)
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb: if self.extra_per_block_abs_pos_emb:
@@ -1258,7 +1283,6 @@ class MiniTrainDIT(nn.Module):
) )
return x_B_C_Tt_Hp_Wp return x_B_C_Tt_Hp_Wp
def enable_block_swap(self, num_blocks: int, device: torch.device): def enable_block_swap(self, num_blocks: int, device: torch.device):
self.blocks_to_swap = num_blocks self.blocks_to_swap = num_blocks
@@ -1266,9 +1290,7 @@ class MiniTrainDIT(nn.Module):
self.blocks_to_swap <= self.num_blocks - 2 self.blocks_to_swap <= self.num_blocks - 2
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks." ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
self.offloader = custom_offloading_utils.ModelOffloader( self.offloader = custom_offloading_utils.ModelOffloader(self.blocks, self.blocks_to_swap, device)
self.blocks, self.blocks_to_swap, device
)
logger.info(f"Anima: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.") logger.info(f"Anima: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
def move_to_device_except_swap_blocks(self, device: torch.device): def move_to_device_except_swap_blocks(self, device: torch.device):
@@ -1287,7 +1309,7 @@ class MiniTrainDIT(nn.Module):
return return
self.offloader.prepare_block_devices_before_forward(self.blocks) self.offloader.prepare_block_devices_before_forward(self.blocks)
def forward( def forward_mini_train_dit(
self, self,
x_B_C_T_H_W: torch.Tensor, x_B_C_T_H_W: torch.Tensor,
timesteps_B_T: torch.Tensor, timesteps_B_T: torch.Tensor,
@@ -1310,7 +1332,7 @@ class MiniTrainDIT(nn.Module):
t5_attn_mask: Optional T5 attention mask t5_attn_mask: Optional T5 attention mask
""" """
# Run LLM adapter inside forward for correct DDP gradient synchronization # Run LLM adapter inside forward for correct DDP gradient synchronization
if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, 'llm_adapter'): if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, "llm_adapter"):
crossattn_emb = self.llm_adapter( crossattn_emb = self.llm_adapter(
source_hidden_states=crossattn_emb, source_hidden_states=crossattn_emb,
target_input_ids=t5_input_ids, target_input_ids=t5_input_ids,
@@ -1337,6 +1359,8 @@ class MiniTrainDIT(nn.Module):
"extra_per_block_pos_emb": extra_pos_emb, "extra_per_block_pos_emb": extra_pos_emb,
} }
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
for block_idx, block in enumerate(self.blocks): for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap: if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx) self.offloader.wait_for_block(block_idx)
@@ -1345,6 +1369,7 @@ class MiniTrainDIT(nn.Module):
x_B_T_H_W_D, x_B_T_H_W_D,
t_embedding_B_T_D, t_embedding_B_T_D,
crossattn_emb, crossattn_emb,
attn_params,
**block_kwargs, **block_kwargs,
) )
@@ -1355,6 +1380,41 @@ class MiniTrainDIT(nn.Module):
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O) x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp return x_B_C_Tt_Hp_Wp
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
target_input_ids: Optional[torch.Tensor] = None,
target_attention_mask: Optional[torch.Tensor] = None,
source_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
context = self._preprocess_text_embeds(context, target_input_ids, target_attention_mask, source_attention_mask)
return self.forward_mini_train_dit(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs)
def _preprocess_text_embeds(
self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None
):
if target_input_ids is not None:
# print(
# f"Source hidden states shape: {source_hidden_states.shape},sum of attention mask: {torch.sum(source_attention_mask)}"
# )
# print(f"non zero source_hidden_states before LLM Adapter: {torch.sum(source_hidden_states != 0)}")
context = self.llm_adapter(
source_hidden_states,
target_input_ids,
target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
)
context[~target_attention_mask.bool()] = 0 # zero out padding tokens
# print(f"LLM Adapter output context: {context.shape}, {torch.isnan(context).sum()}")
return context
else:
return source_hidden_states
# LLM Adapter: Bridges Qwen3 embeddings to T5-compatible space # LLM Adapter: Bridges Qwen3 embeddings to T5-compatible space
class LLMAdapterRMSNorm(nn.Module): class LLMAdapterRMSNorm(nn.Module):
@@ -1485,24 +1545,37 @@ class LLMAdapterTransformerBlock(nn.Module):
self.norm_mlp = nn.LayerNorm(model_dim) if layer_norm else LLMAdapterRMSNorm(model_dim) self.norm_mlp = nn.LayerNorm(model_dim) if layer_norm else LLMAdapterRMSNorm(model_dim)
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
nn.Linear(model_dim, int(model_dim * mlp_ratio)), nn.Linear(model_dim, int(model_dim * mlp_ratio)), nn.GELU(), nn.Linear(int(model_dim * mlp_ratio), model_dim)
nn.GELU(),
nn.Linear(int(model_dim * mlp_ratio), model_dim)
) )
def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, def forward(
position_embeddings=None, position_embeddings_context=None): self,
x,
context,
target_attention_mask=None,
source_attention_mask=None,
position_embeddings=None,
position_embeddings_context=None,
):
if self.has_self_attn: if self.has_self_attn:
# Self-attention: target_attention_mask is not expected to be all zeros
normed = self.norm_self_attn(x) normed = self.norm_self_attn(x)
attn_out = self.self_attn(normed, mask=target_attention_mask, attn_out = self.self_attn(
position_embeddings=position_embeddings, normed,
position_embeddings_context=position_embeddings) mask=target_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings,
)
x = x + attn_out x = x + attn_out
normed = self.norm_cross_attn(x) normed = self.norm_cross_attn(x)
attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, attn_out = self.cross_attn(
position_embeddings=position_embeddings, normed,
position_embeddings_context=position_embeddings_context) mask=source_attention_mask,
context=context,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
x = x + attn_out x = x + attn_out
x = x + self.mlp(self.norm_mlp(x)) x = x + self.mlp(self.norm_mlp(x))
@@ -1518,8 +1591,9 @@ class LLMAdapter(nn.Module):
Uses T5 token IDs as target input, embeds them, and cross-attends to Qwen3 hidden states. Uses T5 token IDs as target input, embeds them, and cross-attends to Qwen3 hidden states.
""" """
def __init__(self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16, def __init__(
embed=None, self_attn=False, layer_norm=False): self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16, embed=None, self_attn=False, layer_norm=False
):
super().__init__() super().__init__()
if embed is not None: if embed is not None:
self.embed = nn.Embedding.from_pretrained(embed.weight) self.embed = nn.Embedding.from_pretrained(embed.weight)
@@ -1530,11 +1604,12 @@ class LLMAdapter(nn.Module):
else: else:
self.in_proj = nn.Identity() self.in_proj = nn.Identity()
self.rotary_emb = AdapterRotaryEmbedding(model_dim // num_heads) self.rotary_emb = AdapterRotaryEmbedding(model_dim // num_heads)
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList(
LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, [
self_attn=self_attn, layer_norm=layer_norm) LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, self_attn=self_attn, layer_norm=layer_norm)
for _ in range(num_layers) for _ in range(num_layers)
]) ]
)
self.out_proj = nn.Linear(model_dim, target_dim) self.out_proj = nn.Linear(model_dim, target_dim)
self.norm = LLMAdapterRMSNorm(target_dim) self.norm = LLMAdapterRMSNorm(target_dim)
@@ -1556,10 +1631,14 @@ class LLMAdapter(nn.Module):
position_embeddings = self.rotary_emb(x, position_ids) position_embeddings = self.rotary_emb(x, position_ids)
position_embeddings_context = self.rotary_emb(x, position_ids_context) position_embeddings_context = self.rotary_emb(x, position_ids_context)
for block in self.blocks: for block in self.blocks:
x = block(x, context, target_attention_mask=target_attention_mask, x = block(
source_attention_mask=source_attention_mask, x,
position_embeddings=position_embeddings, context,
position_embeddings_context=position_embeddings_context) target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
return self.norm(self.out_proj(x)) return self.norm(self.out_proj(x))
@@ -1567,30 +1646,60 @@ class LLMAdapter(nn.Module):
# VAE normalization constants # VAE normalization constants
ANIMA_VAE_MEAN = [ ANIMA_VAE_MEAN = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, -0.7571,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 -0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
] ]
ANIMA_VAE_STD = [ ANIMA_VAE_STD = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 2.8184,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
] ]
# DiT config detection from state_dict # DiT config detection from state_dict
KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer'] KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"]
def get_dit_config(state_dict, key_prefix=''): def get_dit_config(state_dict, key_prefix=""):
"""Derive DiT configuration from state_dict weight shapes.""" """Derive DiT configuration from state_dict weight shapes."""
dit_config = {} dit_config = {}
dit_config["max_img_h"] = 512 dit_config["max_img_h"] = 512
dit_config["max_img_w"] = 512 dit_config["max_img_w"] = 512
dit_config["max_frames"] = 128 dit_config["max_frames"] = 128
concat_padding_mask = True concat_padding_mask = True
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask) dit_config["in_channels"] = (state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[1] // 4) - int(
concat_padding_mask
)
dit_config["out_channels"] = 16 dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2 dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1 dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0] dit_config["model_channels"] = state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[0]
dit_config["concat_padding_mask"] = concat_padding_mask dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["crossattn_emb_channels"] = 1024 dit_config["crossattn_emb_channels"] = 1024
dit_config["pos_emb_cls"] = "rope3d" dit_config["pos_emb_cls"] = "rope3d"

View File

@@ -15,6 +15,7 @@ from tqdm import tqdm
from PIL import Image from PIL import Image
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
from library import anima_models, anima_utils, strategy_base, train_util, qwen_image_autoencoder_kl
init_ipex() init_ipex()
@@ -25,29 +26,14 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from library import anima_models, anima_utils, strategy_base, train_util
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas
# Anima-specific training arguments # Anima-specific training arguments
def add_anima_training_arguments(parser: argparse.ArgumentParser): def add_anima_training_arguments(parser: argparse.ArgumentParser):
"""Add Anima-specific training arguments to the parser.""" """Add Anima-specific training arguments to the parser."""
parser.add_argument( parser.add_argument(
"--dit_path", "--qwen3",
type=str,
default=None,
help="Path to Anima DiT model safetensors file",
)
parser.add_argument(
"--vae_path",
type=str,
default=None,
help="Path to WanVAE safetensors/pth file",
)
parser.add_argument(
"--qwen3_path",
type=str, type=str,
default=None, default=None,
help="Path to Qwen3-0.6B model (safetensors file or directory)", help="Path to Qwen3-0.6B model (safetensors file or directory)",
@@ -86,7 +72,7 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
"--mod_lr", "--mod_lr",
type=float, type=float,
default=None, default=None,
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze", help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze. Note: mod layers are not included in LoRA by default.",
) )
parser.add_argument( parser.add_argument(
"--t5_tokenizer_path", "--t5_tokenizer_path",
@@ -113,34 +99,29 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
help="Timestep distribution shift for rectified flow training (default: 1.0)", help="Timestep distribution shift for rectified flow training (default: 1.0)",
) )
parser.add_argument( parser.add_argument(
"--timestep_sample_method", "--timestep_sampling",
type=str, type=str,
default="logit_normal", default="sigmoid",
choices=["logit_normal", "uniform"], choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
help="Timestep sampling method (default: logit_normal)", help="Timestep sampling method (default: sigmoid (logit normal))",
) )
parser.add_argument( parser.add_argument(
"--sigmoid_scale", "--sigmoid_scale",
type=float, type=float,
default=1.0, default=1.0,
help="Scale factor for logit_normal timestep sampling (default: 1.0)", help="Scale factor for sigmoid (logit_normal) timestep sampling (default: 1.0)",
) )
# Note: --caption_dropout_rate is defined by base add_dataset_arguments().
# Anima uses embedding-level dropout (via AnimaTextEncodingStrategy.dropout_rate)
# instead of dataset-level caption dropout, so the subset caption_dropout_rate
# is zeroed out in the training scripts to allow caching.
parser.add_argument( parser.add_argument(
"--transformer_dtype", "--attn_mode",
type=str, choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility
default=None, default=None,
choices=["float16", "bfloat16", "float32", None], help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa."
help="Separate dtype for transformer blocks. If None, uses same as mixed_precision", " / 使用するAttentionの実装。デフォルトはNonetorchです。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません推論のみ。このオプションは--xformersまたは--sdpaを上書きします。",
) )
parser.add_argument( parser.add_argument(
"--flash_attn", "--split_attn",
action="store_true", action="store_true",
help="Use Flash Attention for DiT self/cross-attention (requires flash-attn package). " help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
"Falls back to PyTorch SDPA if flash-attn is not installed.",
) )
@@ -169,20 +150,20 @@ def get_noisy_model_input_and_timesteps(
""" """
bs = latents.shape[0] bs = latents.shape[0]
timestep_sample_method = getattr(args, 'timestep_sample_method', 'logit_normal') timestep_sample_method = getattr(args, "timestep_sample_method", "logit_normal")
sigmoid_scale = getattr(args, 'sigmoid_scale', 1.0) sigmoid_scale = getattr(args, "sigmoid_scale", 1.0)
shift = getattr(args, 'discrete_flow_shift', 1.0) shift = getattr(args, "discrete_flow_shift", 1.0)
if timestep_sample_method == 'logit_normal': if timestep_sample_method == "logit_normal":
dist = torch.distributions.normal.Normal(0, 1) dist = torch.distributions.normal.Normal(0, 1)
elif timestep_sample_method == 'uniform': elif timestep_sample_method == "uniform":
dist = torch.distributions.uniform.Uniform(0, 1) dist = torch.distributions.uniform.Uniform(0, 1)
else: else:
raise NotImplementedError(f"Unknown timestep_sample_method: {timestep_sample_method}") raise NotImplementedError(f"Unknown timestep_sample_method: {timestep_sample_method}")
t = dist.sample((bs,)).to(device) t = dist.sample((bs,)).to(device)
if timestep_sample_method == 'logit_normal': if timestep_sample_method == "logit_normal":
t = t * sigmoid_scale t = t * sigmoid_scale
t = torch.sigmoid(t) t = torch.sigmoid(t)
@@ -196,10 +177,10 @@ def get_noisy_model_input_and_timesteps(
# Create noisy input: (1 - t) * latents + t * noise # Create noisy input: (1 - t) * latents + t * noise
t_expanded = t.view(-1, *([1] * (latents.ndim - 1))) t_expanded = t.view(-1, *([1] * (latents.ndim - 1)))
ip_noise_gamma = getattr(args, 'ip_noise_gamma', None) ip_noise_gamma = getattr(args, "ip_noise_gamma", None)
if ip_noise_gamma: if ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype) xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if getattr(args, 'ip_noise_gamma_random_strength', False): if getattr(args, "ip_noise_gamma_random_strength", False):
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * ip_noise_gamma ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * ip_noise_gamma
noisy_model_input = (1 - t_expanded) * latents + t_expanded * (noise + ip_noise_gamma * xi) noisy_model_input = (1 - t_expanded) * latents + t_expanded * (noise + ip_noise_gamma * xi)
else: else:
@@ -213,10 +194,11 @@ def get_noisy_model_input_and_timesteps(
# Loss weighting # Loss weighting
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor: def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
"""Compute loss weighting for Anima training. """Compute loss weighting for Anima training.
Same schemes as SD3 but can add Anima-specific ones. Same schemes as SD3 but can add Anima-specific ones if needed in future.
""" """
if weighting_scheme == "sigma_sqrt": if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float() weighting = (sigmas**-2.0).float()
@@ -243,7 +225,7 @@ def get_anima_param_groups(
"""Create parameter groups for Anima training with separate learning rates. """Create parameter groups for Anima training with separate learning rates.
Args: Args:
dit: MiniTrainDIT model dit: Anima model
base_lr: Base learning rate base_lr: Base learning rate
self_attn_lr: LR for self-attention layers (None = base_lr, 0 = freeze) self_attn_lr: LR for self-attention layers (None = base_lr, 0 = freeze)
cross_attn_lr: LR for cross-attention layers cross_attn_lr: LR for cross-attention layers
@@ -276,15 +258,15 @@ def get_anima_param_groups(
# Store original name for debugging # Store original name for debugging
p.original_name = name p.original_name = name
if 'llm_adapter' in name: if "llm_adapter" in name:
llm_adapter_params.append(p) llm_adapter_params.append(p)
elif '.self_attn' in name: elif ".self_attn" in name:
self_attn_params.append(p) self_attn_params.append(p)
elif '.cross_attn' in name: elif ".cross_attn" in name:
cross_attn_params.append(p) cross_attn_params.append(p)
elif '.mlp' in name: elif ".mlp" in name:
mlp_params.append(p) mlp_params.append(p)
elif '.adaln_modulation' in name: elif ".adaln_modulation" in name:
mod_params.append(p) mod_params.append(p)
else: else:
base_params.append(p) base_params.append(p)
@@ -311,9 +293,9 @@ def get_anima_param_groups(
p.requires_grad_(False) p.requires_grad_(False)
logger.info(f" Frozen {name} params ({len(params)} parameters)") logger.info(f" Frozen {name} params ({len(params)} parameters)")
elif len(params) > 0: elif len(params) > 0:
param_groups.append({'params': params, 'lr': lr}) param_groups.append({"params": params, "lr": lr})
total_trainable = sum(p.numel() for group in param_groups for p in group['params'] if p.requires_grad) total_trainable = sum(p.numel() for group in param_groups for p in group["params"] if p.requires_grad)
logger.info(f"Total trainable parameters: {total_trainable:,}") logger.info(f"Total trainable parameters: {total_trainable:,}")
return param_groups return param_groups
@@ -325,13 +307,12 @@ def save_anima_model_on_train_end(
save_dtype: torch.dtype, save_dtype: torch.dtype,
epoch: int, epoch: int,
global_step: int, global_step: int,
dit: anima_models.MiniTrainDIT, dit: anima_models.Anima,
): ):
"""Save Anima model at the end of training.""" """Save Anima model at the end of training."""
def sd_saver(ckpt_file, epoch_no, global_step): def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec( sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
None, args, False, False, False, is_stable_diffusion_ckpt=True
)
dit_sd = dit.state_dict() dit_sd = dit.state_dict()
# Save with 'net.' prefix for ComfyUI compatibility # Save with 'net.' prefix for ComfyUI compatibility
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype) anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
@@ -347,13 +328,12 @@ def save_anima_model_on_epoch_end_or_stepwise(
epoch: int, epoch: int,
num_train_epochs: int, num_train_epochs: int,
global_step: int, global_step: int,
dit: anima_models.MiniTrainDIT, dit: anima_models.Anima,
): ):
"""Save Anima model at epoch end or specific steps.""" """Save Anima model at epoch end or specific steps."""
def sd_saver(ckpt_file, epoch_no, global_step): def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec( sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
None, args, False, False, False, is_stable_diffusion_ckpt=True
)
dit_sd = dit.state_dict() dit_sd = dit.state_dict()
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype) anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
@@ -376,12 +356,13 @@ def do_sample(
height: int, height: int,
width: int, width: int,
seed: Optional[int], seed: Optional[int],
dit: anima_models.MiniTrainDIT, dit: anima_models.Anima,
crossattn_emb: torch.Tensor, crossattn_emb: torch.Tensor,
steps: int, steps: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
guidance_scale: float = 1.0, guidance_scale: float = 1.0,
flow_shift: float = 3.0,
neg_crossattn_emb: Optional[torch.Tensor] = None, neg_crossattn_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Generate a sample using Euler discrete sampling for rectified flow. """Generate a sample using Euler discrete sampling for rectified flow.
@@ -389,12 +370,13 @@ def do_sample(
Args: Args:
height, width: Output image dimensions height, width: Output image dimensions
seed: Random seed (None for random) seed: Random seed (None for random)
dit: MiniTrainDIT model dit: Anima model
crossattn_emb: Cross-attention embeddings (B, N, D) crossattn_emb: Cross-attention embeddings (B, N, D)
steps: Number of sampling steps steps: Number of sampling steps
dtype: Compute dtype dtype: Compute dtype
device: Compute device device: Compute device
guidance_scale: CFG scale (1.0 = no guidance) guidance_scale: CFG scale (1.0 = no guidance)
flow_shift: Flow shift parameter for rectified flow
neg_crossattn_emb: Negative cross-attention embeddings for CFG neg_crossattn_emb: Negative cross-attention embeddings for CFG
Returns: Returns:
@@ -410,12 +392,13 @@ def do_sample(
generator = torch.manual_seed(seed) generator = torch.manual_seed(seed)
else: else:
generator = None generator = None
noise = torch.randn( noise = torch.randn(latent.size(), dtype=torch.float32, generator=generator, device="cpu").to(dtype).to(device)
latent.size(), dtype=torch.float32, generator=generator, device="cpu"
).to(dtype).to(device)
# Timestep schedule: linear from 1.0 to 0.0 # Timestep schedule: linear from 1.0 to 0.0
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype) sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
flow_shift = float(flow_shift)
if flow_shift != 1.0:
sigmas = (sigmas * flow_shift) / (1 + (flow_shift - 1) * sigmas)
# Start from pure noise # Start from pure noise
x = noise.clone() x = noise.clone()
@@ -463,7 +446,6 @@ def sample_images(
steps, steps,
dit, dit,
vae, vae,
vae_scale,
text_encoder, text_encoder,
tokenize_strategy, tokenize_strategy,
text_encoding_strategy, text_encoding_strategy,
@@ -512,10 +494,19 @@ def sample_images(
with torch.no_grad(), accelerator.autocast(): with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts: for prompt_dict in prompts:
_sample_image_inference( _sample_image_inference(
accelerator, args, dit, text_encoder, vae, vae_scale, accelerator,
tokenize_strategy, text_encoding_strategy, args,
save_dir, prompt_dict, epoch, steps, dit,
sample_prompts_te_outputs, prompt_replacement, text_encoder,
vae,
tokenize_strategy,
text_encoding_strategy,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
) )
# Restore RNG state # Restore RNG state
@@ -527,10 +518,19 @@ def sample_images(
def _sample_image_inference( def _sample_image_inference(
accelerator, args, dit, text_encoder, vae, vae_scale, accelerator,
tokenize_strategy, text_encoding_strategy, args,
save_dir, prompt_dict, epoch, steps, dit,
sample_prompts_te_outputs, prompt_replacement, text_encoder,
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage,
tokenize_strategy,
text_encoding_strategy,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
): ):
"""Generate a single sample image.""" """Generate a single sample image."""
prompt = prompt_dict.get("prompt", "") prompt = prompt_dict.get("prompt", "")
@@ -540,6 +540,7 @@ def _sample_image_inference(
height = prompt_dict.get("height", 512) height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5) scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed") seed = prompt_dict.get("seed")
flow_shift = prompt_dict.get("flow_shift", 3.0)
if prompt_replacement is not None: if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
@@ -553,7 +554,9 @@ def _sample_image_inference(
height = max(64, height - height % 16) height = max(64, height - height % 16)
width = max(64, width - width % 16) width = max(64, width - width % 16)
logger.info(f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}") logger.info(
f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}, flow_shift: {flow_shift}, seed: {seed}"
)
# Encode prompt # Encode prompt
def encode_prompt(prpt): def encode_prompt(prpt):
@@ -579,13 +582,13 @@ def _sample_image_inference(
t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0) t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0)
t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0) t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0)
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype) prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.dtype)
attn_mask = attn_mask.to(accelerator.device) attn_mask = attn_mask.to(accelerator.device)
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long) t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
t5_attn_mask = t5_attn_mask.to(accelerator.device) t5_attn_mask = t5_attn_mask.to(accelerator.device)
# Process through LLM adapter if available # Process through LLM adapter if available
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'): if dit.use_llm_adapter:
crossattn_emb = dit.llm_adapter( crossattn_emb = dit.llm_adapter(
source_hidden_states=prompt_embeds, source_hidden_states=prompt_embeds,
target_input_ids=t5_input_ids, target_input_ids=t5_input_ids,
@@ -608,12 +611,12 @@ def _sample_image_inference(
neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0) neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0)
neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0) neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0)
neg_pe = neg_pe.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype) neg_pe = neg_pe.to(accelerator.device, dtype=dit.dtype)
neg_am = neg_am.to(accelerator.device) neg_am = neg_am.to(accelerator.device)
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long) neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
neg_t5_am = neg_t5_am.to(accelerator.device) neg_t5_am = neg_t5_am.to(accelerator.device)
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'): if dit.use_llm_adapter:
neg_crossattn_emb = dit.llm_adapter( neg_crossattn_emb = dit.llm_adapter(
source_hidden_states=neg_pe, source_hidden_states=neg_pe,
target_input_ids=neg_t5_ids, target_input_ids=neg_t5_ids,
@@ -627,16 +630,14 @@ def _sample_image_inference(
# Generate sample # Generate sample
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
latents = do_sample( latents = do_sample(
height, width, seed, dit, crossattn_emb, height, width, seed, dit, crossattn_emb, sample_steps, dit.dtype, accelerator.device, scale, flow_shift, neg_crossattn_emb
sample_steps, dit.t_embedding_norm.weight.dtype,
accelerator.device, scale, neg_crossattn_emb,
) )
# Decode latents # Decode latents
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
org_vae_device = next(vae.parameters()).device org_vae_device = vae.device
vae.to(accelerator.device) vae.to(accelerator.device)
decoded = vae.decode(latents.to(next(vae.parameters()).device, dtype=next(vae.parameters()).dtype), vae_scale) decoded = vae.decode_to_pixels(latents)
vae.to(org_vae_device) vae.to(org_vae_device)
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
@@ -662,4 +663,5 @@ def _sample_image_inference(
if "wandb" in [tracker.name for tracker in accelerator.trackers]: if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb") wandb_tracker = accelerator.get_tracker("wandb")
import wandb import wandb
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False)

View File

@@ -6,7 +6,12 @@ import torch
import torch.nn as nn import torch.nn as nn
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from accelerate.utils import set_module_tensor_to_device # kept for potential future use from accelerate.utils import set_module_tensor_to_device # kept for potential future use
from accelerate import init_empty_weights
from library.fp8_optimization_utils import apply_fp8_monkey_patch
from library.lora_utils import load_safetensors_with_lora_and_fp8
from library import anima_models
from library.safetensors_utils import WeightTransformHooks
from .utils import setup_logging from .utils import setup_logging
setup_logging() setup_logging()
@@ -14,11 +19,9 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from library import anima_models
# Keys that should stay in high precision (float32/bfloat16, not quantized) # Keys that should stay in high precision (float32/bfloat16, not quantized)
KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer'] KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"]
def load_safetensors(path: str, device: str = "cpu", dtype: Optional[torch.dtype] = None) -> Dict[str, torch.Tensor]: def load_safetensors(path: str, device: str = "cpu", dtype: Optional[torch.dtype] = None) -> Dict[str, torch.Tensor]:
@@ -36,8 +39,8 @@ def load_anima_dit(
transformer_dtype: Optional[torch.dtype] = None, transformer_dtype: Optional[torch.dtype] = None,
llm_adapter_path: Optional[str] = None, llm_adapter_path: Optional[str] = None,
disable_mmap: bool = False, disable_mmap: bool = False,
) -> anima_models.MiniTrainDIT: ) -> anima_models.Anima:
"""Load the MiniTrainDIT model from safetensors. """Load the Anima model from safetensors.
Args: Args:
dit_path: Path to DiT safetensors file dit_path: Path to DiT safetensors file
@@ -53,6 +56,7 @@ def load_anima_dit(
logger.info(f"Loading Anima DiT from {dit_path}") logger.info(f"Loading Anima DiT from {dit_path}")
if disable_mmap: if disable_mmap:
from library.safetensors_utils import load_safetensors as load_safetensors_no_mmap from library.safetensors_utils import load_safetensors as load_safetensors_no_mmap
state_dict = load_safetensors_no_mmap(dit_path, device="cpu", disable_mmap=True) state_dict = load_safetensors_no_mmap(dit_path, device="cpu", disable_mmap=True)
else: else:
state_dict = load_file(dit_path, device="cpu") state_dict = load_file(dit_path, device="cpu")
@@ -60,8 +64,8 @@ def load_anima_dit(
# Remove 'net.' prefix if present # Remove 'net.' prefix if present
new_state_dict = {} new_state_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
if k.startswith('net.'): if k.startswith("net."):
k = k[len('net.'):] k = k[len("net.") :]
new_state_dict[k] = v new_state_dict[k] = v
state_dict = new_state_dict state_dict = new_state_dict
@@ -71,21 +75,23 @@ def load_anima_dit(
# Detect LLM adapter # Detect LLM adapter
if llm_adapter_path is not None: if llm_adapter_path is not None:
use_llm_adapter = True use_llm_adapter = True
dit_config['use_llm_adapter'] = True dit_config["use_llm_adapter"] = True
llm_adapter_state_dict = load_safetensors(llm_adapter_path, device="cpu") llm_adapter_state_dict = load_safetensors(llm_adapter_path, device="cpu")
elif 'llm_adapter.out_proj.weight' in state_dict: elif "llm_adapter.out_proj.weight" in state_dict:
use_llm_adapter = True use_llm_adapter = True
dit_config['use_llm_adapter'] = True dit_config["use_llm_adapter"] = True
llm_adapter_state_dict = None # Loaded as part of DiT llm_adapter_state_dict = None # Loaded as part of DiT
else: else:
use_llm_adapter = False use_llm_adapter = False
llm_adapter_state_dict = None llm_adapter_state_dict = None
logger.info(f"DiT config: model_channels={dit_config['model_channels']}, num_blocks={dit_config['num_blocks']}, " logger.info(
f"num_heads={dit_config['num_heads']}, use_llm_adapter={use_llm_adapter}") f"DiT config: model_channels={dit_config['model_channels']}, num_blocks={dit_config['num_blocks']}, "
f"num_heads={dit_config['num_heads']}, use_llm_adapter={use_llm_adapter}"
)
# Build model normally on CPU — buffers get proper values from __init__ # Build model normally on CPU — buffers get proper values from __init__
dit = anima_models.MiniTrainDIT(**dit_config) dit = anima_models.Anima(**dit_config)
# Merge LLM adapter weights into state_dict if loaded separately # Merge LLM adapter weights into state_dict if loaded separately
if use_llm_adapter and llm_adapter_state_dict is not None: if use_llm_adapter and llm_adapter_state_dict is not None:
@@ -96,9 +102,11 @@ def load_anima_dit(
missing, unexpected = dit.load_state_dict(state_dict, strict=False) missing, unexpected = dit.load_state_dict(state_dict, strict=False)
if missing: if missing:
# Filter out expected missing buffers (initialized in __init__, not saved in checkpoint) # Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
unexpected_missing = [k for k in missing if not any( unexpected_missing = [
buf_name in k for buf_name in ('seq', 'dim_spatial_range', 'dim_temporal_range', 'inv_freq') k
)] for k in missing
if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq"))
]
if unexpected_missing: if unexpected_missing:
logger.warning(f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}") logger.warning(f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}")
if unexpected: if unexpected:
@@ -106,9 +114,7 @@ def load_anima_dit(
# Apply per-parameter dtype (high precision for 1D/critical, transformer_dtype for rest) # Apply per-parameter dtype (high precision for 1D/critical, transformer_dtype for rest)
for name, p in dit.named_parameters(): for name, p in dit.named_parameters():
dtype_to_use = dtype if ( dtype_to_use = dtype if (any(keyword in name for keyword in KEEP_IN_HIGH_PRECISION) or p.ndim == 1) else transformer_dtype
any(keyword in name for keyword in KEEP_IN_HIGH_PRECISION) or p.ndim == 1
) else transformer_dtype
p.data = p.data.to(dtype=dtype_to_use) p.data = p.data.to(dtype=dtype_to_use)
dit.to(device) dit.to(device)
@@ -116,6 +122,128 @@ def load_anima_dit(
return dit return dit
FP8_OPTIMIZATION_TARGET_KEYS = ["blocks", ""]
# ".embed." excludes Embedding in LLMAdapter
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer", ".embed."]
def load_anima_model(
device: Union[str, torch.device],
dit_path: str,
attn_mode: str,
split_attn: bool,
loading_device: Union[str, torch.device],
dit_weight_dtype: Optional[torch.dtype],
fp8_scaled: bool = False,
lora_weights_list: Optional[Dict[str, torch.Tensor]] = None,
lora_multipliers: Optional[list[float]] = None,
) -> anima_models.Anima:
"""
Load a HunyuanImage model from the specified checkpoint.
Args:
device (Union[str, torch.device]): Device for optimization or merging
dit_path (str): Path to the DiT model checkpoint.
attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc.
split_attn (bool): Whether to use split attention.
loading_device (Union[str, torch.device]): Device to load the model weights on.
dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights.
If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype.
fp8_scaled (bool): Whether to use fp8 scaling for the model weights.
lora_weights_list (Optional[Dict[str, torch.Tensor]]): LoRA weights to apply, if any.
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
"""
# dit_weight_dtype is None for fp8_scaled
assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None)
device = torch.device(device)
loading_device = torch.device(loading_device)
# We currently support fixed DiT config for Anima models
dit_config = {
"max_img_h": 512,
"max_img_w": 512,
"max_frames": 128,
"in_channels": 16,
"out_channels": 16,
"patch_spatial": 2,
"patch_temporal": 1,
"model_channels": 2048,
"concat_padding_mask": True,
"crossattn_emb_channels": 1024,
"pos_emb_cls": "rope3d",
"pos_emb_learnable": True,
"pos_emb_interpolation": "crop",
"min_fps": 1,
"max_fps": 30,
"use_adaln_lora": True,
"adaln_lora_dim": 256,
"num_blocks": 28,
"num_heads": 16,
"extra_per_block_abs_pos_emb": False,
"rope_h_extrapolation_ratio": 4.0,
"rope_w_extrapolation_ratio": 4.0,
"rope_t_extrapolation_ratio": 1.0,
"extra_h_extrapolation_ratio": 1.0,
"extra_w_extrapolation_ratio": 1.0,
"extra_t_extrapolation_ratio": 1.0,
"rope_enable_fps_modulation": False,
"use_llm_adapter": True,
"attn_mode": attn_mode,
"split_attn": split_attn,
}
with init_empty_weights():
model = anima_models.Anima(**dit_config)
if dit_weight_dtype is not None:
model.to(dit_weight_dtype)
# load model weights with dynamic fp8 optimization and LoRA merging if needed
logger.info(f"Loading DiT model from {dit_path}, device={loading_device}")
rename_hooks = WeightTransformHooks(rename_hook=lambda k: k[len("net.") :] if k.startswith("net.") else k)
sd = load_safetensors_with_lora_and_fp8(
model_files=dit_path,
lora_weights_list=lora_weights_list,
lora_multipliers=lora_multipliers,
fp8_optimization=fp8_scaled,
calc_device=device,
move_to_device=(loading_device == device),
dit_weight_dtype=dit_weight_dtype,
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
weight_transform_hooks=rename_hooks,
)
if fp8_scaled:
apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
if loading_device.type != "cpu":
# make sure all the model weights are on the loading_device
logger.info(f"Moving weights to {loading_device}")
for key in sd.keys():
sd[key] = sd[key].to(loading_device)
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
if missing:
# Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
unexpected_missing = [
k
for k in missing
if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq"))
]
if unexpected_missing:
# Raise error to avoid silent failures
raise RuntimeError(
f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}"
)
missing = {} # all missing keys were expected
if unexpected:
# Raise error to avoid silent failures
raise RuntimeError(f"Unexpected keys in checkpoint: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
logger.info(f"Loaded DiT model from {dit_path}, unexpected missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")
return model
def load_anima_vae(vae_path: str, dtype: torch.dtype = torch.float32, device: str = "cpu"): def load_anima_vae(vae_path: str, dtype: torch.dtype = torch.float32, device: str = "cpu"):
"""Load WanVAE from a safetensors/pth file. """Load WanVAE from a safetensors/pth file.
@@ -139,14 +267,14 @@ def load_anima_vae(vae_path: str, dtype: torch.dtype = torch.float32, device: st
from library.anima_vae import WanVAE_ from library.anima_vae import WanVAE_
# Build model # Build model
with torch.device('meta'): with torch.device("meta"):
vae = WanVAE_(**vae_config) vae = WanVAE_(**vae_config)
# Load state dict # Load state dict
if vae_path.endswith('.safetensors'): if vae_path.endswith(".safetensors"):
vae_sd = load_file(vae_path, device='cpu') vae_sd = load_file(vae_path, device="cpu")
else: else:
vae_sd = torch.load(vae_path, map_location='cpu', weights_only=True) vae_sd = torch.load(vae_path, map_location="cpu", weights_only=True)
vae.load_state_dict(vae_sd, assign=True) vae.load_state_dict(vae_sd, assign=True)
vae = vae.eval().requires_grad_(False).to(device, dtype=dtype) vae = vae.eval().requires_grad_(False).to(device, dtype=dtype)
@@ -175,7 +303,7 @@ def load_qwen3_tokenizer(qwen3_path: str):
if os.path.isdir(qwen3_path): if os.path.isdir(qwen3_path):
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True) tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
else: else:
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b') config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
if not os.path.exists(config_dir): if not os.path.exists(config_dir):
raise FileNotFoundError( raise FileNotFoundError(
f"Qwen3 config directory not found at {config_dir}. " f"Qwen3 config directory not found at {config_dir}. "
@@ -209,12 +337,10 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16
if os.path.isdir(qwen3_path): if os.path.isdir(qwen3_path):
# Directory with full model # Directory with full model
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True) tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
model = transformers.AutoModelForCausalLM.from_pretrained( model = transformers.AutoModelForCausalLM.from_pretrained(qwen3_path, torch_dtype=dtype, local_files_only=True).model
qwen3_path, torch_dtype=dtype, local_files_only=True
).model
else: else:
# Single safetensors file - use configs/qwen3_06b/ for config # Single safetensors file - use configs/qwen3_06b/ for config
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b') config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
if not os.path.exists(config_dir): if not os.path.exists(config_dir):
raise FileNotFoundError( raise FileNotFoundError(
f"Qwen3 config directory not found at {config_dir}. " f"Qwen3 config directory not found at {config_dir}. "
@@ -227,16 +353,16 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16
model = transformers.Qwen3ForCausalLM(qwen3_config).model model = transformers.Qwen3ForCausalLM(qwen3_config).model
# Load weights # Load weights
if qwen3_path.endswith('.safetensors'): if qwen3_path.endswith(".safetensors"):
state_dict = load_file(qwen3_path, device='cpu') state_dict = load_file(qwen3_path, device="cpu")
else: else:
state_dict = torch.load(qwen3_path, map_location='cpu', weights_only=True) state_dict = torch.load(qwen3_path, map_location="cpu", weights_only=True)
# Remove 'model.' prefix if present # Remove 'model.' prefix if present
new_sd = {} new_sd = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
if k.startswith('model.'): if k.startswith("model."):
new_sd[k[len('model.'):]] = v new_sd[k[len("model.") :]] = v
else: else:
new_sd[k] = v new_sd[k] = v
@@ -265,11 +391,11 @@ def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
return T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True) return T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
# Use bundled config # Use bundled config
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 't5_old') config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "t5_old")
if os.path.exists(config_dir): if os.path.exists(config_dir):
return T5TokenizerFast( return T5TokenizerFast(
vocab_file=os.path.join(config_dir, 'spiece.model'), vocab_file=os.path.join(config_dir, "spiece.model"),
tokenizer_file=os.path.join(config_dir, 'tokenizer.json'), tokenizer_file=os.path.join(config_dir, "tokenizer.json"),
) )
raise FileNotFoundError( raise FileNotFoundError(
@@ -291,9 +417,9 @@ def save_anima_model(save_path: str, dit_state_dict: Dict[str, torch.Tensor], dt
for k, v in dit_state_dict.items(): for k, v in dit_state_dict.items():
if dtype is not None: if dtype is not None:
v = v.to(dtype) v = v.to(dtype)
prefixed_sd['net.' + k] = v.contiguous() prefixed_sd["net." + k] = v.contiguous()
save_file(prefixed_sd, save_path, metadata={'format': 'pt'}) save_file(prefixed_sd, save_path, metadata={"format": "pt"})
logger.info(f"Saved Anima model to {save_path}") logger.info(f"Saved Anima model to {save_path}")

View File

@@ -16,8 +16,7 @@ class CausalConv3d(nn.Conv3d):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1], self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0) self.padding = (0, 0, 0)
def forward(self, x, cache_x=None): def forward(self, x, cache_x=None):
@@ -41,12 +40,10 @@ class RMS_norm(nn.Module):
self.channel_first = channel_first self.channel_first = channel_first
self.scale = dim**0.5 self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape)) self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x): def forward(self, x):
return F.normalize( return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample): class Upsample(nn.Upsample):
@@ -61,65 +58,48 @@ class Upsample(nn.Upsample):
class Resample(nn.Module): class Resample(nn.Module):
def __init__(self, dim, mode): def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d")
'downsample3d')
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.mode = mode self.mode = mode
# layers # layers
if mode == 'upsample2d': if mode == "upsample2d":
self.resample = nn.Sequential( self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'), Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
nn.Conv2d(dim, dim // 2, 3, padding=1)) )
elif mode == 'upsample3d': elif mode == "upsample3d":
self.resample = nn.Sequential( self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'), Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
nn.Conv2d(dim, dim // 2, 3, padding=1)) )
self.time_conv = CausalConv3d( self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == 'downsample2d': elif mode == "downsample2d":
self.resample = nn.Sequential( self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
nn.ZeroPad2d((0, 1, 0, 1)), elif mode == "downsample3d":
nn.Conv2d(dim, dim, 3, stride=(2, 2))) self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == 'downsample3d': self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else: else:
self.resample = nn.Identity() self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]): def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size() b, c, t, h, w = x.size()
if self.mode == 'upsample3d': if self.mode == "upsample3d":
if feat_cache is not None: if feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
if feat_cache[idx] is None: if feat_cache[idx] is None:
feat_cache[idx] = 'Rep' feat_cache[idx] = "Rep"
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk # cache last frame of last two chunk
cache_x = torch.cat([ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x.device), cache_x cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
], if feat_cache[idx] == "Rep":
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x) x = self.time_conv(x)
else: else:
x = self.time_conv(x, feat_cache[idx]) x = self.time_conv(x, feat_cache[idx])
@@ -127,15 +107,14 @@ class Resample(nn.Module):
feat_idx[0] += 1 feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w) x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
3)
x = x.reshape(b, c, t * 2, h, w) x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2] t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w') x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.resample(x) x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t) x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.mode == 'downsample3d': if self.mode == "downsample3d":
if feat_cache is not None: if feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
if feat_cache[idx] is None: if feat_cache[idx] is None:
@@ -144,8 +123,7 @@ class Resample(nn.Module):
else: else:
cache_x = x[:, :, -1:, :, :].clone() cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv( x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
return x return x
@@ -166,8 +144,8 @@ class Resample(nn.Module):
nn.init.zeros_(conv_weight) nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size() c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2) init_matrix = torch.eye(c1 // 2, c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight) conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data) nn.init.zeros_(conv.bias.data)
@@ -181,12 +159,15 @@ class ResidualBlock(nn.Module):
# layers # layers
self.residual = nn.Sequential( self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(), RMS_norm(in_dim, images=False),
nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1), CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), RMS_norm(out_dim, images=False),
CausalConv3d(out_dim, out_dim, 3, padding=1)) nn.SiLU(),
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ nn.Dropout(dropout),
if in_dim != out_dim else nn.Identity() CausalConv3d(out_dim, out_dim, 3, padding=1),
)
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]): def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x) h = self.shortcut(x)
@@ -196,11 +177,7 @@ class ResidualBlock(nn.Module):
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None: if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk # cache last frame of last two chunk
cache_x = torch.cat([ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx]) x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@@ -229,13 +206,10 @@ class AttentionBlock(nn.Module):
def forward(self, x): def forward(self, x):
identity = x identity = x
b, c, t, h, w = x.size() b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w') x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm(x) x = self.norm(x)
# compute query, key, value # compute query, key, value
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
-1).permute(0, 1, 3,
2).contiguous().chunk(
3, dim=-1)
# apply attention # apply attention
x = F.scaled_dot_product_attention( x = F.scaled_dot_product_attention(
@@ -247,20 +221,22 @@ class AttentionBlock(nn.Module):
# output # output
x = self.proj(x) x = self.proj(x)
x = rearrange(x, '(b t) c h w-> b c t h w', t=t) x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
return x + identity return x + identity
class Encoder3d(nn.Module): class Encoder3d(nn.Module):
def __init__(self, def __init__(
dim=128, self,
z_dim=4, dim=128,
dim_mult=[1, 2, 4, 4], z_dim=4,
num_res_blocks=2, dim_mult=[1, 2, 4, 4],
attn_scales=[], num_res_blocks=2,
temperal_downsample=[True, True, False], attn_scales=[],
dropout=0.0): temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.z_dim = z_dim self.z_dim = z_dim
@@ -288,21 +264,18 @@ class Encoder3d(nn.Module):
# downsample block # downsample block
if i != len(dim_mult) - 1: if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode)) downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0 scale /= 2.0
self.downsamples = nn.Sequential(*downsamples) self.downsamples = nn.Sequential(*downsamples)
# middle blocks # middle blocks
self.middle = nn.Sequential( self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout)
ResidualBlock(out_dim, out_dim, dropout)) )
# output blocks # output blocks
self.head = nn.Sequential( self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1))
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]): def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None: if feat_cache is not None:
@@ -310,11 +283,7 @@ class Encoder3d(nn.Module):
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None: if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk # cache last frame of last two chunk
cache_x = torch.cat([ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx]) x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@@ -342,11 +311,7 @@ class Encoder3d(nn.Module):
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None: if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk # cache last frame of last two chunk
cache_x = torch.cat([ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx]) x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@@ -357,14 +322,16 @@ class Encoder3d(nn.Module):
class Decoder3d(nn.Module): class Decoder3d(nn.Module):
def __init__(self, def __init__(
dim=128, self,
z_dim=4, dim=128,
dim_mult=[1, 2, 4, 4], z_dim=4,
num_res_blocks=2, dim_mult=[1, 2, 4, 4],
attn_scales=[], num_res_blocks=2,
temperal_upsample=[False, True, True], attn_scales=[],
dropout=0.0): temperal_upsample=[False, True, True],
dropout=0.0,
):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.z_dim = z_dim self.z_dim = z_dim
@@ -375,15 +342,15 @@ class Decoder3d(nn.Module):
# dimensions # dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2) scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block # init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks # middle blocks
self.middle = nn.Sequential( self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout)
ResidualBlock(dims[0], dims[0], dropout)) )
# upsample blocks # upsample blocks
upsamples = [] upsamples = []
@@ -399,15 +366,13 @@ class Decoder3d(nn.Module):
# upsample block # upsample block
if i != len(dim_mult) - 1: if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
upsamples.append(Resample(out_dim, mode=mode)) upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0 scale *= 2.0
self.upsamples = nn.Sequential(*upsamples) self.upsamples = nn.Sequential(*upsamples)
# output blocks # output blocks
self.head = nn.Sequential( self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1))
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]): def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1 ## conv1
@@ -416,11 +381,7 @@ class Decoder3d(nn.Module):
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None: if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk # cache last frame of last two chunk
cache_x = torch.cat([ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx]) x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@@ -448,11 +409,7 @@ class Decoder3d(nn.Module):
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None: if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk # cache last frame of last two chunk
cache_x = torch.cat([ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx]) x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@@ -471,14 +428,16 @@ def count_conv3d(model):
class WanVAE_(nn.Module): class WanVAE_(nn.Module):
def __init__(self, def __init__(
dim=128, self,
z_dim=4, dim=128,
dim_mult=[1, 2, 4, 4], z_dim=4,
num_res_blocks=2, dim_mult=[1, 2, 4, 4],
attn_scales=[], num_res_blocks=2,
temperal_downsample=[True, True, False], attn_scales=[],
dropout=0.0): temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.z_dim = z_dim self.z_dim = z_dim
@@ -489,12 +448,10 @@ class WanVAE_(nn.Module):
self.temperal_upsample = temperal_downsample[::-1] self.temperal_upsample = temperal_downsample[::-1]
# modules # modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout)
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout)
attn_scales, self.temperal_upsample, dropout)
def forward(self, x): def forward(self, x):
mu, log_var = self.encode(x) mu, log_var = self.encode(x)
@@ -510,20 +467,15 @@ class WanVAE_(nn.Module):
for i in range(iter_): for i in range(iter_):
self._enc_conv_idx = [0] self._enc_conv_idx = [0]
if i == 0: if i == 0:
out = self.encoder( out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else: else:
out_ = self.encoder( out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx
feat_cache=self._enc_feat_map, )
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1) mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor): if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
1, self.z_dim, 1, 1, 1)
else: else:
mu = (mu - scale[0]) * scale[1] mu = (mu - scale[0]) * scale[1]
self.clear_cache() self.clear_cache()
@@ -533,8 +485,7 @@ class WanVAE_(nn.Module):
self.clear_cache() self.clear_cache()
# z: [b,c,t,h,w] # z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor): if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
1, self.z_dim, 1, 1, 1)
else: else:
z = z / scale[1] + scale[0] z = z / scale[1] + scale[0]
iter_ = z.shape[2] iter_ = z.shape[2]
@@ -542,15 +493,9 @@ class WanVAE_(nn.Module):
for i in range(iter_): for i in range(iter_):
self._conv_idx = [0] self._conv_idx = [0]
if i == 0: if i == 0:
out = self.decoder( out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else: else:
out_ = self.decoder( out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
self.clear_cache() self.clear_cache()
return out return out
@@ -571,7 +516,7 @@ class WanVAE_(nn.Module):
self._conv_num = count_conv3d(self.decoder) self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0] self._conv_idx = [0]
self._feat_map = [None] * self._conv_num self._feat_map = [None] * self._conv_num
#cache encode # cache encode
self._enc_conv_num = count_conv3d(self.encoder) self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0] self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num self._enc_feat_map = [None] * self._enc_conv_num

View File

@@ -471,7 +471,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def get_noisy_model_input_and_timesteps( def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape bsz, h, w = latents.shape[0], latents.shape[-2], latents.shape[-1]
assert bsz > 0, "Batch size not large enough" assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps num_timesteps = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
@@ -512,7 +512,7 @@ def get_noisy_model_input_and_timesteps(
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
# Broadcast sigmas to latent shape # Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1) sigmas = sigmas.view(-1, 1, 1, 1) if latents.ndim == 4 else sigmas.view(-1, 1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep # Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)

View File

@@ -9,7 +9,7 @@ import logging
from tqdm import tqdm from tqdm import tqdm
from library.device_utils import clean_memory_on_device from library.device_utils import clean_memory_on_device
from library.safetensors_utils import MemoryEfficientSafeOpen from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks
from library.utils import setup_logging from library.utils import setup_logging
setup_logging() setup_logging()
@@ -220,6 +220,8 @@ def quantize_weight(
tensor_max = torch.max(torch.abs(tensor).view(-1)) tensor_max = torch.max(torch.abs(tensor).view(-1))
scale = tensor_max / max_value scale = tensor_max / max_value
# print(f"Optimizing {key} with scale: {scale}")
# numerical safety # numerical safety
scale = torch.clamp(scale, min=1e-8) scale = torch.clamp(scale, min=1e-8)
scale = scale.to(torch.float32) # ensure scale is in float32 for division scale = scale.to(torch.float32) # ensure scale is in float32 for division
@@ -245,6 +247,8 @@ def load_safetensors_with_fp8_optimization(
weight_hook=None, weight_hook=None,
quantization_mode: str = "block", quantization_mode: str = "block",
block_size: Optional[int] = 64, block_size: Optional[int] = 64,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict: ) -> dict:
""" """
Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization. Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization.
@@ -260,6 +264,8 @@ def load_safetensors_with_fp8_optimization(
weight_hook (callable, optional): Function to apply to each weight tensor before optimization weight_hook (callable, optional): Function to apply to each weight tensor before optimization
quantization_mode (str): Quantization mode, "tensor", "channel", or "block" quantization_mode (str): Quantization mode, "tensor", "channel", or "block"
block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block") block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block")
disable_numpy_memmap (bool): Disable numpy memmap when loading safetensors
weight_transform_hooks (WeightTransformHooks, optional): Hooks for weight transformation during loading
Returns: Returns:
dict: FP8 optimized state dict dict: FP8 optimized state dict
@@ -288,7 +294,9 @@ def load_safetensors_with_fp8_optimization(
# Process each file # Process each file
state_dict = {} state_dict = {}
for model_file in model_files: for model_file in model_files:
with MemoryEfficientSafeOpen(model_file) as f: with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
keys = f.keys() keys = f.keys()
for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"): for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"):
value = f.get_tensor(key) value = f.get_tensor(key)
@@ -311,6 +319,11 @@ def load_safetensors_with_fp8_optimization(
value = value.to(calc_device) value = value.to(calc_device)
original_dtype = value.dtype original_dtype = value.dtype
if original_dtype.itemsize == 1:
raise ValueError(
f"Layer {key} is already in {original_dtype} format. `--fp8_scaled` optimization should not be applied. Please use fp16/bf16/float32 model weights."
+ f" / レイヤー {key} は既に{original_dtype}形式です。`--fp8_scaled` 最適化は適用できません。FP16/BF16/Float32のモデル重みを使用してください。"
)
quantized_weight, scale_tensor = quantize_weight( quantized_weight, scale_tensor = quantize_weight(
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
) )
@@ -387,7 +400,7 @@ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=
else: else:
o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight) o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
o = o.reshape(original_shape[0], original_shape[1], -1) if x.ndim == 3 else o.reshape(original_shape[0], -1) o = o.reshape(original_shape[0], original_shape[1], -1) if len(original_shape) == 3 else o.reshape(original_shape[0], -1)
return o.to(input_dtype) return o.to(input_dtype)
else: else:

View File

@@ -5,7 +5,7 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from library.device_utils import synchronize_device from library.device_utils import synchronize_device
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
from library.safetensors_utils import MemoryEfficientSafeOpen from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
from library.utils import setup_logging from library.utils import setup_logging
setup_logging() setup_logging()
@@ -44,7 +44,7 @@ def filter_lora_state_dict(
def load_safetensors_with_lora_and_fp8( def load_safetensors_with_lora_and_fp8(
model_files: Union[str, List[str]], model_files: Union[str, List[str]],
lora_weights_list: Optional[Dict[str, torch.Tensor]], lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
lora_multipliers: Optional[List[float]], lora_multipliers: Optional[List[float]],
fp8_optimization: bool, fp8_optimization: bool,
calc_device: torch.device, calc_device: torch.device,
@@ -52,19 +52,23 @@ def load_safetensors_with_lora_and_fp8(
dit_weight_dtype: Optional[torch.dtype] = None, dit_weight_dtype: Optional[torch.dtype] = None,
target_keys: Optional[List[str]] = None, target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
""" """
Merge LoRA weights into the state dict of a model with fp8 optimization if needed. Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
Args: Args:
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix. model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
lora_weights_list (Optional[Dict[str, torch.Tensor]]): Dictionary of LoRA weight tensors to load. lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights. lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
fp8_optimization (bool): Whether to apply FP8 optimization. fp8_optimization (bool): Whether to apply FP8 optimization.
calc_device (torch.device): Device to calculate on. calc_device (torch.device): Device to calculate on.
move_to_device (bool): Whether to move tensors to the calculation device after loading. move_to_device (bool): Whether to move tensors to the calculation device after loading.
target_keys (Optional[List[str]]): Keys to target for optimization. target_keys (Optional[List[str]]): Keys to target for optimization.
exclude_keys (Optional[List[str]]): Keys to exclude from optimization. exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
""" """
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
@@ -73,19 +77,9 @@ def load_safetensors_with_lora_and_fp8(
extended_model_files = [] extended_model_files = []
for model_file in model_files: for model_file in model_files:
basename = os.path.basename(model_file) split_filenames = get_split_weight_filenames(model_file)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename) if split_filenames is not None:
if match: extended_model_files.extend(split_filenames)
prefix = basename[: match.start(2)]
count = int(match.group(3))
state_dict = {}
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(model_file), filename)
if os.path.exists(filepath):
extended_model_files.append(filepath)
else:
raise FileNotFoundError(f"File {filepath} not found")
else: else:
extended_model_files.append(model_file) extended_model_files.append(model_file)
model_files = extended_model_files model_files = extended_model_files
@@ -114,7 +108,7 @@ def load_safetensors_with_lora_and_fp8(
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}") logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
# make hook for LoRA merging # make hook for LoRA merging
def weight_hook_func(model_weight_key, model_weight, keep_on_calc_device=False): def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
if not model_weight_key.endswith(".weight"): if not model_weight_key.endswith(".weight"):
@@ -145,6 +139,13 @@ def load_safetensors_with_lora_and_fp8(
down_weight = down_weight.to(calc_device) down_weight = down_weight.to(calc_device)
up_weight = up_weight.to(calc_device) up_weight = up_weight.to(calc_device)
original_dtype = model_weight.dtype
if original_dtype.itemsize == 1: # fp8
# temporarily convert to float16 for calculation
model_weight = model_weight.to(torch.float16)
down_weight = down_weight.to(torch.float16)
up_weight = up_weight.to(torch.float16)
# W <- W + U * D # W <- W + U * D
if len(model_weight.size()) == 2: if len(model_weight.size()) == 2:
# linear # linear
@@ -166,6 +167,9 @@ def load_safetensors_with_lora_and_fp8(
# logger.info(conved.size(), weight.size(), module.stride, module.padding) # logger.info(conved.size(), weight.size(), module.stride, module.padding)
model_weight = model_weight + multiplier * conved * scale model_weight = model_weight + multiplier * conved * scale
if original_dtype.itemsize == 1: # fp8
model_weight = model_weight.to(original_dtype) # convert back to original dtype
# remove LoRA keys from set # remove LoRA keys from set
lora_weight_keys.remove(down_key) lora_weight_keys.remove(down_key)
lora_weight_keys.remove(up_key) lora_weight_keys.remove(up_key)
@@ -187,6 +191,8 @@ def load_safetensors_with_lora_and_fp8(
target_keys, target_keys,
exclude_keys, exclude_keys,
weight_hook=weight_hook, weight_hook=weight_hook,
disable_numpy_memmap=disable_numpy_memmap,
weight_transform_hooks=weight_transform_hooks,
) )
for lora_weight_keys in list_of_lora_weight_keys: for lora_weight_keys in list_of_lora_weight_keys:
@@ -208,6 +214,8 @@ def load_safetensors_with_fp8_optimization_and_hook(
target_keys: Optional[List[str]] = None, target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None,
weight_hook: callable = None, weight_hook: callable = None,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
""" """
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed. Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
@@ -218,7 +226,14 @@ def load_safetensors_with_fp8_optimization_and_hook(
) )
# dit_weight_dtype is not used because we use fp8 optimization # dit_weight_dtype is not used because we use fp8 optimization
state_dict = load_safetensors_with_fp8_optimization( state_dict = load_safetensors_with_fp8_optimization(
model_files, calc_device, target_keys, exclude_keys, move_to_device=move_to_device, weight_hook=weight_hook model_files,
calc_device,
target_keys,
exclude_keys,
move_to_device=move_to_device,
weight_hook=weight_hook,
disable_numpy_memmap=disable_numpy_memmap,
weight_transform_hooks=weight_transform_hooks,
) )
else: else:
logger.info( logger.info(
@@ -226,7 +241,8 @@ def load_safetensors_with_fp8_optimization_and_hook(
) )
state_dict = {} state_dict = {}
for model_file in model_files: for model_file in model_files:
with MemoryEfficientSafeOpen(model_file) as f: with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False): for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
if weight_hook is None and move_to_device: if weight_hook is None and move_to_device:
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype) value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass
import os import os
import re import re
import numpy as np import numpy as np
@@ -44,6 +45,7 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
validated[key] = value validated[key] = value
return validated return validated
# print(f"Using memory efficient save file: {filename}")
header = {} header = {}
offset = 0 offset = 0
@@ -88,15 +90,17 @@ class MemoryEfficientSafeOpen:
by using memory mapping for large tensors and avoiding unnecessary copies. by using memory mapping for large tensors and avoiding unnecessary copies.
""" """
def __init__(self, filename): def __init__(self, filename, disable_numpy_memmap=False):
"""Initialize the SafeTensor reader. """Initialize the SafeTensor reader.
Args: Args:
filename (str): Path to the safetensors file to read. filename (str): Path to the safetensors file to read.
disable_numpy_memmap (bool): If True, disable numpy memory mapping for large tensors, using standard file read instead.
""" """
self.filename = filename self.filename = filename
self.file = open(filename, "rb") self.file = open(filename, "rb")
self.header, self.header_size = self._read_header() self.header, self.header_size = self._read_header()
self.disable_numpy_memmap = disable_numpy_memmap
def __enter__(self): def __enter__(self):
"""Enter context manager.""" """Enter context manager."""
@@ -178,7 +182,8 @@ class MemoryEfficientSafeOpen:
# Use memmap for large tensors to avoid intermediate copies. # Use memmap for large tensors to avoid intermediate copies.
# If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired. # If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired.
# So we only use memmap if device is not cpu. # So we only use memmap if device is not cpu.
if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu": # If disable_numpy_memmap is True, skip numpy memory mapping to load with standard file read.
if not self.disable_numpy_memmap and num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
# Create memory map for zero-copy reading # Create memory map for zero-copy reading
mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,)) mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,))
byte_tensor = torch.from_numpy(mm) # zero copy byte_tensor = torch.from_numpy(mm) # zero copy
@@ -285,7 +290,11 @@ class MemoryEfficientSafeOpen:
def load_safetensors( def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = None path: str,
device: Union[str, torch.device],
disable_mmap: bool = False,
dtype: Optional[torch.dtype] = None,
disable_numpy_memmap: bool = False,
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
if disable_mmap: if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read()) # return safetensors.torch.load(open(path, "rb").read())
@@ -293,7 +302,7 @@ def load_safetensors(
# logger.info(f"Loading without mmap (experimental)") # logger.info(f"Loading without mmap (experimental)")
state_dict = {} state_dict = {}
device = torch.device(device) if device is not None else None device = torch.device(device) if device is not None else None
with MemoryEfficientSafeOpen(path) as f: with MemoryEfficientSafeOpen(path, disable_numpy_memmap=disable_numpy_memmap) as f:
for key in f.keys(): for key in f.keys():
state_dict[key] = f.get_tensor(key, device=device, dtype=dtype) state_dict[key] = f.get_tensor(key, device=device, dtype=dtype)
synchronize_device(device) synchronize_device(device)
@@ -309,6 +318,29 @@ def load_safetensors(
return state_dict return state_dict
def get_split_weight_filenames(file_path: str) -> Optional[list[str]]:
"""
Get the list of split weight filenames (full paths) if the file name ends with 00001-of-00004 etc.
Returns None if the file is not split.
"""
basename = os.path.basename(file_path)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
filenames = []
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(file_path), filename)
if os.path.exists(filepath):
filenames.append(filepath)
else:
raise FileNotFoundError(f"File {filepath} not found")
return filenames
else:
return None
def load_split_weights( def load_split_weights(
file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False, dtype: Optional[torch.dtype] = None file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
@@ -319,19 +351,11 @@ def load_split_weights(
device = torch.device(device) device = torch.device(device)
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
basename = os.path.basename(file_path) split_filenames = get_split_weight_filenames(file_path)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename) if split_filenames is not None:
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
state_dict = {} state_dict = {}
for i in range(count): for filename in split_filenames:
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors" state_dict.update(load_safetensors(filename, device=device, disable_mmap=disable_mmap, dtype=dtype))
filepath = os.path.join(os.path.dirname(file_path), filename)
if os.path.exists(filepath):
state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap, dtype=dtype))
else:
raise FileNotFoundError(f"File {filepath} not found")
else: else:
state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap, dtype=dtype) state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
return state_dict return state_dict
@@ -349,3 +373,107 @@ def find_key(safetensors_file: str, starts_with: Optional[str] = None, ends_with
if (starts_with is None or key.startswith(starts_with)) and (ends_with is None or key.endswith(ends_with)): if (starts_with is None or key.startswith(starts_with)) and (ends_with is None or key.endswith(ends_with)):
return key return key
return None return None
@dataclass
class WeightTransformHooks:
split_hook: Optional[callable] = None
concat_hook: Optional[callable] = None
rename_hook: Optional[callable] = None
class TensorWeightAdapter:
"""
A wrapper for weight conversion hooks (split and concat) to be used with MemoryEfficientSafeOpen.
This wrapper adapts the original MemoryEfficientSafeOpen to apply the provided split and concat hooks
when loading tensors.
split_hook: A callable that takes (original_key: str, original_tensor: torch.Tensor) and returns (new_keys: list[str], new_tensors: list[torch.Tensor]).
concat_hook: A callable that takes (original_key: str, tensors: dict[str, torch.Tensor]) and returns (new_key: str, concatenated_tensor: torch.Tensor).
rename_hook: A callable that takes (original_key: str) and returns (new_key: str).
If tensors is None, the hook should return only the new keys (for split) or new key (for concat), without tensors.
No need to implement __enter__ and __exit__ methods, as they are handled by the original MemoryEfficientSafeOpen.
Do not use this wrapper as a context manager directly, like `with WeightConvertHookWrapper(...) as f:`.
**concat_hook is not tested yet.**
"""
def __init__(self, weight_convert_hook: WeightTransformHooks, original_f: MemoryEfficientSafeOpen):
self.original_f = original_f
self.new_key_to_original_key_map: dict[str, Union[str, list[str]]] = (
{}
) # for split: new_key -> original_key; for concat: new_key -> list of original_keys; for direct mapping: new_key -> original_key
self.concat_key_set = set() # set of concatenated keys
self.split_key_set = set() # set of split keys
self.new_keys = []
self.tensor_cache = {} # cache for split tensors
self.split_hook = weight_convert_hook.split_hook
self.concat_hook = weight_convert_hook.concat_hook
self.rename_hook = weight_convert_hook.rename_hook
for key in self.original_f.keys():
if self.split_hook is not None:
converted_keys, _ = self.split_hook(key, None) # get new keys only
if converted_keys is not None:
for converted_key in converted_keys:
self.new_key_to_original_key_map[converted_key] = key
self.split_key_set.add(converted_key)
self.new_keys.extend(converted_keys)
continue # skip concat_hook if split_hook is applied
if self.concat_hook is not None:
converted_key, _ = self.concat_hook(key, None) # get new key only
if converted_key is not None:
if converted_key not in self.concat_key_set: # first time seeing this concatenated key
self.concat_key_set.add(converted_key)
self.new_key_to_original_key_map[converted_key] = []
# multiple original keys map to the same concatenated key
self.new_key_to_original_key_map[converted_key].append(key)
self.new_keys.append(converted_key)
continue # skip to next key
# direct mapping
if self.rename_hook is not None:
new_key = self.rename_hook(key)
self.new_key_to_original_key_map[new_key] = key
else:
new_key = key
self.new_keys.append(new_key)
def keys(self) -> list[str]:
return self.new_keys
def get_tensor(self, new_key: str, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
# load tensor by new_key, applying split or concat hooks as needed
if new_key not in self.new_key_to_original_key_map:
# direct mapping
return self.original_f.get_tensor(new_key, device=device, dtype=dtype)
elif new_key in self.split_key_set:
# split hook: split key is requested multiple times, so we cache the result
original_key = self.new_key_to_original_key_map[new_key]
if original_key not in self.tensor_cache: # not yet split
original_tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
new_keys, new_tensors = self.split_hook(original_key, original_tensor) # apply split hook
for k, t in zip(new_keys, new_tensors):
self.tensor_cache[k] = t
return self.tensor_cache.pop(new_key) # return and remove from cache
elif new_key in self.concat_key_set:
# concat hook: concatenated key is requested only once, so we do not cache the result
tensors = {}
for original_key in self.new_key_to_original_key_map[new_key]:
tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
tensors[original_key] = tensor
_, concatenated_tensors = self.concat_hook(self.new_key_to_original_key_map[new_key][0], tensors) # apply concat hook
return concatenated_tensors
else:
# direct mapping
original_key = self.new_key_to_original_key_map[new_key]
return self.original_f.get_tensor(original_key, device=device, dtype=dtype)

View File

@@ -9,6 +9,7 @@ import torch
from library import anima_utils, train_util from library import anima_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library import qwen_image_autoencoder_kl
from library.utils import setup_logging from library.utils import setup_logging
@@ -45,8 +46,8 @@ class AnimaTokenizeStrategy(TokenizeStrategy):
t5_tokenizer = anima_utils.load_t5_tokenizer(t5_tokenizer_path) t5_tokenizer = anima_utils.load_t5_tokenizer(t5_tokenizer_path)
self.qwen3_tokenizer = qwen3_tokenizer self.qwen3_tokenizer = qwen3_tokenizer
self.t5_tokenizer = t5_tokenizer
self.qwen3_max_length = qwen3_max_length self.qwen3_max_length = qwen3_max_length
self.t5_tokenizer = t5_tokenizer
self.t5_max_length = t5_max_length self.t5_max_length = t5_max_length
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
@@ -54,26 +55,17 @@ class AnimaTokenizeStrategy(TokenizeStrategy):
# Tokenize with Qwen3 # Tokenize with Qwen3
qwen3_encoding = self.qwen3_tokenizer.batch_encode_plus( qwen3_encoding = self.qwen3_tokenizer.batch_encode_plus(
text, text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.qwen3_max_length
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=self.qwen3_max_length,
) )
qwen3_input_ids = qwen3_encoding["input_ids"] qwen3_input_ids = qwen3_encoding["input_ids"]
qwen3_attn_mask = qwen3_encoding["attention_mask"] qwen3_attn_mask = qwen3_encoding["attention_mask"]
# Tokenize with T5 (for LLM Adapter target tokens) # Tokenize with T5 (for LLM Adapter target tokens)
t5_encoding = self.t5_tokenizer.batch_encode_plus( t5_encoding = self.t5_tokenizer.batch_encode_plus(
text, text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.t5_max_length
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=self.t5_max_length,
) )
t5_input_ids = t5_encoding["input_ids"] t5_input_ids = t5_encoding["input_ids"]
t5_attn_mask = t5_encoding["attention_mask"] t5_attn_mask = t5_encoding["attention_mask"]
return [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask] return [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
@@ -84,46 +76,11 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
T5 tokens are passed through unchanged (only used by LLM Adapter). T5 tokens are passed through unchanged (only used by LLM Adapter).
""" """
def __init__( def __init__(self) -> None:
self, super().__init__()
dropout_rate: float = 0.0,
) -> None:
self.dropout_rate = dropout_rate
# Cached unconditional embeddings (from encoding empty caption "")
# Must be initialized via cache_uncond_embeddings() before text encoder is deleted
self._uncond_prompt_embeds: Optional[torch.Tensor] = None # (1, seq_len, hidden)
self._uncond_attn_mask: Optional[torch.Tensor] = None # (1, seq_len)
self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len)
self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len)
def cache_uncond_embeddings(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
) -> None:
"""Pre-encode empty caption "" and cache the unconditional embeddings.
Must be called before the text encoder is deleted from GPU.
This matches diffusion-pipe-main behavior where empty caption embeddings
are pre-cached and swapped in during caption dropout.
"""
logger.info("Caching unconditional embeddings for caption dropout (encoding empty caption)...")
tokens = tokenize_strategy.tokenize("")
with torch.no_grad():
uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens, enable_dropout=False)
# Store as CPU tensors (1, seq_len, ...) to avoid GPU memory waste
self._uncond_prompt_embeds = uncond_outputs[0].cpu()
self._uncond_attn_mask = uncond_outputs[1].cpu()
self._uncond_t5_input_ids = uncond_outputs[2].cpu()
self._uncond_t5_attn_mask = uncond_outputs[3].cpu()
logger.info(" Unconditional embeddings cached successfully")
def encode_tokens( def encode_tokens(
self, self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
enable_dropout: bool = True,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
"""Encode Qwen3 tokens and return embeddings + T5 token IDs. """Encode Qwen3 tokens and return embeddings + T5 token IDs.
@@ -134,82 +91,20 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
Returns: Returns:
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask] [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
""" """
# Do not handle dropout here; handled dataset-side or in drop_cached_text_encoder_outputs()
qwen3_text_encoder = models[0] qwen3_text_encoder = models[0]
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
# Handle dropout: replace dropped items with unconditional embeddings (matching diffusion-pipe-main) encoder_device = qwen3_text_encoder.device
batch_size = qwen3_input_ids.shape[0]
non_drop_indices = []
for i in range(batch_size):
drop = enable_dropout and (self.dropout_rate > 0.0 and random.random() < self.dropout_rate)
if not drop:
non_drop_indices.append(i)
encoder_device = qwen3_text_encoder.device if hasattr(qwen3_text_encoder, 'device') else next(qwen3_text_encoder.parameters()).device qwen3_input_ids = qwen3_input_ids.to(encoder_device)
qwen3_attn_mask = qwen3_attn_mask.to(encoder_device)
outputs = qwen3_text_encoder(input_ids=qwen3_input_ids, attention_mask=qwen3_attn_mask)
prompt_embeds = outputs.last_hidden_state
prompt_embeds[~qwen3_attn_mask.bool()] = 0
if len(non_drop_indices) > 0 and len(non_drop_indices) < batch_size: return [prompt_embeds, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
# Only encode non-dropped items to save compute
nd_input_ids = qwen3_input_ids[non_drop_indices].to(encoder_device)
nd_attn_mask = qwen3_attn_mask[non_drop_indices].to(encoder_device)
elif len(non_drop_indices) == batch_size:
nd_input_ids = qwen3_input_ids.to(encoder_device)
nd_attn_mask = qwen3_attn_mask.to(encoder_device)
else:
nd_input_ids = None
nd_attn_mask = None
if nd_input_ids is not None:
outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask)
nd_encoded_text = outputs.last_hidden_state
# Zero out padding positions
nd_encoded_text[~nd_attn_mask.bool()] = 0
# Build full batch: fill non-dropped with encoded, dropped with unconditional
if len(non_drop_indices) == batch_size:
prompt_embeds = nd_encoded_text
attn_mask = qwen3_attn_mask.to(encoder_device)
else:
# Get unconditional embeddings
if self._uncond_prompt_embeds is not None:
uncond_pe = self._uncond_prompt_embeds[0]
uncond_am = self._uncond_attn_mask[0]
uncond_t5_ids = self._uncond_t5_input_ids[0]
uncond_t5_am = self._uncond_t5_attn_mask[0]
else:
# Encode empty caption on-the-fly (text encoder still available)
uncond_tokens = tokenize_strategy.tokenize("")
uncond_ids = uncond_tokens[0].to(encoder_device)
uncond_mask = uncond_tokens[1].to(encoder_device)
uncond_out = qwen3_text_encoder(input_ids=uncond_ids, attention_mask=uncond_mask)
uncond_pe = uncond_out.last_hidden_state[0]
uncond_pe[~uncond_mask[0].bool()] = 0
uncond_am = uncond_mask[0]
uncond_t5_ids = uncond_tokens[2][0]
uncond_t5_am = uncond_tokens[3][0]
seq_len = qwen3_input_ids.shape[1]
hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1]
dtype = nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype
prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype)
attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype)
if len(non_drop_indices) > 0:
prompt_embeds[non_drop_indices] = nd_encoded_text
attn_mask[non_drop_indices] = nd_attn_mask
# Fill dropped items with unconditional embeddings
t5_input_ids = t5_input_ids.clone()
t5_attn_mask = t5_attn_mask.clone()
drop_indices = [i for i in range(batch_size) if i not in non_drop_indices]
for i in drop_indices:
prompt_embeds[i] = uncond_pe.to(device=encoder_device, dtype=dtype)
attn_mask[i] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype)
t5_input_ids[i] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
t5_attn_mask[i] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
def drop_cached_text_encoder_outputs( def drop_cached_text_encoder_outputs(
self, self,
@@ -217,6 +112,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
attn_mask: torch.Tensor, attn_mask: torch.Tensor,
t5_input_ids: torch.Tensor, t5_input_ids: torch.Tensor,
t5_attn_mask: torch.Tensor, t5_attn_mask: torch.Tensor,
caption_dropout_rates: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
"""Apply dropout to cached text encoder outputs. """Apply dropout to cached text encoder outputs.
@@ -224,37 +120,30 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
Replaces dropped items with pre-cached unconditional embeddings (from encoding "") Replaces dropped items with pre-cached unconditional embeddings (from encoding "")
to match diffusion-pipe-main behavior. to match diffusion-pipe-main behavior.
""" """
if prompt_embeds is not None and self.dropout_rate > 0.0: if caption_dropout_rates is None or all(caption_dropout_rates == 0.0):
# Clone to avoid in-place modification of cached tensors return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
prompt_embeds = prompt_embeds.clone()
if attn_mask is not None:
attn_mask = attn_mask.clone()
if t5_input_ids is not None:
t5_input_ids = t5_input_ids.clone()
if t5_attn_mask is not None:
t5_attn_mask = t5_attn_mask.clone()
for i in range(prompt_embeds.shape[0]): # Clone to avoid in-place modification of cached tensors
if random.random() < self.dropout_rate: prompt_embeds = prompt_embeds.clone()
if self._uncond_prompt_embeds is not None: if attn_mask is not None:
# Use pre-cached unconditional embeddings attn_mask = attn_mask.clone()
prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) if t5_input_ids is not None:
if attn_mask is not None: t5_input_ids = t5_input_ids.clone()
attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype) if t5_attn_mask is not None:
if t5_input_ids is not None: t5_attn_mask = t5_attn_mask.clone()
t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
if t5_attn_mask is not None: for i in range(prompt_embeds.shape[0]):
t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype) if random.random() < caption_dropout_rates[i].item():
else: # Use pre-cached unconditional embeddings
# Fallback: zero out (should not happen if cache_uncond_embeddings was called) prompt_embeds[i] = 0
logger.warning("Unconditional embeddings not cached, falling back to zeros for caption dropout") if attn_mask is not None:
prompt_embeds[i] = torch.zeros_like(prompt_embeds[i]) attn_mask[i] = 0
if attn_mask is not None: if t5_input_ids is not None:
attn_mask[i] = torch.zeros_like(attn_mask[i]) t5_input_ids[i, 0] = 1 # Set to </s> token ID
if t5_input_ids is not None: t5_input_ids[i, 1:] = 0
t5_input_ids[i] = torch.zeros_like(t5_input_ids[i]) if t5_attn_mask is not None:
if t5_attn_mask is not None: t5_attn_mask[i, 0] = 1
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) t5_attn_mask[i, 1:] = 0
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask] return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
@@ -297,6 +186,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return False return False
if "t5_attn_mask" not in npz: if "t5_attn_mask" not in npz:
return False return False
if "caption_dropout_rate" not in npz:
return False
except Exception as e: except Exception as e:
logger.error(f"Error loading file: {npz_path}") logger.error(f"Error loading file: {npz_path}")
raise e raise e
@@ -309,7 +200,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask = data["attn_mask"] attn_mask = data["attn_mask"]
t5_input_ids = data["t5_input_ids"] t5_input_ids = data["t5_input_ids"]
t5_attn_mask = data["t5_attn_mask"] t5_attn_mask = data["t5_attn_mask"]
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask] caption_dropout_rate = data["caption_dropout_rate"]
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
def cache_batch_outputs( def cache_batch_outputs(
self, self,
@@ -323,12 +215,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokens_and_masks = tokenize_strategy.tokenize(captions) tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad(): with torch.no_grad():
# Always disable dropout during caching
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = anima_text_encoding_strategy.encode_tokens( prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = anima_text_encoding_strategy.encode_tokens(
tokenize_strategy, tokenize_strategy, models, tokens_and_masks
models,
tokens_and_masks,
enable_dropout=False,
) )
# Convert to numpy for caching # Convert to numpy for caching
@@ -344,6 +232,7 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask_i = attn_mask[i] attn_mask_i = attn_mask[i]
t5_input_ids_i = t5_input_ids[i] t5_input_ids_i = t5_input_ids[i]
t5_attn_mask_i = t5_attn_mask[i] t5_attn_mask_i = t5_attn_mask[i]
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
if self.cache_to_disk: if self.cache_to_disk:
np.savez( np.savez(
@@ -352,9 +241,10 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask=attn_mask_i, attn_mask=attn_mask_i,
t5_input_ids=t5_input_ids_i, t5_input_ids=t5_input_ids_i,
t5_attn_mask=t5_attn_mask_i, t5_attn_mask=t5_attn_mask_i,
caption_dropout_rate=caption_dropout_rate,
) )
else: else:
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i) info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
class AnimaLatentsCachingStrategy(LatentsCachingStrategy): class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
@@ -374,18 +264,10 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
return self.ANIMA_LATENTS_NPZ_SUFFIX return self.ANIMA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return ( return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ self.ANIMA_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected( def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
):
return self._default_is_disk_cached_latents_expected(
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
)
def load_latents_from_disk( def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int] self, npz_path: str, bucket_reso: Tuple[int, int]
@@ -393,32 +275,23 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
"""Cache batch of latents using WanVAE. """Cache batch of latents using Qwen Image VAE.
vae is expected to be the WanVAE_ model (not the wrapper). vae is expected to be the Qwen Image VAE (AutoencoderKLQwenImage).
The encoding function handles the mean/std normalization. The encoding function handles the mean/std normalization.
""" """
from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage = vae
vae_device = vae.device
vae_device = next(vae.parameters()).device vae_dtype = vae.dtype
vae_dtype = next(vae.parameters()).dtype
# Create scale tensors on VAE device
mean = torch.tensor(ANIMA_VAE_MEAN, dtype=vae_dtype, device=vae_device)
std = torch.tensor(ANIMA_VAE_STD, dtype=vae_dtype, device=vae_device)
scale = [mean, 1.0 / std]
def encode_by_vae(img_tensor): def encode_by_vae(img_tensor):
"""Encode image tensor to latents. """Encode image tensor to latents.
img_tensor: (B, C, H, W) in [-1, 1] range (already normalized by IMAGE_TRANSFORMS) img_tensor: (B, C, H, W) in [-1, 1] range (already normalized by IMAGE_TRANSFORMS)
Need to add temporal dim to get (B, C, T=1, H, W) for WanVAE Qwen Image VAE accepts inputs in (B, C, H, W) or (B, C, 1, H, W) shape.
Returns latents in (B, 16, 1, H/8, W/8) shape on CPU.
""" """
# Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W) latents = vae.encode_pixels_to_latents(img_tensor)
img_tensor = img_tensor.unsqueeze(2)
img_tensor = img_tensor.to(vae_device, dtype=vae_dtype)
latents = vae.encode(img_tensor, scale)
return latents.to("cpu") return latents.to("cpu")
self._default_cache_batch_latents( self._default_cache_batch_latents(

View File

@@ -179,12 +179,15 @@ def split_train_val(
class ImageInfo: class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: def __init__(
self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str, caption_dropout_rate: float = 0.0
) -> None:
self.image_key: str = image_key self.image_key: str = image_key
self.num_repeats: int = num_repeats self.num_repeats: int = num_repeats
self.caption: str = caption self.caption: str = caption
self.is_reg: bool = is_reg self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path self.absolute_path: str = absolute_path
self.caption_dropout_rate: float = caption_dropout_rate
self.image_size: Tuple[int, int] = None self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None self.bucket_reso: Tuple[int, int] = None
@@ -197,7 +200,7 @@ class ImageInfo:
) )
self.cond_img_path: Optional[str] = None self.cond_img_path: Optional[str] = None
self.image: Optional[Image.Image] = None # optional, original PIL Image self.image: Optional[Image.Image] = None # optional, original PIL Image
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs self.text_encoder_outputs_npz: Optional[str] = None # filename. set in cache_text_encoder_outputs
# new # new
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
@@ -1096,11 +1099,11 @@ class BaseDataset(torch.utils.data.Dataset):
def is_latent_cacheable(self): def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
def is_text_encoder_output_cacheable(self): def is_text_encoder_output_cacheable(self, cache_supports_dropout: bool = False):
return all( return all(
[ [
not ( not (
subset.caption_dropout_rate > 0 subset.caption_dropout_rate > 0 and not cache_supports_dropout
or subset.shuffle_caption or subset.shuffle_caption
or subset.token_warmup_step > 0 or subset.token_warmup_step > 0
or subset.caption_tag_dropout_rate > 0 or subset.caption_tag_dropout_rate > 0
@@ -2137,7 +2140,7 @@ class DreamBoothDataset(BaseDataset):
num_train_images += num_repeats * len(img_paths) num_train_images += num_repeats * len(img_paths)
for img_path, caption, size in zip(img_paths, captions, sizes): for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path, subset.caption_dropout_rate)
info.resize_interpolation = ( info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
) )
@@ -2338,7 +2341,7 @@ class FineTuningDataset(BaseDataset):
if caption is None: if caption is None:
caption = "" caption = ""
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path) image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path, subset.caption_dropout_rate)
image_info.resize_interpolation = ( image_info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
) )
@@ -2661,8 +2664,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def is_latent_cacheable(self) -> bool: def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets]) return all([dataset.is_latent_cacheable() for dataset in self.datasets])
def is_text_encoder_output_cacheable(self) -> bool: def is_text_encoder_output_cacheable(self, cache_supports_dropout: bool = False) -> bool:
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets]) return all([dataset.is_text_encoder_output_cacheable(cache_supports_dropout) for dataset in self.datasets])
def set_current_strategies(self): def set_current_strategies(self):
for dataset in self.datasets: for dataset in self.datasets:

View File

@@ -1,18 +1,17 @@
# LoRA network module for Anima # LoRA network module for Anima
import math import ast
import os import os
import re
from typing import Dict, List, Optional, Tuple, Type, Union from typing import Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch import torch
from library.utils import setup_logging from library.utils import setup_logging
from networks.lora_flux import LoRAModule, LoRAInfModule
setup_logging()
import logging import logging
setup_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from networks.lora_flux import LoRAModule, LoRAInfModule
def create_network( def create_network(
multiplier: float, multiplier: float,
@@ -29,68 +28,28 @@ def create_network(
if network_alpha is None: if network_alpha is None:
network_alpha = 1.0 network_alpha = 1.0
# type_dims: [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
self_attn_dim = kwargs.get("self_attn_dim", None)
cross_attn_dim = kwargs.get("cross_attn_dim", None)
mlp_dim = kwargs.get("mlp_dim", None)
mod_dim = kwargs.get("mod_dim", None)
llm_adapter_dim = kwargs.get("llm_adapter_dim", None)
if self_attn_dim is not None:
self_attn_dim = int(self_attn_dim)
if cross_attn_dim is not None:
cross_attn_dim = int(cross_attn_dim)
if mlp_dim is not None:
mlp_dim = int(mlp_dim)
if mod_dim is not None:
mod_dim = int(mod_dim)
if llm_adapter_dim is not None:
llm_adapter_dim = int(llm_adapter_dim)
type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
if all([d is None for d in type_dims]):
type_dims = None
# emb_dims: [x_embedder, t_embedder, final_layer]
emb_dims = kwargs.get("emb_dims", None)
if emb_dims is not None:
emb_dims = emb_dims.strip()
if emb_dims.startswith("[") and emb_dims.endswith("]"):
emb_dims = emb_dims[1:-1]
emb_dims = [int(d) for d in emb_dims.split(",")]
assert len(emb_dims) == 3, f"invalid emb_dims: {emb_dims}, must be 3 dimensions (x_embedder, t_embedder, final_layer)"
# block selection
def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
if selection == "all":
return [True] * total_blocks
if selection == "none" or selection == "":
return [False] * total_blocks
selected = [False] * total_blocks
ranges = selection.split(",")
for r in ranges:
if "-" in r:
start, end = map(str.strip, r.split("-"))
start, end = int(start), int(end)
assert 0 <= start < total_blocks and 0 <= end < total_blocks and start <= end
for i in range(start, end + 1):
selected[i] = True
else:
index = int(r)
assert 0 <= index < total_blocks
selected[index] = True
return selected
train_block_indices = kwargs.get("train_block_indices", None)
if train_block_indices is not None:
num_blocks = len(unet.blocks) if hasattr(unet, 'blocks') else 999
train_block_indices = parse_block_selection(train_block_indices, num_blocks)
# train LLM adapter # train LLM adapter
train_llm_adapter = kwargs.get("train_llm_adapter", False) train_llm_adapter = kwargs.get("train_llm_adapter", "false")
if train_llm_adapter is not None: if train_llm_adapter is not None:
train_llm_adapter = True if train_llm_adapter == "True" else False train_llm_adapter = True if train_llm_adapter.lower() == "true" else False
exclude_patterns = kwargs.get("exclude_patterns", None)
if exclude_patterns is None:
exclude_patterns = []
else:
exclude_patterns = ast.literal_eval(exclude_patterns)
if not isinstance(exclude_patterns, list):
exclude_patterns = [exclude_patterns]
# add default exclude patterns
exclude_patterns.append(r".*(_modulation|_norm|_embedder|final_layer).*")
# regular expression for module selection: exclude and include
include_patterns = kwargs.get("include_patterns", None)
if include_patterns is not None:
include_patterns = ast.literal_eval(include_patterns)
if not isinstance(include_patterns, list):
include_patterns = [include_patterns]
# rank/module dropout # rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None) rank_dropout = kwargs.get("rank_dropout", None)
@@ -101,9 +60,9 @@ def create_network(
module_dropout = float(module_dropout) module_dropout = float(module_dropout)
# verbose # verbose
verbose = kwargs.get("verbose", False) verbose = kwargs.get("verbose", "false")
if verbose is not None: if verbose is not None:
verbose = True if verbose == "True" else False verbose = True if verbose.lower() == "true" else False
network = LoRANetwork( network = LoRANetwork(
text_encoders, text_encoders,
@@ -115,9 +74,8 @@ def create_network(
rank_dropout=rank_dropout, rank_dropout=rank_dropout,
module_dropout=module_dropout, module_dropout=module_dropout,
train_llm_adapter=train_llm_adapter, train_llm_adapter=train_llm_adapter,
type_dims=type_dims, exclude_patterns=exclude_patterns,
emb_dims=emb_dims, include_patterns=include_patterns,
train_block_indices=train_block_indices,
verbose=verbose, verbose=verbose,
) )
@@ -137,6 +95,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh
if weights_sd is None: if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors": if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file from safetensors.torch import load_file
weights_sd = load_file(file) weights_sd = load_file(file)
else: else:
weights_sd = torch.load(file, map_location="cpu") weights_sd = torch.load(file, map_location="cpu")
@@ -173,8 +132,8 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh
class LoRANetwork(torch.nn.Module): class LoRANetwork(torch.nn.Module):
# Target modules: DiT blocks # Target modules: DiT blocks, embedders, final layer. embedders and final layer are excluded by default.
ANIMA_TARGET_REPLACE_MODULE = ["Block"] ANIMA_TARGET_REPLACE_MODULE = ["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"]
# Target modules: LLM Adapter blocks # Target modules: LLM Adapter blocks
ANIMA_ADAPTER_TARGET_REPLACE_MODULE = ["LLMAdapterTransformerBlock"] ANIMA_ADAPTER_TARGET_REPLACE_MODULE = ["LLMAdapterTransformerBlock"]
# Target modules for text encoder (Qwen3) # Target modules for text encoder (Qwen3)
@@ -197,9 +156,8 @@ class LoRANetwork(torch.nn.Module):
modules_dim: Optional[Dict[str, int]] = None, modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None,
train_llm_adapter: bool = False, train_llm_adapter: bool = False,
type_dims: Optional[List[int]] = None, exclude_patterns: Optional[List[str]] = None,
emb_dims: Optional[List[int]] = None, include_patterns: Optional[List[str]] = None,
train_block_indices: Optional[List[bool]] = None,
verbose: Optional[bool] = False, verbose: Optional[bool] = False,
) -> None: ) -> None:
super().__init__() super().__init__()
@@ -210,21 +168,36 @@ class LoRANetwork(torch.nn.Module):
self.rank_dropout = rank_dropout self.rank_dropout = rank_dropout
self.module_dropout = module_dropout self.module_dropout = module_dropout
self.train_llm_adapter = train_llm_adapter self.train_llm_adapter = train_llm_adapter
self.type_dims = type_dims
self.emb_dims = emb_dims
self.train_block_indices = train_block_indices
self.loraplus_lr_ratio = None self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None: if modules_dim is not None:
logger.info(f"create LoRA network from weights") logger.info("create LoRA network from weights")
if self.emb_dims is None: if self.emb_dims is None:
self.emb_dims = [0] * 3 self.emb_dims = [0] * 3
else: else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
# compile regular expression if specified
def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
re_patterns = []
if patterns is not None:
for pattern in patterns:
try:
re_pattern = re.compile(pattern)
except re.error as e:
logger.error(f"Invalid pattern '{pattern}': {e}")
continue
re_patterns.append(re_pattern)
return re_patterns
exclude_re_patterns = str_to_re_patterns(exclude_patterns)
include_re_patterns = str_to_re_patterns(include_patterns)
# create module instances # create module instances
def create_modules( def create_modules(
@@ -232,15 +205,9 @@ class LoRANetwork(torch.nn.Module):
text_encoder_idx: Optional[int], text_encoder_idx: Optional[int],
root_module: torch.nn.Module, root_module: torch.nn.Module,
target_replace_modules: List[str], target_replace_modules: List[str],
filter: Optional[str] = None,
default_dim: Optional[int] = None, default_dim: Optional[int] = None,
include_conv2d_if_filter: bool = False,
) -> Tuple[List[LoRAModule], List[str]]: ) -> Tuple[List[LoRAModule], List[str]]:
prefix = ( prefix = self.LORA_PREFIX_ANIMA if is_unet else self.LORA_PREFIX_TEXT_ENCODER
self.LORA_PREFIX_ANIMA
if is_unet
else self.LORA_PREFIX_TEXT_ENCODER
)
loras = [] loras = []
skipped = [] skipped = []
@@ -255,14 +222,16 @@ class LoRANetwork(torch.nn.Module):
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d: if is_linear or is_conv2d:
lora_name = prefix + "." + (name + "." if name else "") + child_name original_name = (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_") lora_name = f"{prefix}.{original_name}".replace(".", "_")
force_incl_conv2d = False # exclude/include filter
if filter is not None: excluded = any(pattern.match(original_name) for pattern in exclude_re_patterns)
if filter not in lora_name: included = any(pattern.match(original_name) for pattern in include_re_patterns)
continue if excluded and not included:
force_incl_conv2d = include_conv2d_if_filter if verbose:
logger.info(f"exclude: {original_name}")
continue
dim = None dim = None
alpha_val = None alpha_val = None
@@ -276,40 +245,6 @@ class LoRANetwork(torch.nn.Module):
dim = default_dim if default_dim is not None else self.lora_dim dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha alpha_val = self.alpha
if is_unet and type_dims is not None:
# type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
# Order matters: check most specific identifiers first to avoid mismatches.
identifier_order = [
(4, ("llm_adapter",)),
(3, ("adaln_modulation",)),
(0, ("self_attn",)),
(1, ("cross_attn",)),
(2, ("mlp",)),
]
for idx, ids in identifier_order:
d = type_dims[idx]
if d is not None and all(id_str in lora_name for id_str in ids):
dim = d # 0 means skip
break
# block index filtering
if is_unet and dim and self.train_block_indices is not None and "blocks_" in lora_name:
# Extract block index from lora_name: "lora_unet_blocks_0_self_attn..."
parts = lora_name.split("_")
for pi, part in enumerate(parts):
if part == "blocks" and pi + 1 < len(parts):
try:
block_index = int(parts[pi + 1])
if not self.train_block_indices[block_index]:
dim = 0
except (ValueError, IndexError):
pass
break
elif force_incl_conv2d:
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
if dim is None or dim == 0: if dim is None or dim == 0:
if is_linear or is_conv2d_1x1: if is_linear or is_conv2d_1x1:
skipped.append(lora_name) skipped.append(lora_name)
@@ -339,9 +274,7 @@ class LoRANetwork(torch.nn.Module):
if text_encoder is None: if text_encoder is None:
continue continue
logger.info(f"create LoRA for Text Encoder {i+1}:") logger.info(f"create LoRA for Text Encoder {i+1}:")
te_loras, te_skipped = create_modules( te_loras, te_skipped = create_modules(False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
)
logger.info(f"create LoRA for Text Encoder {i+1}: {len(te_loras)} modules.") logger.info(f"create LoRA for Text Encoder {i+1}: {len(te_loras)} modules.")
self.text_encoder_loras.extend(te_loras) self.text_encoder_loras.extend(te_loras)
skipped_te += te_skipped skipped_te += te_skipped
@@ -354,19 +287,6 @@ class LoRANetwork(torch.nn.Module):
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
# emb_dims: [x_embedder, t_embedder, final_layer]
if self.emb_dims:
for filter_name, in_dim in zip(
["x_embedder", "t_embedder", "final_layer"],
self.emb_dims,
):
loras, _ = create_modules(
True, None, unet, None,
filter=filter_name, default_dim=in_dim,
include_conv2d_if_filter=(filter_name == "x_embedder"),
)
self.unet_loras.extend(loras)
logger.info(f"create LoRA for Anima DiT: {len(self.unet_loras)} modules.") logger.info(f"create LoRA for Anima DiT: {len(self.unet_loras)} modules.")
if verbose: if verbose:
for lora in self.unet_loras: for lora in self.unet_loras:
@@ -396,6 +316,7 @@ class LoRANetwork(torch.nn.Module):
def load_weights(self, file): def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors": if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file from safetensors.torch import load_file
weights_sd = load_file(file) weights_sd = load_file(file)
else: else:
weights_sd = torch.load(file, map_location="cpu") weights_sd = torch.load(file, map_location="cpu")
@@ -443,10 +364,10 @@ class LoRANetwork(torch.nn.Module):
sd_for_lora = {} sd_for_lora = {}
for key in weights_sd.keys(): for key in weights_sd.keys():
if key.startswith(lora.lora_name): if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1:]] = weights_sd[key] sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device) lora.merge_to(sd_for_lora, dtype, device)
logger.info(f"weights are merged") logger.info("weights are merged")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio self.loraplus_lr_ratio = loraplus_lr_ratio
@@ -498,10 +419,7 @@ class LoRANetwork(torch.nn.Module):
if self.text_encoder_loras: if self.text_encoder_loras:
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
te1_loras = [ te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)]
lora for lora in self.text_encoder_loras
if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)
]
if len(te1_loras) > 0: if len(te1_loras) > 0:
logger.info(f"Text Encoder 1 (Qwen3): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") logger.info(f"Text Encoder 1 (Qwen3): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_ratio) params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_ratio)

View File

@@ -2,7 +2,7 @@
Diagnostic script to test Anima latent & text encoder caching independently. Diagnostic script to test Anima latent & text encoder caching independently.
Usage: Usage:
python test_anima_cache.py \ python manual_test_anima_cache.py \
--image_dir /path/to/images \ --image_dir /path/to/images \
--qwen3_path /path/to/qwen3 \ --qwen3_path /path/to/qwen3 \
--vae_path /path/to/vae.safetensors \ --vae_path /path/to/vae.safetensors \

View File

@@ -470,7 +470,7 @@ class NetworkTrainer:
loss = loss * weighting loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch) loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3]) loss = loss.mean(dim=list(range(1, loss.ndim))) # mean over all dims except batch
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights loss = loss * loss_weights