This commit is contained in:
Kohya S.
2026-02-10 08:25:27 +09:00
committed by GitHub
20 changed files with 3727 additions and 1119 deletions

View File

@@ -32,6 +32,7 @@ hime="hime"
OT="OT"
byt="byt"
tak="tak"
temperal="temperal"
[files]
extend-exclude = ["_typos.toml", "venv"]
extend-exclude = ["_typos.toml", "venv", "configs"]

1020
anima_minimal_inference.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -49,35 +49,32 @@ def train(args):
args.skip_cache_check = args.skip_latents_validity_check
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
)
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
args.cache_text_encoder_outputs = True
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
if getattr(args, 'unsloth_offload_checkpointing', False):
if getattr(args, "unsloth_offload_checkpointing", False):
if not args.gradient_checkpointing:
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
assert not args.cpu_offload_checkpointing, \
"Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not getattr(args, 'unsloth_offload_checkpointing', False), \
"blocks_to_swap is not supported with unsloth_offload_checkpointing"
assert (args.blocks_to_swap is None or args.blocks_to_swap == 0) or not getattr(
args, "unsloth_offload_checkpointing", False
), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
# Flash attention: validate availability
if getattr(args, 'flash_attn', False):
if getattr(args, "flash_attn", False):
try:
import flash_attn # noqa: F401
logger.info("Flash Attention enabled for DiT blocks")
except ImportError:
logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
@@ -104,9 +101,7 @@ def train(args):
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0}".format(", ".join(ignored))
)
logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored)))
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
@@ -150,7 +145,7 @@ def train(args):
# Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
# dataset-level caption dropout, so we save the rate and zero out subset-level
# caption_dropout_rate to allow text encoder output caching.
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0)
if caption_dropout_rate > 0:
logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
for dataset in train_dataset_group.datasets:
@@ -175,9 +170,7 @@ def train(args):
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used"
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used"
if args.cache_text_encoder_outputs:
assert (
@@ -193,7 +186,7 @@ def train(args):
# parse transformer_dtype
transformer_dtype = None
if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None:
if hasattr(args, "transformer_dtype") and args.transformer_dtype is not None:
transformer_dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
@@ -203,12 +196,8 @@ def train(args):
# Load tokenizers and set strategies
logger.info("Loading tokenizers...")
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(
args.qwen3_path, dtype=weight_dtype, device="cpu"
)
t5_tokenizer = anima_utils.load_t5_tokenizer(
getattr(args, 't5_tokenizer_path', None)
)
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=weight_dtype, device="cpu")
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
# Set tokenize strategy
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
@@ -220,7 +209,7 @@ def train(args):
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# Set text encoding strategy
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0)
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
dropout_rate=caption_dropout_rate,
)
@@ -266,10 +255,8 @@ def train(args):
)
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
if caption_dropout_rate > 0.0:
with accelerator.autocast():
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
with accelerator.autocast():
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
accelerator.wait_for_everyone()
@@ -299,17 +286,17 @@ def train(args):
dtype=weight_dtype,
device="cpu",
transformer_dtype=transformer_dtype,
llm_adapter_path=getattr(args, 'llm_adapter_path', None),
disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False),
llm_adapter_path=getattr(args, "llm_adapter_path", None),
disable_mmap=getattr(args, "disable_mmap_load_safetensors", False),
)
if args.gradient_checkpointing:
dit.enable_gradient_checkpointing(
cpu_offload=args.cpu_offload_checkpointing,
unsloth_offload=getattr(args, 'unsloth_offload_checkpointing', False),
unsloth_offload=getattr(args, "unsloth_offload_checkpointing", False),
)
if getattr(args, 'flash_attn', False):
if getattr(args, "flash_attn", False):
dit.set_flash_attn(True)
train_dit = args.learning_rate != 0
@@ -335,11 +322,11 @@ def train(args):
param_groups = anima_train_utils.get_anima_param_groups(
dit,
base_lr=args.learning_rate,
self_attn_lr=getattr(args, 'self_attn_lr', None),
cross_attn_lr=getattr(args, 'cross_attn_lr', None),
mlp_lr=getattr(args, 'mlp_lr', None),
mod_lr=getattr(args, 'mod_lr', None),
llm_adapter_lr=getattr(args, 'llm_adapter_lr', None),
self_attn_lr=getattr(args, "self_attn_lr", None),
cross_attn_lr=getattr(args, "cross_attn_lr", None),
mlp_lr=getattr(args, "mlp_lr", None),
mod_lr=getattr(args, "mod_lr", None),
llm_adapter_lr=getattr(args, "llm_adapter_lr", None),
)
else:
param_groups = []
@@ -366,8 +353,8 @@ def train(args):
# Build param_id → lr mapping from param_groups to propagate per-component LRs
param_lr_map = {}
for group in param_groups:
for p in group['params']:
param_lr_map[id(p)] = group['lr']
for p in group["params"]:
param_lr_map[id(p)] = group["lr"]
grouped_params = []
param_group = {}
@@ -557,9 +544,7 @@ def train(args):
accelerator.print(f" num examples: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch: {len(train_dataloader)}")
accelerator.print(f" num epochs: {num_train_epochs}")
accelerator.print(
f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
accelerator.print(f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps: {args.max_train_steps}")
@@ -580,6 +565,7 @@ def train(args):
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
import wandb
wandb.define_metric("epoch")
wandb.define_metric("loss/epoch", step_metric="epoch")
@@ -589,8 +575,16 @@ def train(args):
# For --sample_at_first
optimizer_eval_fn()
anima_train_utils.sample_images(
accelerator, args, 0, global_step, dit, vae, vae_scale,
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
accelerator,
args,
0,
global_step,
dit,
vae,
vae_scale,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs,
)
optimizer_train_fn()
@@ -600,7 +594,9 @@ def train(args):
# Show model info
unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None
if unwrapped_dit is not None:
logger.info(f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}")
logger.info(
f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}"
)
if qwen3_text_encoder is not None:
logger.info(f"qwen3 device: {next(qwen3_text_encoder.parameters()).device}")
if vae is not None:
@@ -640,9 +636,7 @@ def train(args):
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Cached outputs
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(
*text_encoder_outputs_list
)
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list
else:
# Encode on-the-fly
@@ -678,10 +672,7 @@ def train(args):
bs = latents.shape[0]
h_latent = latents.shape[-2]
w_latent = latents.shape[-1]
padding_mask = torch.zeros(
bs, 1, h_latent, w_latent,
dtype=weight_dtype, device=accelerator.device
)
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
if is_swapping_blocks:
@@ -708,9 +699,7 @@ def train(args):
# Loss
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None)
loss = train_util.conditional_loss(
model_pred.float(), target.float(), args.loss_type, "none", huber_c
)
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3, 4]) # (B, C, T, H, W) -> (B,)
@@ -748,8 +737,16 @@ def train(args):
optimizer_eval_fn()
anima_train_utils.sample_images(
accelerator, args, None, global_step, dit, vae, vae_scale,
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
accelerator,
args,
None,
global_step,
dit,
vae,
vae_scale,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs,
)
@@ -773,8 +770,10 @@ def train(args):
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs_with_names(
logs, lr_scheduler, args.optimizer_type,
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else []
logs,
lr_scheduler,
args.optimizer_type,
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [],
)
accelerator.log(logs, step=global_step)
@@ -807,8 +806,16 @@ def train(args):
)
anima_train_utils.sample_images(
accelerator, args, epoch + 1, global_step, dit, vae, vae_scale,
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
accelerator,
args,
epoch + 1,
global_step,
dit,
vae,
vae_scale,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs,
)
@@ -852,11 +859,11 @@ def setup_parser() -> argparse.ArgumentParser:
anima_train_utils.add_anima_training_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
parser.add_argument(
"--blockwise_fused_optimizers",
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step",
)
# parser.add_argument(
# "--blockwise_fused_optimizers",
# action="store_true",
# help="enable blockwise optimizers for fused backward pass and optimizer step",
# )
parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",

View File

@@ -1,16 +1,26 @@
# Anima LoRA training script
import argparse
import math
from typing import Any, Optional, Union
import torch
import torch.nn as nn
from accelerate import Accelerator
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import anima_models, anima_train_utils, anima_utils, strategy_anima, strategy_base, train_util
from library import (
anima_models,
anima_train_utils,
anima_utils,
flux_train_utils,
qwen_image_autoencoder_kl,
sd3_train_utils,
strategy_anima,
strategy_base,
train_util,
)
import train_network
from library.utils import setup_logging
@@ -24,13 +34,6 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.vae = None
self.vae_scale = None
self.qwen3_text_encoder = None
self.qwen3_tokenizer = None
self.t5_tokenizer = None
self.tokenize_strategy = None
self.text_encoding_strategy = None
def assert_extra_args(
self,
@@ -38,137 +41,113 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
if (args.fp8_base or args.fp8_base_unet) and not args.fp8_scaled:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
"fp8_base and fp8_base_unet are not supported. Use fp8_scaled instead / fp8_baseとfp8_base_unetはサポートされていません。代わりにfp8_scaledを使用してください"
)
if args.fp8_scaled and (args.fp8_base or args.fp8_base_unet):
logger.info(
"fp8_scaled is used, so fp8_base and fp8_base_unet are ignored / fp8_scaledが使われているので、fp8_baseとfp8_base_unetは無視されます"
)
args.fp8_base = False
args.fp8_base_unet = False
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
args.cache_text_encoder_outputs = True
# Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
# dataset-level caption dropout, so zero out subset-level rates to allow caching.
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
if caption_dropout_rate > 0:
logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
if hasattr(train_dataset_group, 'datasets'):
for dataset in train_dataset_group.datasets:
for subset in dataset.subsets:
subset.caption_dropout_rate = 0.0
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
assert train_dataset_group.is_text_encoder_output_cacheable(
cache_supports_dropout=True
), "when caching Text Encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
if getattr(args, 'unsloth_offload_checkpointing', False):
if args.unsloth_offload_checkpointing:
if not args.gradient_checkpointing:
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
assert not args.cpu_offload_checkpointing, \
"Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert (
not args.cpu_offload_checkpointing
), "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
# Flash attention: validate availability
if getattr(args, 'flash_attn', False):
try:
import flash_attn # noqa: F401
logger.info("Flash Attention enabled for DiT blocks")
except ImportError:
logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
args.flash_attn = False
if getattr(args, 'blockwise_fused_optimizers', False):
raise ValueError("blockwise_fused_optimizers is not supported with LoRA/NetworkTrainer")
train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8
train_dataset_group.verify_bucket_reso_steps(16) # WanVAE spatial downscale = 8 and patch size = 2
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(8)
val_dataset_group.verify_bucket_reso_steps(16)
def load_target_model(self, args, weight_dtype, accelerator):
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
# Load Qwen3 text encoder (tokenizers already loaded in get_tokenize_strategy)
logger.info("Loading Qwen3 text encoder...")
self.qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(
args.qwen3_path, dtype=weight_dtype, device="cpu"
)
self.qwen3_text_encoder.eval()
qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
qwen3_text_encoder.eval()
# Parse transformer_dtype
transformer_dtype = None
if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None:
transformer_dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
# Load VAE
logger.info("Loading Anima VAE...")
vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
# Return format: (model_type, text_encoders, vae, unet)
return "anima", [qwen3_text_encoder], vae, None # unet loaded lazily
def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, list[nn.Module]]:
loading_dtype = None if args.fp8_scaled else weight_dtype
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
attn_mode = "torch"
if args.xformers:
attn_mode = "xformers"
if args.attn_mode is not None:
attn_mode = args.attn_mode
# Load DiT
logger.info("Loading Anima DiT...")
dit = anima_utils.load_anima_dit(
args.dit_path,
dtype=weight_dtype,
device="cpu",
transformer_dtype=transformer_dtype,
llm_adapter_path=getattr(args, 'llm_adapter_path', None),
disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False),
logger.info(f"Loading Anima DiT model with attn_mode={attn_mode}, split_attn: {args.split_attn}...")
model = anima_utils.load_anima_model(
accelerator.device,
args.pretrained_model_name_or_path,
attn_mode,
args.split_attn,
loading_device,
loading_dtype,
args.fp8_scaled,
)
# Flash attention
if getattr(args, 'flash_attn', False):
dit.set_flash_attn(True)
# Store unsloth preference so that when the base NetworkTrainer calls
# dit.enable_gradient_checkpointing(cpu_offload=...), we can override to use unsloth.
# The base trainer only passes cpu_offload, so we store the flag on the model.
self._use_unsloth_offload_checkpointing = getattr(args, 'unsloth_offload_checkpointing', False)
self._use_unsloth_offload_checkpointing = args.unsloth_offload_checkpointing
# Block swap
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
dit.enable_block_swap(args.blocks_to_swap, accelerator.device)
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
# Load VAE
logger.info("Loading Anima VAE...")
self.vae, vae_mean, vae_std, self.vae_scale = anima_utils.load_anima_vae(
args.vae_path, dtype=weight_dtype, device="cpu"
)
# Return format: (model_type, text_encoders, vae, unet)
return "anima", [self.qwen3_text_encoder], self.vae, dit
return model, text_encoders
def get_tokenize_strategy(self, args):
# Load tokenizers from paths (called before load_target_model, so self.qwen3_tokenizer isn't set yet)
self.tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
qwen3_path=args.qwen3_path,
t5_tokenizer_path=getattr(args, 't5_tokenizer_path', None),
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
qwen3_path=args.qwen3,
t5_tokenizer_path=args.t5_tokenizer_path,
qwen3_max_length=args.qwen3_max_token_length,
t5_max_length=args.t5_max_token_length,
)
# Store references so load_target_model can reuse them
self.qwen3_tokenizer = self.tokenize_strategy.qwen3_tokenizer
self.t5_tokenizer = self.tokenize_strategy.t5_tokenizer
return self.tokenize_strategy
return tokenize_strategy
def get_tokenizers(self, tokenize_strategy: strategy_anima.AnimaTokenizeStrategy):
return [tokenize_strategy.qwen3_tokenizer]
def get_latents_caching_strategy(self, args):
return strategy_anima.AnimaLatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return strategy_anima.AnimaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check)
def get_text_encoding_strategy(self, args):
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
self.text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
dropout_rate=caption_dropout_rate,
)
return self.text_encoding_strategy
return strategy_anima.AnimaTextEncodingStrategy()
def post_process_network(self, args, accelerator, network, text_encoders, unet):
# Qwen3 text encoder is always frozen for Anima
@@ -188,10 +167,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
return strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=False,
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
return None
@@ -200,15 +176,14 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
):
if args.cache_text_encoder_outputs:
if not args.lowram:
logger.info("move vae and unet to cpu to save memory")
org_vae_device = next(vae.parameters()).device
org_unet_device = unet.device
# We cannot move DiT to CPU because of block swap, so only move VAE
logger.info("move vae to cpu to save memory")
org_vae_device = vae.device
vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device)
logger.info("move text encoder to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[0].to(accelerator.device)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
@@ -229,59 +204,52 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
logger.info(f" cache TE outputs for: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy,
text_encoders,
tokens_and_masks,
enable_dropout=False,
tokenize_strategy, text_encoders, tokens_and_masks
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
text_encoding_strategy_for_uncond = strategy_base.TextEncodingStrategy.get_strategy()
if caption_dropout_rate > 0.0:
tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy()
with accelerator.autocast():
text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders)
accelerator.wait_for_everyone()
# move text encoder back to cpu
logger.info("move text encoder back to cpu")
text_encoders[0].to("cpu")
clean_memory_on_device(accelerator.device)
if not args.lowram:
logger.info("move vae and unet back to original device")
logger.info("move vae back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
clean_memory_on_device(accelerator.device)
else:
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
# move text encoder to device for encoding during training/validation
text_encoders[0].to(accelerator.device)
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] # compatibility
te = self.get_models_for_text_encoding(args, accelerator, text_encoders)
qwen3_te = te[0] if te is not None else None
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
anima_train_utils.sample_images(
accelerator, args, epoch, global_step, unet, vae, self.vae_scale,
qwen3_te, self.tokenize_strategy, self.text_encoding_strategy,
accelerator,
args,
epoch,
global_step,
unet,
vae,
qwen3_te,
tokenize_strategy,
text_encoding_strategy,
self.sample_prompts_te_outputs,
)
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = anima_train_utils.FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000, shift=args.discrete_flow_shift
)
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
return noise_scheduler
def encode_images_to_latents(self, args, vae, images):
# images are already [-1,1] from IMAGE_TRANSFORMS, add temporal dim
images = images.unsqueeze(2) # (B, C, 1, H, W)
# Ensure scale tensors are on the same device as images
vae_device = images.device
scale = [s.to(vae_device) if isinstance(s, torch.Tensor) else s for s in self.vae_scale]
return vae.encode(images, scale)
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage
return vae.encode_pixels_to_latents(images)
def shift_scale_latents(self, args, latents):
# Latents already normalized by vae.encode with scale
@@ -301,13 +269,16 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
train_unet,
is_train=True,
):
anima: anima_models.Anima = unet
# Sample noise
noise = torch.randn_like(latents)
# Get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
args, latents, noise, accelerator.device, weight_dtype
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
# Gradient checkpointing support
if args.gradient_checkpointing:
@@ -329,147 +300,86 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
bs = latents.shape[0]
h_latent = latents.shape[-2]
w_latent = latents.shape[-1]
padding_mask = torch.zeros(
bs, 1, h_latent, w_latent,
dtype=weight_dtype, device=accelerator.device
)
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
# Prepare block swap
if self.is_swapping_blocks:
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
accelerator.unwrap_model(anima).prepare_block_swap_before_forward()
# Call model (LLM adapter runs inside forward for DDP gradient sync)
# Call model
with torch.set_grad_enabled(is_train), accelerator.autocast():
model_pred = unet(
model_pred = anima(
noisy_model_input,
timesteps,
prompt_embeds,
padding_mask=padding_mask,
target_input_ids=t5_input_ids,
target_attention_mask=t5_attn_mask,
source_attention_mask=attn_mask,
t5_input_ids=t5_input_ids,
t5_attn_mask=t5_attn_mask,
)
# Rectified flow target: noise - latents
target = noise - latents
# Loss weighting
weighting = anima_train_utils.compute_loss_weighting_for_anima(
weighting_scheme=args.weighting_scheme, sigmas=sigmas
)
# Differential output preservation
if "custom_attributes" in batch:
diff_output_pr_indices = []
for i, custom_attributes in enumerate(batch["custom_attributes"]):
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
diff_output_pr_indices.append(i)
if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
if self.is_swapping_blocks:
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
model_pred_prior = unet(
noisy_model_input[diff_output_pr_indices],
timesteps[diff_output_pr_indices],
prompt_embeds[diff_output_pr_indices],
padding_mask=padding_mask[diff_output_pr_indices],
source_attention_mask=attn_mask[diff_output_pr_indices],
t5_input_ids=t5_input_ids[diff_output_pr_indices],
t5_attn_mask=t5_attn_mask[diff_output_pr_indices],
)
network.set_multiplier(1.0)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
weighting = anima_train_utils.compute_loss_weighting_for_anima(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
return model_pred, target, timesteps, weighting
def process_batch(
self, batch, text_encoders, unet, network, vae, noise_scheduler,
vae_dtype, weight_dtype, accelerator, args,
text_encoding_strategy, tokenize_strategy,
is_train=True, train_text_encoder=True, train_unet=True,
self,
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=True,
train_unet=True,
) -> torch.Tensor:
"""Override base process_batch for 5D video latents (B, C, T, H, W).
"""Override base process_batch for caption dropout with cached text encoder outputs.
Base class assumes 4D (B, C, H, W) for loss.mean([1,2,3]) and weighting broadcast.
Base class now supports 4D and 5D latents, so we only need to handle caption dropout here.
"""
import typing
from library.custom_train_functions import apply_masked_loss
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
else:
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
else:
chunks = [
batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
]
list_latents = []
for chunk in chunks:
with torch.no_grad():
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
list_latents.append(chunk)
latents = torch.cat(list_latents, dim=0)
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents))
latents = self.shift_scale_latents(args, latents)
# Text encoder conditions
text_encoder_conds = []
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
anima_text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = text_encoding_strategy
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list
caption_dropout_rates = text_encoder_outputs_list[-1]
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
)
if args.full_fp16:
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
# Apply caption dropout to cached outputs
text_encoder_outputs_list = anima_text_encoding_strategy.drop_cached_text_encoder_outputs(
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
args, accelerator, noise_scheduler, latents, batch,
text_encoder_conds, unet, network, weight_dtype, train_unet, is_train=is_train,
return super().process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train,
train_text_encoder,
train_unet,
)
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
# Reduce all non-batch dims: (B, C, T, H, W) -> (B,) for 5D, (B, C, H, W) -> (B,) for 4D
reduce_dims = list(range(1, loss.ndim))
loss = loss.mean(reduce_dims)
# Apply weighting after reducing to (B,)
if weighting is not None:
loss = loss * weighting
loss_weights = batch["loss_weights"]
loss = loss * loss_weights
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
return loss.mean()
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
@@ -478,12 +388,15 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
def update_metadata(self, metadata, args):
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
metadata["ss_timestep_sampling"] = args.timestep_sampling
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
metadata["ss_timestep_sample_method"] = getattr(args, 'timestep_sample_method', 'logit_normal')
metadata["ss_sigmoid_scale"] = getattr(args, 'sigmoid_scale', 1.0)
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
@@ -496,20 +409,12 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)
dit = unet
dit = accelerator.prepare(dit, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(dit).move_to_device_except_swap_blocks(accelerator.device)
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
model = unet
model = accelerator.prepare(model, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(model).move_to_device_except_swap_blocks(accelerator.device)
accelerator.unwrap_model(model).prepare_block_swap_before_forward()
return dit
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
# Drop cached text encoder outputs for caption dropout
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
return model
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
@@ -520,6 +425,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
anima_train_utils.add_anima_training_arguments(parser)
parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
parser.add_argument(
"--unsloth_offload_checkpointing",
action="store_true",
@@ -536,5 +442,8 @@ if __name__ == "__main__":
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
if args.attn_mode == "sdpa":
args.attn_mode = "torch" # backward compatibility
trainer = AnimaNetworkTrainer()
trainer.train(args)

View File

@@ -118,7 +118,7 @@ accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
--optimizer_type="AdamW8bit" \
--lr_scheduler="constant" \
--timestep_sample_method="logit_normal" \
--discrete_flow_shift=3.0 \
--discrete_flow_shift=1.0 \
--max_train_epochs=10 \
--save_every_n_epochs=1 \
--mixed_precision="bf16" \
@@ -162,7 +162,7 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
* `--timestep_sample_method=<choice>`
- Timestep sampling method. Choose from `logit_normal` (default) or `uniform`.
* `--discrete_flow_shift=<float>`
- Shift for the timestep distribution in Rectified Flow training. Default `3.0`. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
- Shift for the timestep distribution in Rectified Flow training. Default `1.0`. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`. 1.0 means no shift.
* `--sigmoid_scale=<float>`
- Scale factor for `logit_normal` timestep sampling. Default `1.0`.
* `--qwen3_max_token_length=<integer>`

View File

@@ -13,11 +13,10 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from library import custom_offloading_utils
from library import custom_offloading_utils, attention
from library.device_utils import clean_memory_on_device
def to_device(x, device):
if isinstance(x, torch.Tensor):
return x.to(device)
@@ -39,11 +38,13 @@ def to_cpu(x):
else:
return x
# Unsloth Offloaded Gradient Checkpointing
# Based on Unsloth Zoo by Daniel Han-Chen & the Unsloth team
try:
from deepspeed.runtime.activation_checkpointing.checkpointing import detach_variable
except ImportError:
def detach_variable(inputs, device=None):
"""Detach tensors from computation graph, optionally moving to a device.
@@ -80,11 +81,11 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
"""
@staticmethod
@torch.amp.custom_fwd(device_type='cuda')
@torch.amp.custom_fwd(device_type="cuda")
def forward(ctx, forward_function, hidden_states, *args):
# Remember the original device for backward pass (multi-GPU support)
ctx.input_device = hidden_states.device
saved_hidden_states = hidden_states.to('cpu', non_blocking=True)
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
@@ -96,7 +97,7 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
return output
@staticmethod
@torch.amp.custom_bwd(device_type='cuda')
@torch.amp.custom_bwd(device_type="cuda")
def backward(ctx, *grads):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to(ctx.input_device, non_blocking=True).detach()
@@ -108,8 +109,9 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
output_tensors = []
grad_tensors = []
for out, grad in zip(outputs if isinstance(outputs, tuple) else (outputs,),
grads if isinstance(grads, tuple) else (grads,)):
for out, grad in zip(
outputs if isinstance(outputs, tuple) else (outputs,), grads if isinstance(grads, tuple) else (grads,)
):
if isinstance(out, torch.Tensor) and out.requires_grad:
output_tensors.append(out)
grad_tensors.append(grad)
@@ -123,24 +125,24 @@ def unsloth_checkpoint(function, *args):
return UnslothOffloadedGradientCheckpointer.apply(function, *args)
# Flash Attention support
try:
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
FLASH_ATTN_AVAILABLE = True
except ImportError:
_flash_attn_func = None
FLASH_ATTN_AVAILABLE = False
# # Flash Attention support
# try:
# from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
# FLASH_ATTN_AVAILABLE = True
# except ImportError:
# _flash_attn_func = None
# FLASH_ATTN_AVAILABLE = False
def flash_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
"""Computes multi-head attention using Flash Attention.
# def flash_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
# """Computes multi-head attention using Flash Attention.
Input format: (batch, seq_len, n_heads, head_dim)
Output format: (batch, seq_len, n_heads * head_dim) — matches torch_attention_op output.
"""
# flash_attn_func expects (B, S, H, D) and returns (B, S, H, D)
out = _flash_attn_func(q_B_S_H_D, k_B_S_H_D, v_B_S_H_D)
return rearrange(out, "b s h d -> b s (h d)")
# Input format: (batch, seq_len, n_heads, head_dim)
# Output format: (batch, seq_len, n_heads * head_dim) — matches torch_attention_op output.
# """
# # flash_attn_func expects (B, S, H, D) and returns (B, S, H, D)
# out = _flash_attn_func(q_B_S_H_D, k_B_S_H_D, v_B_S_H_D)
# return rearrange(out, "b s h d -> b s (h d)")
from .utils import setup_logging
@@ -174,14 +176,10 @@ def _apply_rotary_pos_emb_base(
if start_positions is not None:
max_offset = torch.max(start_positions)
assert (
max_offset + cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
assert max_offset + cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1)
assert (
cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
assert cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
@@ -205,13 +203,9 @@ def apply_rotary_pos_emb(
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
) -> torch.Tensor:
assert not (
cp_size > 1 and start_positions is not None
), "start_positions != None with CP SIZE > 1 is not supported!"
assert not (cp_size > 1 and start_positions is not None), "start_positions != None with CP SIZE > 1 is not supported!"
assert (
tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
assert tensor_format != "thd" or cu_seqlens is not None, "cu_seqlens must not be None when tensor_format is 'thd'."
assert fused == False
@@ -223,9 +217,7 @@ def apply_rotary_pos_emb(
_apply_rotary_pos_emb_base(
x.unsqueeze(1),
freqs,
start_positions=(
start_positions[idx : idx + 1] if start_positions is not None else None
),
start_positions=(start_positions[idx : idx + 1] if start_positions is not None else None),
interleaved=interleaved,
)
for idx, x in enumerate(torch.split(t, seqlens))
@@ -262,7 +254,7 @@ class RMSNorm(torch.nn.Module):
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
@torch.amp.autocast(device_type='cuda', dtype=torch.float32)
@torch.amp.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
@@ -308,9 +300,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
result_B_S_HD = rearrange(
F.scaled_dot_product_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D), "b h ... l -> b ... (h l)"
)
result_B_S_HD = rearrange(F.scaled_dot_product_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D), "b h ... l -> b ... (h l)")
return result_B_S_HD
@@ -399,18 +389,23 @@ class Attention(nn.Module):
return q, k, v
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
result = self.attn_op(q, k, v) # [B, S, H, D]
return self.output_dropout(self.output_proj(result))
# def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
# result = self.attn_op(q, k, v) # [B, S, H, D]
# return self.output_dropout(self.output_proj(result))
def forward(
self,
x: torch.Tensor,
attn_params: attention.AttentionParams,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
return self.compute_attention(q, k, v)
# return self.compute_attention(q, k, v)
qkv = [q, k, v]
del q, k, v
result = attention.attention(qkv, attn_params=attn_params)
return self.output_dropout(self.output_proj(result))
# Positional Embeddings
@@ -484,12 +479,8 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
dim_t = self._dim_t
self.seq = torch.arange(max(self.max_h, self.max_w, self.max_t)).float().to(self.dim_spatial_range.device)
self.dim_spatial_range = (
torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h
)
self.dim_temporal_range = (
torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t
)
self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h
self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t
def generate_embeddings(
self,
@@ -679,9 +670,7 @@ class FourierFeatures(nn.Module):
def reset_parameters(self) -> None:
generator = torch.Generator()
generator.manual_seed(0)
self.freqs = (
2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
)
self.freqs = 2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device)
def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
@@ -713,9 +702,7 @@ class PatchEmbed(nn.Module):
m=spatial_patch_size,
n=spatial_patch_size,
),
nn.Linear(
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False
),
nn.Linear(in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False),
)
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
@@ -765,9 +752,7 @@ class FinalLayer(nn.Module):
nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False),
)
else:
self.adaln_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False)
)
self.adaln_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False))
self.init_weights()
@@ -790,9 +775,9 @@ class FinalLayer(nn.Module):
):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
shift_B_T_D, scale_B_T_D = (
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
).chunk(2, dim=-1)
shift_B_T_D, scale_B_T_D = (self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]).chunk(
2, dim=-1
)
else:
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
@@ -833,7 +818,11 @@ class Block(nn.Module):
self.layer_norm_cross_attn = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
self.cross_attn = Attention(
x_dim, context_dim, num_heads, x_dim // num_heads, qkv_format="bshd",
x_dim,
context_dim,
num_heads,
x_dim // num_heads,
qkv_format="bshd",
)
self.layer_norm_mlp = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
@@ -904,6 +893,7 @@ class Block(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -919,13 +909,13 @@ class Block(nn.Module):
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk(
3, dim=-1
)
else:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
emb_B_T_D
).chunk(3, dim=-1)
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(emb_B_T_D).chunk(
3, dim=-1
)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
emb_B_T_D
).chunk(3, dim=-1)
@@ -954,11 +944,14 @@ class Block(nn.Module):
result = rearrange(
self.self_attn(
rearrange(normalized_x, "b t h w d -> b (t h w) d"),
attn_params,
None,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
t=T, h=H, w=W,
t=T,
h=H,
w=W,
)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result
@@ -967,11 +960,14 @@ class Block(nn.Module):
result = rearrange(
self.cross_attn(
rearrange(normalized_x, "b t h w d -> b (t h w) d"),
attn_params,
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
t=T, h=H, w=W,
t=T,
h=H,
w=W,
)
x_B_T_H_W_D = result * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
@@ -987,6 +983,7 @@ class Block(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -996,8 +993,13 @@ class Block(nn.Module):
# Unsloth: async non-blocking CPU RAM offload (fastest offload method)
return unsloth_checkpoint(
self._forward,
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
x_B_T_H_W_D,
emb_B_T_D,
crossattn_emb,
attn_params,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
)
elif self.cpu_offload_checkpointing:
# Standard cpu offload: blocking transfers
@@ -1008,36 +1010,54 @@ class Block(nn.Module):
device_inputs = to_device(inputs, device)
outputs = func(*device_inputs)
return to_cpu(outputs)
return custom_forward
return torch_checkpoint(
create_custom_forward(self._forward),
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
x_B_T_H_W_D,
emb_B_T_D,
crossattn_emb,
attn_params,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
use_reentrant=False,
)
else:
# Standard gradient checkpointing (no offload)
return torch_checkpoint(
self._forward,
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
x_B_T_H_W_D,
emb_B_T_D,
crossattn_emb,
attn_params,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
use_reentrant=False,
)
else:
return self._forward(
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
x_B_T_H_W_D,
emb_B_T_D,
crossattn_emb,
attn_params,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
)
# Main DiT Model: MiniTrainDIT
class MiniTrainDIT(nn.Module):
# Main DiT Model: MiniTrainDIT (renamed to Anima)
class Anima(nn.Module):
"""Cosmos-Predict2 DiT model for image/video generation.
28 transformer blocks with AdaLN-LoRA modulation, 3D RoPE, and optional LLM Adapter.
"""
LATENT_CHANNELS = 16
def __init__(
self,
max_img_h: int,
@@ -1069,6 +1089,8 @@ class MiniTrainDIT(nn.Module):
extra_t_extrapolation_ratio: float = 1.0,
rope_enable_fps_modulation: bool = True,
use_llm_adapter: bool = False,
attn_mode: str = "torch",
split_attn: bool = False,
) -> None:
super().__init__()
self.max_img_h = max_img_h
@@ -1097,6 +1119,9 @@ class MiniTrainDIT(nn.Module):
self.rope_enable_fps_modulation = rope_enable_fps_modulation
self.use_llm_adapter = use_llm_adapter
self.attn_mode = attn_mode
self.split_attn = split_attn
# Block swap support
self.blocks_to_swap = None
self.offloader: Optional[custom_offloading_utils.ModelOffloader] = None
@@ -1156,7 +1181,6 @@ class MiniTrainDIT(nn.Module):
self.final_layer.init_weights()
self.t_embedding_norm.reset_parameters()
def enable_gradient_checkpointing(self, cpu_offload: bool = False, unsloth_offload: bool = False):
for block in self.blocks:
block.enable_gradient_checkpointing(cpu_offload=cpu_offload, unsloth_offload=unsloth_offload)
@@ -1169,18 +1193,21 @@ class MiniTrainDIT(nn.Module):
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def set_flash_attn(self, use_flash_attn: bool):
"""Toggle flash attention for all DiT blocks (self-attn + cross-attn).
# def set_flash_attn(self, use_flash_attn: bool):
# """Toggle flash attention for all DiT blocks (self-attn + cross-attn).
LLM Adapter attention is NOT affected (it uses attention masks incompatible with flash_attn).
"""
if use_flash_attn and not FLASH_ATTN_AVAILABLE:
raise ImportError("flash_attn package is required for --flash_attn but is not installed")
attn_op = flash_attention_op if use_flash_attn else torch_attention_op
for block in self.blocks:
block.self_attn.attn_op = attn_op
block.cross_attn.attn_op = attn_op
# LLM Adapter attention is NOT affected (it uses attention masks incompatible with flash_attn).
# """
# if use_flash_attn and not FLASH_ATTN_AVAILABLE:
# raise ImportError("flash_attn package is required for --flash_attn but is not installed")
# attn_op = flash_attention_op if use_flash_attn else torch_attention_op
# for block in self.blocks:
# block.self_attn.attn_op = attn_op
# block.cross_attn.attn_op = attn_op
def build_patch_embed(self) -> None:
in_channels = self.in_channels + 1 if self.concat_padding_mask else self.in_channels
@@ -1232,9 +1259,7 @@ class MiniTrainDIT(nn.Module):
padding_mask = transforms.functional.resize(
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
x_B_C_T_H_W = torch.cat(
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
)
x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1)
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb:
@@ -1258,7 +1283,6 @@ class MiniTrainDIT(nn.Module):
)
return x_B_C_Tt_Hp_Wp
def enable_block_swap(self, num_blocks: int, device: torch.device):
self.blocks_to_swap = num_blocks
@@ -1266,9 +1290,7 @@ class MiniTrainDIT(nn.Module):
self.blocks_to_swap <= self.num_blocks - 2
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
self.offloader = custom_offloading_utils.ModelOffloader(
self.blocks, self.blocks_to_swap, device
)
self.offloader = custom_offloading_utils.ModelOffloader(self.blocks, self.blocks_to_swap, device)
logger.info(f"Anima: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
def move_to_device_except_swap_blocks(self, device: torch.device):
@@ -1287,7 +1309,7 @@ class MiniTrainDIT(nn.Module):
return
self.offloader.prepare_block_devices_before_forward(self.blocks)
def forward(
def forward_mini_train_dit(
self,
x_B_C_T_H_W: torch.Tensor,
timesteps_B_T: torch.Tensor,
@@ -1310,7 +1332,7 @@ class MiniTrainDIT(nn.Module):
t5_attn_mask: Optional T5 attention mask
"""
# Run LLM adapter inside forward for correct DDP gradient synchronization
if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, 'llm_adapter'):
if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, "llm_adapter"):
crossattn_emb = self.llm_adapter(
source_hidden_states=crossattn_emb,
target_input_ids=t5_input_ids,
@@ -1337,6 +1359,8 @@ class MiniTrainDIT(nn.Module):
"extra_per_block_pos_emb": extra_pos_emb,
}
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx)
@@ -1345,6 +1369,7 @@ class MiniTrainDIT(nn.Module):
x_B_T_H_W_D,
t_embedding_B_T_D,
crossattn_emb,
attn_params,
**block_kwargs,
)
@@ -1355,6 +1380,41 @@ class MiniTrainDIT(nn.Module):
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
target_input_ids: Optional[torch.Tensor] = None,
target_attention_mask: Optional[torch.Tensor] = None,
source_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
context = self._preprocess_text_embeds(context, target_input_ids, target_attention_mask, source_attention_mask)
return self.forward_mini_train_dit(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs)
def _preprocess_text_embeds(
self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None
):
if target_input_ids is not None:
# print(
# f"Source hidden states shape: {source_hidden_states.shape},sum of attention mask: {torch.sum(source_attention_mask)}"
# )
# print(f"non zero source_hidden_states before LLM Adapter: {torch.sum(source_hidden_states != 0)}")
context = self.llm_adapter(
source_hidden_states,
target_input_ids,
target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
)
context[~target_attention_mask.bool()] = 0 # zero out padding tokens
# print(f"LLM Adapter output context: {context.shape}, {torch.isnan(context).sum()}")
return context
else:
return source_hidden_states
# LLM Adapter: Bridges Qwen3 embeddings to T5-compatible space
class LLMAdapterRMSNorm(nn.Module):
@@ -1485,24 +1545,37 @@ class LLMAdapterTransformerBlock(nn.Module):
self.norm_mlp = nn.LayerNorm(model_dim) if layer_norm else LLMAdapterRMSNorm(model_dim)
self.mlp = nn.Sequential(
nn.Linear(model_dim, int(model_dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(model_dim * mlp_ratio), model_dim)
nn.Linear(model_dim, int(model_dim * mlp_ratio)), nn.GELU(), nn.Linear(int(model_dim * mlp_ratio), model_dim)
)
def forward(self, x, context, target_attention_mask=None, source_attention_mask=None,
position_embeddings=None, position_embeddings_context=None):
def forward(
self,
x,
context,
target_attention_mask=None,
source_attention_mask=None,
position_embeddings=None,
position_embeddings_context=None,
):
if self.has_self_attn:
# Self-attention: target_attention_mask is not expected to be all zeros
normed = self.norm_self_attn(x)
attn_out = self.self_attn(normed, mask=target_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings)
attn_out = self.self_attn(
normed,
mask=target_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings,
)
x = x + attn_out
normed = self.norm_cross_attn(x)
attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context)
attn_out = self.cross_attn(
normed,
mask=source_attention_mask,
context=context,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
x = x + attn_out
x = x + self.mlp(self.norm_mlp(x))
@@ -1518,8 +1591,9 @@ class LLMAdapter(nn.Module):
Uses T5 token IDs as target input, embeds them, and cross-attends to Qwen3 hidden states.
"""
def __init__(self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16,
embed=None, self_attn=False, layer_norm=False):
def __init__(
self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16, embed=None, self_attn=False, layer_norm=False
):
super().__init__()
if embed is not None:
self.embed = nn.Embedding.from_pretrained(embed.weight)
@@ -1530,11 +1604,12 @@ class LLMAdapter(nn.Module):
else:
self.in_proj = nn.Identity()
self.rotary_emb = AdapterRotaryEmbedding(model_dim // num_heads)
self.blocks = nn.ModuleList([
LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads,
self_attn=self_attn, layer_norm=layer_norm)
for _ in range(num_layers)
])
self.blocks = nn.ModuleList(
[
LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, self_attn=self_attn, layer_norm=layer_norm)
for _ in range(num_layers)
]
)
self.out_proj = nn.Linear(model_dim, target_dim)
self.norm = LLMAdapterRMSNorm(target_dim)
@@ -1556,10 +1631,14 @@ class LLMAdapter(nn.Module):
position_embeddings = self.rotary_emb(x, position_ids)
position_embeddings_context = self.rotary_emb(x, position_ids_context)
for block in self.blocks:
x = block(x, context, target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context)
x = block(
x,
context,
target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
return self.norm(self.out_proj(x))
@@ -1567,30 +1646,60 @@ class LLMAdapter(nn.Module):
# VAE normalization constants
ANIMA_VAE_MEAN = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
]
ANIMA_VAE_STD = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
]
# DiT config detection from state_dict
KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer']
KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"]
def get_dit_config(state_dict, key_prefix=''):
def get_dit_config(state_dict, key_prefix=""):
"""Derive DiT configuration from state_dict weight shapes."""
dit_config = {}
dit_config["max_img_h"] = 512
dit_config["max_img_w"] = 512
dit_config["max_frames"] = 128
concat_padding_mask = True
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
dit_config["in_channels"] = (state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[1] // 4) - int(
concat_padding_mask
)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
dit_config["model_channels"] = state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[0]
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["crossattn_emb_channels"] = 1024
dit_config["pos_emb_cls"] = "rope3d"

View File

@@ -15,6 +15,7 @@ from tqdm import tqdm
from PIL import Image
from library.device_utils import init_ipex, clean_memory_on_device
from library import anima_models, anima_utils, strategy_base, train_util, qwen_image_autoencoder_kl
init_ipex()
@@ -25,29 +26,14 @@ import logging
logger = logging.getLogger(__name__)
from library import anima_models, anima_utils, strategy_base, train_util
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas
# Anima-specific training arguments
def add_anima_training_arguments(parser: argparse.ArgumentParser):
"""Add Anima-specific training arguments to the parser."""
parser.add_argument(
"--dit_path",
type=str,
default=None,
help="Path to Anima DiT model safetensors file",
)
parser.add_argument(
"--vae_path",
type=str,
default=None,
help="Path to WanVAE safetensors/pth file",
)
parser.add_argument(
"--qwen3_path",
"--qwen3",
type=str,
default=None,
help="Path to Qwen3-0.6B model (safetensors file or directory)",
@@ -86,7 +72,7 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
"--mod_lr",
type=float,
default=None,
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze",
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze. Note: mod layers are not included in LoRA by default.",
)
parser.add_argument(
"--t5_tokenizer_path",
@@ -113,34 +99,29 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
help="Timestep distribution shift for rectified flow training (default: 1.0)",
)
parser.add_argument(
"--timestep_sample_method",
"--timestep_sampling",
type=str,
default="logit_normal",
choices=["logit_normal", "uniform"],
help="Timestep sampling method (default: logit_normal)",
default="sigmoid",
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
help="Timestep sampling method (default: sigmoid (logit normal))",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
help="Scale factor for logit_normal timestep sampling (default: 1.0)",
help="Scale factor for sigmoid (logit_normal) timestep sampling (default: 1.0)",
)
# Note: --caption_dropout_rate is defined by base add_dataset_arguments().
# Anima uses embedding-level dropout (via AnimaTextEncodingStrategy.dropout_rate)
# instead of dataset-level caption dropout, so the subset caption_dropout_rate
# is zeroed out in the training scripts to allow caching.
parser.add_argument(
"--transformer_dtype",
type=str,
"--attn_mode",
choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility
default=None,
choices=["float16", "bfloat16", "float32", None],
help="Separate dtype for transformer blocks. If None, uses same as mixed_precision",
help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa."
" / 使用するAttentionの実装。デフォルトはNonetorchです。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません推論のみ。このオプションは--xformersまたは--sdpaを上書きします。",
)
parser.add_argument(
"--flash_attn",
"--split_attn",
action="store_true",
help="Use Flash Attention for DiT self/cross-attention (requires flash-attn package). "
"Falls back to PyTorch SDPA if flash-attn is not installed.",
help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
)
@@ -169,20 +150,20 @@ def get_noisy_model_input_and_timesteps(
"""
bs = latents.shape[0]
timestep_sample_method = getattr(args, 'timestep_sample_method', 'logit_normal')
sigmoid_scale = getattr(args, 'sigmoid_scale', 1.0)
shift = getattr(args, 'discrete_flow_shift', 1.0)
timestep_sample_method = getattr(args, "timestep_sample_method", "logit_normal")
sigmoid_scale = getattr(args, "sigmoid_scale", 1.0)
shift = getattr(args, "discrete_flow_shift", 1.0)
if timestep_sample_method == 'logit_normal':
if timestep_sample_method == "logit_normal":
dist = torch.distributions.normal.Normal(0, 1)
elif timestep_sample_method == 'uniform':
elif timestep_sample_method == "uniform":
dist = torch.distributions.uniform.Uniform(0, 1)
else:
raise NotImplementedError(f"Unknown timestep_sample_method: {timestep_sample_method}")
t = dist.sample((bs,)).to(device)
if timestep_sample_method == 'logit_normal':
if timestep_sample_method == "logit_normal":
t = t * sigmoid_scale
t = torch.sigmoid(t)
@@ -196,10 +177,10 @@ def get_noisy_model_input_and_timesteps(
# Create noisy input: (1 - t) * latents + t * noise
t_expanded = t.view(-1, *([1] * (latents.ndim - 1)))
ip_noise_gamma = getattr(args, 'ip_noise_gamma', None)
ip_noise_gamma = getattr(args, "ip_noise_gamma", None)
if ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if getattr(args, 'ip_noise_gamma_random_strength', False):
if getattr(args, "ip_noise_gamma_random_strength", False):
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * ip_noise_gamma
noisy_model_input = (1 - t_expanded) * latents + t_expanded * (noise + ip_noise_gamma * xi)
else:
@@ -213,10 +194,11 @@ def get_noisy_model_input_and_timesteps(
# Loss weighting
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
"""Compute loss weighting for Anima training.
Same schemes as SD3 but can add Anima-specific ones.
Same schemes as SD3 but can add Anima-specific ones if needed in future.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
@@ -243,7 +225,7 @@ def get_anima_param_groups(
"""Create parameter groups for Anima training with separate learning rates.
Args:
dit: MiniTrainDIT model
dit: Anima model
base_lr: Base learning rate
self_attn_lr: LR for self-attention layers (None = base_lr, 0 = freeze)
cross_attn_lr: LR for cross-attention layers
@@ -276,15 +258,15 @@ def get_anima_param_groups(
# Store original name for debugging
p.original_name = name
if 'llm_adapter' in name:
if "llm_adapter" in name:
llm_adapter_params.append(p)
elif '.self_attn' in name:
elif ".self_attn" in name:
self_attn_params.append(p)
elif '.cross_attn' in name:
elif ".cross_attn" in name:
cross_attn_params.append(p)
elif '.mlp' in name:
elif ".mlp" in name:
mlp_params.append(p)
elif '.adaln_modulation' in name:
elif ".adaln_modulation" in name:
mod_params.append(p)
else:
base_params.append(p)
@@ -311,9 +293,9 @@ def get_anima_param_groups(
p.requires_grad_(False)
logger.info(f" Frozen {name} params ({len(params)} parameters)")
elif len(params) > 0:
param_groups.append({'params': params, 'lr': lr})
param_groups.append({"params": params, "lr": lr})
total_trainable = sum(p.numel() for group in param_groups for p in group['params'] if p.requires_grad)
total_trainable = sum(p.numel() for group in param_groups for p in group["params"] if p.requires_grad)
logger.info(f"Total trainable parameters: {total_trainable:,}")
return param_groups
@@ -325,13 +307,12 @@ def save_anima_model_on_train_end(
save_dtype: torch.dtype,
epoch: int,
global_step: int,
dit: anima_models.MiniTrainDIT,
dit: anima_models.Anima,
):
"""Save Anima model at the end of training."""
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True
)
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
dit_sd = dit.state_dict()
# Save with 'net.' prefix for ComfyUI compatibility
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
@@ -347,13 +328,12 @@ def save_anima_model_on_epoch_end_or_stepwise(
epoch: int,
num_train_epochs: int,
global_step: int,
dit: anima_models.MiniTrainDIT,
dit: anima_models.Anima,
):
"""Save Anima model at epoch end or specific steps."""
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True
)
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
dit_sd = dit.state_dict()
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
@@ -376,12 +356,13 @@ def do_sample(
height: int,
width: int,
seed: Optional[int],
dit: anima_models.MiniTrainDIT,
dit: anima_models.Anima,
crossattn_emb: torch.Tensor,
steps: int,
dtype: torch.dtype,
device: torch.device,
guidance_scale: float = 1.0,
flow_shift: float = 3.0,
neg_crossattn_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Generate a sample using Euler discrete sampling for rectified flow.
@@ -389,12 +370,13 @@ def do_sample(
Args:
height, width: Output image dimensions
seed: Random seed (None for random)
dit: MiniTrainDIT model
dit: Anima model
crossattn_emb: Cross-attention embeddings (B, N, D)
steps: Number of sampling steps
dtype: Compute dtype
device: Compute device
guidance_scale: CFG scale (1.0 = no guidance)
flow_shift: Flow shift parameter for rectified flow
neg_crossattn_emb: Negative cross-attention embeddings for CFG
Returns:
@@ -410,12 +392,13 @@ def do_sample(
generator = torch.manual_seed(seed)
else:
generator = None
noise = torch.randn(
latent.size(), dtype=torch.float32, generator=generator, device="cpu"
).to(dtype).to(device)
noise = torch.randn(latent.size(), dtype=torch.float32, generator=generator, device="cpu").to(dtype).to(device)
# Timestep schedule: linear from 1.0 to 0.0
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
flow_shift = float(flow_shift)
if flow_shift != 1.0:
sigmas = (sigmas * flow_shift) / (1 + (flow_shift - 1) * sigmas)
# Start from pure noise
x = noise.clone()
@@ -463,7 +446,6 @@ def sample_images(
steps,
dit,
vae,
vae_scale,
text_encoder,
tokenize_strategy,
text_encoding_strategy,
@@ -512,10 +494,19 @@ def sample_images(
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
_sample_image_inference(
accelerator, args, dit, text_encoder, vae, vae_scale,
tokenize_strategy, text_encoding_strategy,
save_dir, prompt_dict, epoch, steps,
sample_prompts_te_outputs, prompt_replacement,
accelerator,
args,
dit,
text_encoder,
vae,
tokenize_strategy,
text_encoding_strategy,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
# Restore RNG state
@@ -527,10 +518,19 @@ def sample_images(
def _sample_image_inference(
accelerator, args, dit, text_encoder, vae, vae_scale,
tokenize_strategy, text_encoding_strategy,
save_dir, prompt_dict, epoch, steps,
sample_prompts_te_outputs, prompt_replacement,
accelerator,
args,
dit,
text_encoder,
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage,
tokenize_strategy,
text_encoding_strategy,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
):
"""Generate a single sample image."""
prompt = prompt_dict.get("prompt", "")
@@ -540,6 +540,7 @@ def _sample_image_inference(
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
flow_shift = prompt_dict.get("flow_shift", 3.0)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
@@ -553,7 +554,9 @@ def _sample_image_inference(
height = max(64, height - height % 16)
width = max(64, width - width % 16)
logger.info(f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}")
logger.info(
f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}, flow_shift: {flow_shift}, seed: {seed}"
)
# Encode prompt
def encode_prompt(prpt):
@@ -579,13 +582,13 @@ def _sample_image_inference(
t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0)
t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0)
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.dtype)
attn_mask = attn_mask.to(accelerator.device)
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
t5_attn_mask = t5_attn_mask.to(accelerator.device)
# Process through LLM adapter if available
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
if dit.use_llm_adapter:
crossattn_emb = dit.llm_adapter(
source_hidden_states=prompt_embeds,
target_input_ids=t5_input_ids,
@@ -608,12 +611,12 @@ def _sample_image_inference(
neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0)
neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0)
neg_pe = neg_pe.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
neg_pe = neg_pe.to(accelerator.device, dtype=dit.dtype)
neg_am = neg_am.to(accelerator.device)
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
neg_t5_am = neg_t5_am.to(accelerator.device)
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
if dit.use_llm_adapter:
neg_crossattn_emb = dit.llm_adapter(
source_hidden_states=neg_pe,
target_input_ids=neg_t5_ids,
@@ -627,16 +630,14 @@ def _sample_image_inference(
# Generate sample
clean_memory_on_device(accelerator.device)
latents = do_sample(
height, width, seed, dit, crossattn_emb,
sample_steps, dit.t_embedding_norm.weight.dtype,
accelerator.device, scale, neg_crossattn_emb,
height, width, seed, dit, crossattn_emb, sample_steps, dit.dtype, accelerator.device, scale, flow_shift, neg_crossattn_emb
)
# Decode latents
clean_memory_on_device(accelerator.device)
org_vae_device = next(vae.parameters()).device
org_vae_device = vae.device
vae.to(accelerator.device)
decoded = vae.decode(latents.to(next(vae.parameters()).device, dtype=next(vae.parameters()).dtype), vae_scale)
decoded = vae.decode_to_pixels(latents)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
@@ -662,4 +663,5 @@ def _sample_image_inference(
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False)

View File

@@ -6,7 +6,12 @@ import torch
import torch.nn as nn
from safetensors.torch import load_file, save_file
from accelerate.utils import set_module_tensor_to_device # kept for potential future use
from accelerate import init_empty_weights
from library.fp8_optimization_utils import apply_fp8_monkey_patch
from library.lora_utils import load_safetensors_with_lora_and_fp8
from library import anima_models
from library.safetensors_utils import WeightTransformHooks
from .utils import setup_logging
setup_logging()
@@ -14,11 +19,9 @@ import logging
logger = logging.getLogger(__name__)
from library import anima_models
# Keys that should stay in high precision (float32/bfloat16, not quantized)
KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer']
KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"]
def load_safetensors(path: str, device: str = "cpu", dtype: Optional[torch.dtype] = None) -> Dict[str, torch.Tensor]:
@@ -36,8 +39,8 @@ def load_anima_dit(
transformer_dtype: Optional[torch.dtype] = None,
llm_adapter_path: Optional[str] = None,
disable_mmap: bool = False,
) -> anima_models.MiniTrainDIT:
"""Load the MiniTrainDIT model from safetensors.
) -> anima_models.Anima:
"""Load the Anima model from safetensors.
Args:
dit_path: Path to DiT safetensors file
@@ -53,6 +56,7 @@ def load_anima_dit(
logger.info(f"Loading Anima DiT from {dit_path}")
if disable_mmap:
from library.safetensors_utils import load_safetensors as load_safetensors_no_mmap
state_dict = load_safetensors_no_mmap(dit_path, device="cpu", disable_mmap=True)
else:
state_dict = load_file(dit_path, device="cpu")
@@ -60,8 +64,8 @@ def load_anima_dit(
# Remove 'net.' prefix if present
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('net.'):
k = k[len('net.'):]
if k.startswith("net."):
k = k[len("net.") :]
new_state_dict[k] = v
state_dict = new_state_dict
@@ -71,21 +75,23 @@ def load_anima_dit(
# Detect LLM adapter
if llm_adapter_path is not None:
use_llm_adapter = True
dit_config['use_llm_adapter'] = True
dit_config["use_llm_adapter"] = True
llm_adapter_state_dict = load_safetensors(llm_adapter_path, device="cpu")
elif 'llm_adapter.out_proj.weight' in state_dict:
elif "llm_adapter.out_proj.weight" in state_dict:
use_llm_adapter = True
dit_config['use_llm_adapter'] = True
dit_config["use_llm_adapter"] = True
llm_adapter_state_dict = None # Loaded as part of DiT
else:
use_llm_adapter = False
llm_adapter_state_dict = None
logger.info(f"DiT config: model_channels={dit_config['model_channels']}, num_blocks={dit_config['num_blocks']}, "
f"num_heads={dit_config['num_heads']}, use_llm_adapter={use_llm_adapter}")
logger.info(
f"DiT config: model_channels={dit_config['model_channels']}, num_blocks={dit_config['num_blocks']}, "
f"num_heads={dit_config['num_heads']}, use_llm_adapter={use_llm_adapter}"
)
# Build model normally on CPU — buffers get proper values from __init__
dit = anima_models.MiniTrainDIT(**dit_config)
dit = anima_models.Anima(**dit_config)
# Merge LLM adapter weights into state_dict if loaded separately
if use_llm_adapter and llm_adapter_state_dict is not None:
@@ -96,9 +102,11 @@ def load_anima_dit(
missing, unexpected = dit.load_state_dict(state_dict, strict=False)
if missing:
# Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
unexpected_missing = [k for k in missing if not any(
buf_name in k for buf_name in ('seq', 'dim_spatial_range', 'dim_temporal_range', 'inv_freq')
)]
unexpected_missing = [
k
for k in missing
if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq"))
]
if unexpected_missing:
logger.warning(f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}")
if unexpected:
@@ -106,9 +114,7 @@ def load_anima_dit(
# Apply per-parameter dtype (high precision for 1D/critical, transformer_dtype for rest)
for name, p in dit.named_parameters():
dtype_to_use = dtype if (
any(keyword in name for keyword in KEEP_IN_HIGH_PRECISION) or p.ndim == 1
) else transformer_dtype
dtype_to_use = dtype if (any(keyword in name for keyword in KEEP_IN_HIGH_PRECISION) or p.ndim == 1) else transformer_dtype
p.data = p.data.to(dtype=dtype_to_use)
dit.to(device)
@@ -116,6 +122,128 @@ def load_anima_dit(
return dit
FP8_OPTIMIZATION_TARGET_KEYS = ["blocks", ""]
# ".embed." excludes Embedding in LLMAdapter
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer", ".embed."]
def load_anima_model(
device: Union[str, torch.device],
dit_path: str,
attn_mode: str,
split_attn: bool,
loading_device: Union[str, torch.device],
dit_weight_dtype: Optional[torch.dtype],
fp8_scaled: bool = False,
lora_weights_list: Optional[Dict[str, torch.Tensor]] = None,
lora_multipliers: Optional[list[float]] = None,
) -> anima_models.Anima:
"""
Load a HunyuanImage model from the specified checkpoint.
Args:
device (Union[str, torch.device]): Device for optimization or merging
dit_path (str): Path to the DiT model checkpoint.
attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc.
split_attn (bool): Whether to use split attention.
loading_device (Union[str, torch.device]): Device to load the model weights on.
dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights.
If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype.
fp8_scaled (bool): Whether to use fp8 scaling for the model weights.
lora_weights_list (Optional[Dict[str, torch.Tensor]]): LoRA weights to apply, if any.
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
"""
# dit_weight_dtype is None for fp8_scaled
assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None)
device = torch.device(device)
loading_device = torch.device(loading_device)
# We currently support fixed DiT config for Anima models
dit_config = {
"max_img_h": 512,
"max_img_w": 512,
"max_frames": 128,
"in_channels": 16,
"out_channels": 16,
"patch_spatial": 2,
"patch_temporal": 1,
"model_channels": 2048,
"concat_padding_mask": True,
"crossattn_emb_channels": 1024,
"pos_emb_cls": "rope3d",
"pos_emb_learnable": True,
"pos_emb_interpolation": "crop",
"min_fps": 1,
"max_fps": 30,
"use_adaln_lora": True,
"adaln_lora_dim": 256,
"num_blocks": 28,
"num_heads": 16,
"extra_per_block_abs_pos_emb": False,
"rope_h_extrapolation_ratio": 4.0,
"rope_w_extrapolation_ratio": 4.0,
"rope_t_extrapolation_ratio": 1.0,
"extra_h_extrapolation_ratio": 1.0,
"extra_w_extrapolation_ratio": 1.0,
"extra_t_extrapolation_ratio": 1.0,
"rope_enable_fps_modulation": False,
"use_llm_adapter": True,
"attn_mode": attn_mode,
"split_attn": split_attn,
}
with init_empty_weights():
model = anima_models.Anima(**dit_config)
if dit_weight_dtype is not None:
model.to(dit_weight_dtype)
# load model weights with dynamic fp8 optimization and LoRA merging if needed
logger.info(f"Loading DiT model from {dit_path}, device={loading_device}")
rename_hooks = WeightTransformHooks(rename_hook=lambda k: k[len("net.") :] if k.startswith("net.") else k)
sd = load_safetensors_with_lora_and_fp8(
model_files=dit_path,
lora_weights_list=lora_weights_list,
lora_multipliers=lora_multipliers,
fp8_optimization=fp8_scaled,
calc_device=device,
move_to_device=(loading_device == device),
dit_weight_dtype=dit_weight_dtype,
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
weight_transform_hooks=rename_hooks,
)
if fp8_scaled:
apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
if loading_device.type != "cpu":
# make sure all the model weights are on the loading_device
logger.info(f"Moving weights to {loading_device}")
for key in sd.keys():
sd[key] = sd[key].to(loading_device)
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
if missing:
# Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
unexpected_missing = [
k
for k in missing
if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq"))
]
if unexpected_missing:
# Raise error to avoid silent failures
raise RuntimeError(
f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}"
)
missing = {} # all missing keys were expected
if unexpected:
# Raise error to avoid silent failures
raise RuntimeError(f"Unexpected keys in checkpoint: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
logger.info(f"Loaded DiT model from {dit_path}, unexpected missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")
return model
def load_anima_vae(vae_path: str, dtype: torch.dtype = torch.float32, device: str = "cpu"):
"""Load WanVAE from a safetensors/pth file.
@@ -139,14 +267,14 @@ def load_anima_vae(vae_path: str, dtype: torch.dtype = torch.float32, device: st
from library.anima_vae import WanVAE_
# Build model
with torch.device('meta'):
with torch.device("meta"):
vae = WanVAE_(**vae_config)
# Load state dict
if vae_path.endswith('.safetensors'):
vae_sd = load_file(vae_path, device='cpu')
if vae_path.endswith(".safetensors"):
vae_sd = load_file(vae_path, device="cpu")
else:
vae_sd = torch.load(vae_path, map_location='cpu', weights_only=True)
vae_sd = torch.load(vae_path, map_location="cpu", weights_only=True)
vae.load_state_dict(vae_sd, assign=True)
vae = vae.eval().requires_grad_(False).to(device, dtype=dtype)
@@ -175,7 +303,7 @@ def load_qwen3_tokenizer(qwen3_path: str):
if os.path.isdir(qwen3_path):
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
else:
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b')
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
if not os.path.exists(config_dir):
raise FileNotFoundError(
f"Qwen3 config directory not found at {config_dir}. "
@@ -209,12 +337,10 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16
if os.path.isdir(qwen3_path):
# Directory with full model
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
model = transformers.AutoModelForCausalLM.from_pretrained(
qwen3_path, torch_dtype=dtype, local_files_only=True
).model
model = transformers.AutoModelForCausalLM.from_pretrained(qwen3_path, torch_dtype=dtype, local_files_only=True).model
else:
# Single safetensors file - use configs/qwen3_06b/ for config
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b')
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
if not os.path.exists(config_dir):
raise FileNotFoundError(
f"Qwen3 config directory not found at {config_dir}. "
@@ -227,16 +353,16 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16
model = transformers.Qwen3ForCausalLM(qwen3_config).model
# Load weights
if qwen3_path.endswith('.safetensors'):
state_dict = load_file(qwen3_path, device='cpu')
if qwen3_path.endswith(".safetensors"):
state_dict = load_file(qwen3_path, device="cpu")
else:
state_dict = torch.load(qwen3_path, map_location='cpu', weights_only=True)
state_dict = torch.load(qwen3_path, map_location="cpu", weights_only=True)
# Remove 'model.' prefix if present
new_sd = {}
for k, v in state_dict.items():
if k.startswith('model.'):
new_sd[k[len('model.'):]] = v
if k.startswith("model."):
new_sd[k[len("model.") :]] = v
else:
new_sd[k] = v
@@ -265,11 +391,11 @@ def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
return T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
# Use bundled config
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 't5_old')
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "t5_old")
if os.path.exists(config_dir):
return T5TokenizerFast(
vocab_file=os.path.join(config_dir, 'spiece.model'),
tokenizer_file=os.path.join(config_dir, 'tokenizer.json'),
vocab_file=os.path.join(config_dir, "spiece.model"),
tokenizer_file=os.path.join(config_dir, "tokenizer.json"),
)
raise FileNotFoundError(
@@ -291,9 +417,9 @@ def save_anima_model(save_path: str, dit_state_dict: Dict[str, torch.Tensor], dt
for k, v in dit_state_dict.items():
if dtype is not None:
v = v.to(dtype)
prefixed_sd['net.' + k] = v.contiguous()
prefixed_sd["net." + k] = v.contiguous()
save_file(prefixed_sd, save_path, metadata={'format': 'pt'})
save_file(prefixed_sd, save_path, metadata={"format": "pt"})
logger.info(f"Saved Anima model to {save_path}")

View File

@@ -16,8 +16,7 @@ class CausalConv3d(nn.Conv3d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
@@ -41,12 +40,10 @@ class RMS_norm(nn.Module):
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
@@ -61,65 +58,48 @@ class Upsample(nn.Upsample):
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d")
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
if mode == "upsample2d":
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
)
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == 'downsample2d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == 'downsample3d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
@@ -127,15 +107,14 @@ class Resample(nn.Module):
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.mode == 'downsample3d':
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
@@ -144,8 +123,7 @@ class Resample(nn.Module):
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
@@ -166,8 +144,8 @@ class Resample(nn.Module):
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
@@ -181,12 +159,15 @@ class ResidualBlock(nn.Module):
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(),
RMS_norm(in_dim, images=False),
nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1))
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
RMS_norm(out_dim, images=False),
nn.SiLU(),
nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1),
)
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
@@ -196,11 +177,7 @@ class ResidualBlock(nn.Module):
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -229,13 +206,10 @@ class AttentionBlock(nn.Module):
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm(x)
# compute query, key, value
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
-1).permute(0, 1, 3,
2).contiguous().chunk(
3, dim=-1)
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(
@@ -247,20 +221,22 @@ class AttentionBlock(nn.Module):
# output
x = self.proj(x)
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
return x + identity
class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
@@ -288,21 +264,18 @@ class Encoder3d(nn.Module):
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[
i] else 'downsample2d'
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout))
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout)
)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
@@ -310,11 +283,7 @@ class Encoder3d(nn.Module):
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -342,11 +311,7 @@ class Encoder3d(nn.Module):
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -357,14 +322,16 @@ class Encoder3d(nn.Module):
class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
@@ -375,15 +342,15 @@ class Decoder3d(nn.Module):
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout)
)
# upsample blocks
upsamples = []
@@ -399,15 +366,13 @@ class Decoder3d(nn.Module):
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
@@ -416,11 +381,7 @@ class Decoder3d(nn.Module):
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -448,11 +409,7 @@ class Decoder3d(nn.Module):
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -471,14 +428,16 @@ def count_conv3d(model):
class WanVAE_(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
@@ -489,12 +448,10 @@ class WanVAE_(nn.Module):
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
@@ -510,20 +467,15 @@ class WanVAE_(nn.Module):
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
@@ -533,8 +485,7 @@ class WanVAE_(nn.Module):
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
@@ -542,15 +493,9 @@ class WanVAE_(nn.Module):
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
@@ -571,7 +516,7 @@ class WanVAE_(nn.Module):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
# cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@@ -471,7 +471,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
bsz, h, w = latents.shape[0], latents.shape[-2], latents.shape[-1]
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
@@ -512,7 +512,7 @@ def get_noisy_model_input_and_timesteps(
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
sigmas = sigmas.view(-1, 1, 1, 1) if latents.ndim == 4 else sigmas.view(-1, 1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)

View File

@@ -9,7 +9,7 @@ import logging
from tqdm import tqdm
from library.device_utils import clean_memory_on_device
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks
from library.utils import setup_logging
setup_logging()
@@ -220,6 +220,8 @@ def quantize_weight(
tensor_max = torch.max(torch.abs(tensor).view(-1))
scale = tensor_max / max_value
# print(f"Optimizing {key} with scale: {scale}")
# numerical safety
scale = torch.clamp(scale, min=1e-8)
scale = scale.to(torch.float32) # ensure scale is in float32 for division
@@ -245,6 +247,8 @@ def load_safetensors_with_fp8_optimization(
weight_hook=None,
quantization_mode: str = "block",
block_size: Optional[int] = 64,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict:
"""
Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization.
@@ -260,6 +264,8 @@ def load_safetensors_with_fp8_optimization(
weight_hook (callable, optional): Function to apply to each weight tensor before optimization
quantization_mode (str): Quantization mode, "tensor", "channel", or "block"
block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block")
disable_numpy_memmap (bool): Disable numpy memmap when loading safetensors
weight_transform_hooks (WeightTransformHooks, optional): Hooks for weight transformation during loading
Returns:
dict: FP8 optimized state dict
@@ -288,7 +294,9 @@ def load_safetensors_with_fp8_optimization(
# Process each file
state_dict = {}
for model_file in model_files:
with MemoryEfficientSafeOpen(model_file) as f:
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
keys = f.keys()
for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"):
value = f.get_tensor(key)
@@ -311,6 +319,11 @@ def load_safetensors_with_fp8_optimization(
value = value.to(calc_device)
original_dtype = value.dtype
if original_dtype.itemsize == 1:
raise ValueError(
f"Layer {key} is already in {original_dtype} format. `--fp8_scaled` optimization should not be applied. Please use fp16/bf16/float32 model weights."
+ f" / レイヤー {key} は既に{original_dtype}形式です。`--fp8_scaled` 最適化は適用できません。FP16/BF16/Float32のモデル重みを使用してください。"
)
quantized_weight, scale_tensor = quantize_weight(
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
)
@@ -387,7 +400,7 @@ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=
else:
o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
o = o.reshape(original_shape[0], original_shape[1], -1) if x.ndim == 3 else o.reshape(original_shape[0], -1)
o = o.reshape(original_shape[0], original_shape[1], -1) if len(original_shape) == 3 else o.reshape(original_shape[0], -1)
return o.to(input_dtype)
else:

View File

@@ -5,7 +5,7 @@ import torch
from tqdm import tqdm
from library.device_utils import synchronize_device
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
from library.utils import setup_logging
setup_logging()
@@ -44,7 +44,7 @@ def filter_lora_state_dict(
def load_safetensors_with_lora_and_fp8(
model_files: Union[str, List[str]],
lora_weights_list: Optional[Dict[str, torch.Tensor]],
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
lora_multipliers: Optional[List[float]],
fp8_optimization: bool,
calc_device: torch.device,
@@ -52,19 +52,23 @@ def load_safetensors_with_lora_and_fp8(
dit_weight_dtype: Optional[torch.dtype] = None,
target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict[str, torch.Tensor]:
"""
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
Args:
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
lora_weights_list (Optional[Dict[str, torch.Tensor]]): Dictionary of LoRA weight tensors to load.
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
fp8_optimization (bool): Whether to apply FP8 optimization.
calc_device (torch.device): Device to calculate on.
move_to_device (bool): Whether to move tensors to the calculation device after loading.
target_keys (Optional[List[str]]): Keys to target for optimization.
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
"""
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
@@ -73,19 +77,9 @@ def load_safetensors_with_lora_and_fp8(
extended_model_files = []
for model_file in model_files:
basename = os.path.basename(model_file)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
state_dict = {}
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(model_file), filename)
if os.path.exists(filepath):
extended_model_files.append(filepath)
else:
raise FileNotFoundError(f"File {filepath} not found")
split_filenames = get_split_weight_filenames(model_file)
if split_filenames is not None:
extended_model_files.extend(split_filenames)
else:
extended_model_files.append(model_file)
model_files = extended_model_files
@@ -114,7 +108,7 @@ def load_safetensors_with_lora_and_fp8(
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
# make hook for LoRA merging
def weight_hook_func(model_weight_key, model_weight, keep_on_calc_device=False):
def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
if not model_weight_key.endswith(".weight"):
@@ -145,6 +139,13 @@ def load_safetensors_with_lora_and_fp8(
down_weight = down_weight.to(calc_device)
up_weight = up_weight.to(calc_device)
original_dtype = model_weight.dtype
if original_dtype.itemsize == 1: # fp8
# temporarily convert to float16 for calculation
model_weight = model_weight.to(torch.float16)
down_weight = down_weight.to(torch.float16)
up_weight = up_weight.to(torch.float16)
# W <- W + U * D
if len(model_weight.size()) == 2:
# linear
@@ -166,6 +167,9 @@ def load_safetensors_with_lora_and_fp8(
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
model_weight = model_weight + multiplier * conved * scale
if original_dtype.itemsize == 1: # fp8
model_weight = model_weight.to(original_dtype) # convert back to original dtype
# remove LoRA keys from set
lora_weight_keys.remove(down_key)
lora_weight_keys.remove(up_key)
@@ -187,6 +191,8 @@ def load_safetensors_with_lora_and_fp8(
target_keys,
exclude_keys,
weight_hook=weight_hook,
disable_numpy_memmap=disable_numpy_memmap,
weight_transform_hooks=weight_transform_hooks,
)
for lora_weight_keys in list_of_lora_weight_keys:
@@ -208,6 +214,8 @@ def load_safetensors_with_fp8_optimization_and_hook(
target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
weight_hook: callable = None,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict[str, torch.Tensor]:
"""
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
@@ -218,7 +226,14 @@ def load_safetensors_with_fp8_optimization_and_hook(
)
# dit_weight_dtype is not used because we use fp8 optimization
state_dict = load_safetensors_with_fp8_optimization(
model_files, calc_device, target_keys, exclude_keys, move_to_device=move_to_device, weight_hook=weight_hook
model_files,
calc_device,
target_keys,
exclude_keys,
move_to_device=move_to_device,
weight_hook=weight_hook,
disable_numpy_memmap=disable_numpy_memmap,
weight_transform_hooks=weight_transform_hooks,
)
else:
logger.info(
@@ -226,7 +241,8 @@ def load_safetensors_with_fp8_optimization_and_hook(
)
state_dict = {}
for model_file in model_files:
with MemoryEfficientSafeOpen(model_file) as f:
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
if weight_hook is None and move_to_device:
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass
import os
import re
import numpy as np
@@ -44,6 +45,7 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
validated[key] = value
return validated
# print(f"Using memory efficient save file: {filename}")
header = {}
offset = 0
@@ -88,15 +90,17 @@ class MemoryEfficientSafeOpen:
by using memory mapping for large tensors and avoiding unnecessary copies.
"""
def __init__(self, filename):
def __init__(self, filename, disable_numpy_memmap=False):
"""Initialize the SafeTensor reader.
Args:
filename (str): Path to the safetensors file to read.
disable_numpy_memmap (bool): If True, disable numpy memory mapping for large tensors, using standard file read instead.
"""
self.filename = filename
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
self.disable_numpy_memmap = disable_numpy_memmap
def __enter__(self):
"""Enter context manager."""
@@ -178,7 +182,8 @@ class MemoryEfficientSafeOpen:
# Use memmap for large tensors to avoid intermediate copies.
# If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired.
# So we only use memmap if device is not cpu.
if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
# If disable_numpy_memmap is True, skip numpy memory mapping to load with standard file read.
if not self.disable_numpy_memmap and num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
# Create memory map for zero-copy reading
mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,))
byte_tensor = torch.from_numpy(mm) # zero copy
@@ -285,7 +290,11 @@ class MemoryEfficientSafeOpen:
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
path: str,
device: Union[str, torch.device],
disable_mmap: bool = False,
dtype: Optional[torch.dtype] = None,
disable_numpy_memmap: bool = False,
) -> dict[str, torch.Tensor]:
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
@@ -293,7 +302,7 @@ def load_safetensors(
# logger.info(f"Loading without mmap (experimental)")
state_dict = {}
device = torch.device(device) if device is not None else None
with MemoryEfficientSafeOpen(path) as f:
with MemoryEfficientSafeOpen(path, disable_numpy_memmap=disable_numpy_memmap) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key, device=device, dtype=dtype)
synchronize_device(device)
@@ -309,6 +318,29 @@ def load_safetensors(
return state_dict
def get_split_weight_filenames(file_path: str) -> Optional[list[str]]:
"""
Get the list of split weight filenames (full paths) if the file name ends with 00001-of-00004 etc.
Returns None if the file is not split.
"""
basename = os.path.basename(file_path)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
filenames = []
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(file_path), filename)
if os.path.exists(filepath):
filenames.append(filepath)
else:
raise FileNotFoundError(f"File {filepath} not found")
return filenames
else:
return None
def load_split_weights(
file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
) -> Dict[str, torch.Tensor]:
@@ -319,19 +351,11 @@ def load_split_weights(
device = torch.device(device)
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
basename = os.path.basename(file_path)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
split_filenames = get_split_weight_filenames(file_path)
if split_filenames is not None:
state_dict = {}
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(file_path), filename)
if os.path.exists(filepath):
state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap, dtype=dtype))
else:
raise FileNotFoundError(f"File {filepath} not found")
for filename in split_filenames:
state_dict.update(load_safetensors(filename, device=device, disable_mmap=disable_mmap, dtype=dtype))
else:
state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
return state_dict
@@ -349,3 +373,107 @@ def find_key(safetensors_file: str, starts_with: Optional[str] = None, ends_with
if (starts_with is None or key.startswith(starts_with)) and (ends_with is None or key.endswith(ends_with)):
return key
return None
@dataclass
class WeightTransformHooks:
split_hook: Optional[callable] = None
concat_hook: Optional[callable] = None
rename_hook: Optional[callable] = None
class TensorWeightAdapter:
"""
A wrapper for weight conversion hooks (split and concat) to be used with MemoryEfficientSafeOpen.
This wrapper adapts the original MemoryEfficientSafeOpen to apply the provided split and concat hooks
when loading tensors.
split_hook: A callable that takes (original_key: str, original_tensor: torch.Tensor) and returns (new_keys: list[str], new_tensors: list[torch.Tensor]).
concat_hook: A callable that takes (original_key: str, tensors: dict[str, torch.Tensor]) and returns (new_key: str, concatenated_tensor: torch.Tensor).
rename_hook: A callable that takes (original_key: str) and returns (new_key: str).
If tensors is None, the hook should return only the new keys (for split) or new key (for concat), without tensors.
No need to implement __enter__ and __exit__ methods, as they are handled by the original MemoryEfficientSafeOpen.
Do not use this wrapper as a context manager directly, like `with WeightConvertHookWrapper(...) as f:`.
**concat_hook is not tested yet.**
"""
def __init__(self, weight_convert_hook: WeightTransformHooks, original_f: MemoryEfficientSafeOpen):
self.original_f = original_f
self.new_key_to_original_key_map: dict[str, Union[str, list[str]]] = (
{}
) # for split: new_key -> original_key; for concat: new_key -> list of original_keys; for direct mapping: new_key -> original_key
self.concat_key_set = set() # set of concatenated keys
self.split_key_set = set() # set of split keys
self.new_keys = []
self.tensor_cache = {} # cache for split tensors
self.split_hook = weight_convert_hook.split_hook
self.concat_hook = weight_convert_hook.concat_hook
self.rename_hook = weight_convert_hook.rename_hook
for key in self.original_f.keys():
if self.split_hook is not None:
converted_keys, _ = self.split_hook(key, None) # get new keys only
if converted_keys is not None:
for converted_key in converted_keys:
self.new_key_to_original_key_map[converted_key] = key
self.split_key_set.add(converted_key)
self.new_keys.extend(converted_keys)
continue # skip concat_hook if split_hook is applied
if self.concat_hook is not None:
converted_key, _ = self.concat_hook(key, None) # get new key only
if converted_key is not None:
if converted_key not in self.concat_key_set: # first time seeing this concatenated key
self.concat_key_set.add(converted_key)
self.new_key_to_original_key_map[converted_key] = []
# multiple original keys map to the same concatenated key
self.new_key_to_original_key_map[converted_key].append(key)
self.new_keys.append(converted_key)
continue # skip to next key
# direct mapping
if self.rename_hook is not None:
new_key = self.rename_hook(key)
self.new_key_to_original_key_map[new_key] = key
else:
new_key = key
self.new_keys.append(new_key)
def keys(self) -> list[str]:
return self.new_keys
def get_tensor(self, new_key: str, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
# load tensor by new_key, applying split or concat hooks as needed
if new_key not in self.new_key_to_original_key_map:
# direct mapping
return self.original_f.get_tensor(new_key, device=device, dtype=dtype)
elif new_key in self.split_key_set:
# split hook: split key is requested multiple times, so we cache the result
original_key = self.new_key_to_original_key_map[new_key]
if original_key not in self.tensor_cache: # not yet split
original_tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
new_keys, new_tensors = self.split_hook(original_key, original_tensor) # apply split hook
for k, t in zip(new_keys, new_tensors):
self.tensor_cache[k] = t
return self.tensor_cache.pop(new_key) # return and remove from cache
elif new_key in self.concat_key_set:
# concat hook: concatenated key is requested only once, so we do not cache the result
tensors = {}
for original_key in self.new_key_to_original_key_map[new_key]:
tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
tensors[original_key] = tensor
_, concatenated_tensors = self.concat_hook(self.new_key_to_original_key_map[new_key][0], tensors) # apply concat hook
return concatenated_tensors
else:
# direct mapping
original_key = self.new_key_to_original_key_map[new_key]
return self.original_f.get_tensor(original_key, device=device, dtype=dtype)

View File

@@ -9,6 +9,7 @@ import torch
from library import anima_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library import qwen_image_autoencoder_kl
from library.utils import setup_logging
@@ -45,8 +46,8 @@ class AnimaTokenizeStrategy(TokenizeStrategy):
t5_tokenizer = anima_utils.load_t5_tokenizer(t5_tokenizer_path)
self.qwen3_tokenizer = qwen3_tokenizer
self.t5_tokenizer = t5_tokenizer
self.qwen3_max_length = qwen3_max_length
self.t5_tokenizer = t5_tokenizer
self.t5_max_length = t5_max_length
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
@@ -54,26 +55,17 @@ class AnimaTokenizeStrategy(TokenizeStrategy):
# Tokenize with Qwen3
qwen3_encoding = self.qwen3_tokenizer.batch_encode_plus(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=self.qwen3_max_length,
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.qwen3_max_length
)
qwen3_input_ids = qwen3_encoding["input_ids"]
qwen3_attn_mask = qwen3_encoding["attention_mask"]
# Tokenize with T5 (for LLM Adapter target tokens)
t5_encoding = self.t5_tokenizer.batch_encode_plus(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=self.t5_max_length,
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.t5_max_length
)
t5_input_ids = t5_encoding["input_ids"]
t5_attn_mask = t5_encoding["attention_mask"]
return [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
@@ -84,46 +76,11 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
T5 tokens are passed through unchanged (only used by LLM Adapter).
"""
def __init__(
self,
dropout_rate: float = 0.0,
) -> None:
self.dropout_rate = dropout_rate
# Cached unconditional embeddings (from encoding empty caption "")
# Must be initialized via cache_uncond_embeddings() before text encoder is deleted
self._uncond_prompt_embeds: Optional[torch.Tensor] = None # (1, seq_len, hidden)
self._uncond_attn_mask: Optional[torch.Tensor] = None # (1, seq_len)
self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len)
self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len)
def cache_uncond_embeddings(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
) -> None:
"""Pre-encode empty caption "" and cache the unconditional embeddings.
Must be called before the text encoder is deleted from GPU.
This matches diffusion-pipe-main behavior where empty caption embeddings
are pre-cached and swapped in during caption dropout.
"""
logger.info("Caching unconditional embeddings for caption dropout (encoding empty caption)...")
tokens = tokenize_strategy.tokenize("")
with torch.no_grad():
uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens, enable_dropout=False)
# Store as CPU tensors (1, seq_len, ...) to avoid GPU memory waste
self._uncond_prompt_embeds = uncond_outputs[0].cpu()
self._uncond_attn_mask = uncond_outputs[1].cpu()
self._uncond_t5_input_ids = uncond_outputs[2].cpu()
self._uncond_t5_attn_mask = uncond_outputs[3].cpu()
logger.info(" Unconditional embeddings cached successfully")
def __init__(self) -> None:
super().__init__()
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
enable_dropout: bool = True,
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
@@ -134,82 +91,20 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
Returns:
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
"""
# Do not handle dropout here; handled dataset-side or in drop_cached_text_encoder_outputs()
qwen3_text_encoder = models[0]
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
# Handle dropout: replace dropped items with unconditional embeddings (matching diffusion-pipe-main)
batch_size = qwen3_input_ids.shape[0]
non_drop_indices = []
for i in range(batch_size):
drop = enable_dropout and (self.dropout_rate > 0.0 and random.random() < self.dropout_rate)
if not drop:
non_drop_indices.append(i)
encoder_device = qwen3_text_encoder.device
encoder_device = qwen3_text_encoder.device if hasattr(qwen3_text_encoder, 'device') else next(qwen3_text_encoder.parameters()).device
qwen3_input_ids = qwen3_input_ids.to(encoder_device)
qwen3_attn_mask = qwen3_attn_mask.to(encoder_device)
outputs = qwen3_text_encoder(input_ids=qwen3_input_ids, attention_mask=qwen3_attn_mask)
prompt_embeds = outputs.last_hidden_state
prompt_embeds[~qwen3_attn_mask.bool()] = 0
if len(non_drop_indices) > 0 and len(non_drop_indices) < batch_size:
# Only encode non-dropped items to save compute
nd_input_ids = qwen3_input_ids[non_drop_indices].to(encoder_device)
nd_attn_mask = qwen3_attn_mask[non_drop_indices].to(encoder_device)
elif len(non_drop_indices) == batch_size:
nd_input_ids = qwen3_input_ids.to(encoder_device)
nd_attn_mask = qwen3_attn_mask.to(encoder_device)
else:
nd_input_ids = None
nd_attn_mask = None
if nd_input_ids is not None:
outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask)
nd_encoded_text = outputs.last_hidden_state
# Zero out padding positions
nd_encoded_text[~nd_attn_mask.bool()] = 0
# Build full batch: fill non-dropped with encoded, dropped with unconditional
if len(non_drop_indices) == batch_size:
prompt_embeds = nd_encoded_text
attn_mask = qwen3_attn_mask.to(encoder_device)
else:
# Get unconditional embeddings
if self._uncond_prompt_embeds is not None:
uncond_pe = self._uncond_prompt_embeds[0]
uncond_am = self._uncond_attn_mask[0]
uncond_t5_ids = self._uncond_t5_input_ids[0]
uncond_t5_am = self._uncond_t5_attn_mask[0]
else:
# Encode empty caption on-the-fly (text encoder still available)
uncond_tokens = tokenize_strategy.tokenize("")
uncond_ids = uncond_tokens[0].to(encoder_device)
uncond_mask = uncond_tokens[1].to(encoder_device)
uncond_out = qwen3_text_encoder(input_ids=uncond_ids, attention_mask=uncond_mask)
uncond_pe = uncond_out.last_hidden_state[0]
uncond_pe[~uncond_mask[0].bool()] = 0
uncond_am = uncond_mask[0]
uncond_t5_ids = uncond_tokens[2][0]
uncond_t5_am = uncond_tokens[3][0]
seq_len = qwen3_input_ids.shape[1]
hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1]
dtype = nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype
prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype)
attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype)
if len(non_drop_indices) > 0:
prompt_embeds[non_drop_indices] = nd_encoded_text
attn_mask[non_drop_indices] = nd_attn_mask
# Fill dropped items with unconditional embeddings
t5_input_ids = t5_input_ids.clone()
t5_attn_mask = t5_attn_mask.clone()
drop_indices = [i for i in range(batch_size) if i not in non_drop_indices]
for i in drop_indices:
prompt_embeds[i] = uncond_pe.to(device=encoder_device, dtype=dtype)
attn_mask[i] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype)
t5_input_ids[i] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
t5_attn_mask[i] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
return [prompt_embeds, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
def drop_cached_text_encoder_outputs(
self,
@@ -217,6 +112,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
attn_mask: torch.Tensor,
t5_input_ids: torch.Tensor,
t5_attn_mask: torch.Tensor,
caption_dropout_rates: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
"""Apply dropout to cached text encoder outputs.
@@ -224,37 +120,30 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
Replaces dropped items with pre-cached unconditional embeddings (from encoding "")
to match diffusion-pipe-main behavior.
"""
if prompt_embeds is not None and self.dropout_rate > 0.0:
# Clone to avoid in-place modification of cached tensors
prompt_embeds = prompt_embeds.clone()
if attn_mask is not None:
attn_mask = attn_mask.clone()
if t5_input_ids is not None:
t5_input_ids = t5_input_ids.clone()
if t5_attn_mask is not None:
t5_attn_mask = t5_attn_mask.clone()
if caption_dropout_rates is None or all(caption_dropout_rates == 0.0):
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
for i in range(prompt_embeds.shape[0]):
if random.random() < self.dropout_rate:
if self._uncond_prompt_embeds is not None:
# Use pre-cached unconditional embeddings
prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if attn_mask is not None:
attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype)
if t5_input_ids is not None:
t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
if t5_attn_mask is not None:
t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
else:
# Fallback: zero out (should not happen if cache_uncond_embeddings was called)
logger.warning("Unconditional embeddings not cached, falling back to zeros for caption dropout")
prompt_embeds[i] = torch.zeros_like(prompt_embeds[i])
if attn_mask is not None:
attn_mask[i] = torch.zeros_like(attn_mask[i])
if t5_input_ids is not None:
t5_input_ids[i] = torch.zeros_like(t5_input_ids[i])
if t5_attn_mask is not None:
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
# Clone to avoid in-place modification of cached tensors
prompt_embeds = prompt_embeds.clone()
if attn_mask is not None:
attn_mask = attn_mask.clone()
if t5_input_ids is not None:
t5_input_ids = t5_input_ids.clone()
if t5_attn_mask is not None:
t5_attn_mask = t5_attn_mask.clone()
for i in range(prompt_embeds.shape[0]):
if random.random() < caption_dropout_rates[i].item():
# Use pre-cached unconditional embeddings
prompt_embeds[i] = 0
if attn_mask is not None:
attn_mask[i] = 0
if t5_input_ids is not None:
t5_input_ids[i, 0] = 1 # Set to </s> token ID
t5_input_ids[i, 1:] = 0
if t5_attn_mask is not None:
t5_attn_mask[i, 0] = 1
t5_attn_mask[i, 1:] = 0
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
@@ -297,6 +186,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return False
if "t5_attn_mask" not in npz:
return False
if "caption_dropout_rate" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -309,7 +200,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask = data["attn_mask"]
t5_input_ids = data["t5_input_ids"]
t5_attn_mask = data["t5_attn_mask"]
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
caption_dropout_rate = data["caption_dropout_rate"]
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
def cache_batch_outputs(
self,
@@ -323,12 +215,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# Always disable dropout during caching
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = anima_text_encoding_strategy.encode_tokens(
tokenize_strategy,
models,
tokens_and_masks,
enable_dropout=False,
tokenize_strategy, models, tokens_and_masks
)
# Convert to numpy for caching
@@ -344,6 +232,7 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask_i = attn_mask[i]
t5_input_ids_i = t5_input_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
if self.cache_to_disk:
np.savez(
@@ -352,9 +241,10 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask=attn_mask_i,
t5_input_ids=t5_input_ids_i,
t5_attn_mask=t5_attn_mask_i,
caption_dropout_rate=caption_dropout_rate,
)
else:
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i)
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
@@ -374,18 +264,10 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
return self.ANIMA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ self.ANIMA_LATENTS_NPZ_SUFFIX
)
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX
def is_disk_cached_latents_expected(
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
):
return self._default_is_disk_cached_latents_expected(
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
@@ -393,32 +275,23 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
"""Cache batch of latents using WanVAE.
"""Cache batch of latents using Qwen Image VAE.
vae is expected to be the WanVAE_ model (not the wrapper).
vae is expected to be the Qwen Image VAE (AutoencoderKLQwenImage).
The encoding function handles the mean/std normalization.
"""
from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
vae_device = next(vae.parameters()).device
vae_dtype = next(vae.parameters()).dtype
# Create scale tensors on VAE device
mean = torch.tensor(ANIMA_VAE_MEAN, dtype=vae_dtype, device=vae_device)
std = torch.tensor(ANIMA_VAE_STD, dtype=vae_dtype, device=vae_device)
scale = [mean, 1.0 / std]
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage = vae
vae_device = vae.device
vae_dtype = vae.dtype
def encode_by_vae(img_tensor):
"""Encode image tensor to latents.
img_tensor: (B, C, H, W) in [-1, 1] range (already normalized by IMAGE_TRANSFORMS)
Need to add temporal dim to get (B, C, T=1, H, W) for WanVAE
Qwen Image VAE accepts inputs in (B, C, H, W) or (B, C, 1, H, W) shape.
Returns latents in (B, 16, 1, H/8, W/8) shape on CPU.
"""
# Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W)
img_tensor = img_tensor.unsqueeze(2)
img_tensor = img_tensor.to(vae_device, dtype=vae_dtype)
latents = vae.encode(img_tensor, scale)
latents = vae.encode_pixels_to_latents(img_tensor)
return latents.to("cpu")
self._default_cache_batch_latents(

View File

@@ -179,12 +179,15 @@ def split_train_val(
class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
def __init__(
self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str, caption_dropout_rate: float = 0.0
) -> None:
self.image_key: str = image_key
self.num_repeats: int = num_repeats
self.caption: str = caption
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
self.caption_dropout_rate: float = caption_dropout_rate
self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
@@ -197,7 +200,7 @@ class ImageInfo:
)
self.cond_img_path: Optional[str] = None
self.image: Optional[Image.Image] = None # optional, original PIL Image
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
self.text_encoder_outputs_npz: Optional[str] = None # filename. set in cache_text_encoder_outputs
# new
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
@@ -1096,11 +1099,11 @@ class BaseDataset(torch.utils.data.Dataset):
def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
def is_text_encoder_output_cacheable(self):
def is_text_encoder_output_cacheable(self, cache_supports_dropout: bool = False):
return all(
[
not (
subset.caption_dropout_rate > 0
subset.caption_dropout_rate > 0 and not cache_supports_dropout
or subset.shuffle_caption
or subset.token_warmup_step > 0
or subset.caption_tag_dropout_rate > 0
@@ -2137,7 +2140,7 @@ class DreamBoothDataset(BaseDataset):
num_train_images += num_repeats * len(img_paths)
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path, subset.caption_dropout_rate)
info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)
@@ -2338,7 +2341,7 @@ class FineTuningDataset(BaseDataset):
if caption is None:
caption = ""
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path, subset.caption_dropout_rate)
image_info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)
@@ -2661,8 +2664,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
def is_text_encoder_output_cacheable(self) -> bool:
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
def is_text_encoder_output_cacheable(self, cache_supports_dropout: bool = False) -> bool:
return all([dataset.is_text_encoder_output_cacheable(cache_supports_dropout) for dataset in self.datasets])
def set_current_strategies(self):
for dataset in self.datasets:

View File

@@ -1,18 +1,17 @@
# LoRA network module for Anima
import math
# LoRA network module for Anima
import ast
import os
import re
from typing import Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch
from library.utils import setup_logging
from networks.lora_flux import LoRAModule, LoRAInfModule
setup_logging()
import logging
setup_logging()
logger = logging.getLogger(__name__)
from networks.lora_flux import LoRAModule, LoRAInfModule
def create_network(
multiplier: float,
@@ -29,68 +28,28 @@ def create_network(
if network_alpha is None:
network_alpha = 1.0
# type_dims: [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
self_attn_dim = kwargs.get("self_attn_dim", None)
cross_attn_dim = kwargs.get("cross_attn_dim", None)
mlp_dim = kwargs.get("mlp_dim", None)
mod_dim = kwargs.get("mod_dim", None)
llm_adapter_dim = kwargs.get("llm_adapter_dim", None)
if self_attn_dim is not None:
self_attn_dim = int(self_attn_dim)
if cross_attn_dim is not None:
cross_attn_dim = int(cross_attn_dim)
if mlp_dim is not None:
mlp_dim = int(mlp_dim)
if mod_dim is not None:
mod_dim = int(mod_dim)
if llm_adapter_dim is not None:
llm_adapter_dim = int(llm_adapter_dim)
type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
if all([d is None for d in type_dims]):
type_dims = None
# emb_dims: [x_embedder, t_embedder, final_layer]
emb_dims = kwargs.get("emb_dims", None)
if emb_dims is not None:
emb_dims = emb_dims.strip()
if emb_dims.startswith("[") and emb_dims.endswith("]"):
emb_dims = emb_dims[1:-1]
emb_dims = [int(d) for d in emb_dims.split(",")]
assert len(emb_dims) == 3, f"invalid emb_dims: {emb_dims}, must be 3 dimensions (x_embedder, t_embedder, final_layer)"
# block selection
def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
if selection == "all":
return [True] * total_blocks
if selection == "none" or selection == "":
return [False] * total_blocks
selected = [False] * total_blocks
ranges = selection.split(",")
for r in ranges:
if "-" in r:
start, end = map(str.strip, r.split("-"))
start, end = int(start), int(end)
assert 0 <= start < total_blocks and 0 <= end < total_blocks and start <= end
for i in range(start, end + 1):
selected[i] = True
else:
index = int(r)
assert 0 <= index < total_blocks
selected[index] = True
return selected
train_block_indices = kwargs.get("train_block_indices", None)
if train_block_indices is not None:
num_blocks = len(unet.blocks) if hasattr(unet, 'blocks') else 999
train_block_indices = parse_block_selection(train_block_indices, num_blocks)
# train LLM adapter
train_llm_adapter = kwargs.get("train_llm_adapter", False)
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
if train_llm_adapter is not None:
train_llm_adapter = True if train_llm_adapter == "True" else False
train_llm_adapter = True if train_llm_adapter.lower() == "true" else False
exclude_patterns = kwargs.get("exclude_patterns", None)
if exclude_patterns is None:
exclude_patterns = []
else:
exclude_patterns = ast.literal_eval(exclude_patterns)
if not isinstance(exclude_patterns, list):
exclude_patterns = [exclude_patterns]
# add default exclude patterns
exclude_patterns.append(r".*(_modulation|_norm|_embedder|final_layer).*")
# regular expression for module selection: exclude and include
include_patterns = kwargs.get("include_patterns", None)
if include_patterns is not None:
include_patterns = ast.literal_eval(include_patterns)
if not isinstance(include_patterns, list):
include_patterns = [include_patterns]
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
@@ -101,9 +60,43 @@ def create_network(
module_dropout = float(module_dropout)
# verbose
verbose = kwargs.get("verbose", False)
verbose = kwargs.get("verbose", "false")
if verbose is not None:
verbose = True if verbose == "True" else False
verbose = True if verbose.lower() == "true" else False
# regex-specific learning rates / dimensions
def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]:
"""
Parse a string of key-value pairs separated by commas.
"""
pairs = {}
for pair in kv_pair_str.split(","):
pair = pair.strip()
if not pair:
continue
if "=" not in pair:
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
continue
key, value = pair.split("=", 1)
key = key.strip()
value = value.strip()
try:
pairs[key] = int(value) if is_int else float(value)
except ValueError:
logger.warning(f"Invalid value for {key}: {value}")
return pairs
network_reg_lrs = kwargs.get("network_reg_lrs", None)
if network_reg_lrs is not None:
reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False)
else:
reg_lrs = None
network_reg_dims = kwargs.get("network_reg_dims", None)
if network_reg_dims is not None:
reg_dims = parse_kv_pairs(network_reg_dims, is_int=True)
else:
reg_dims = None
network = LoRANetwork(
text_encoders,
@@ -115,9 +108,10 @@ def create_network(
rank_dropout=rank_dropout,
module_dropout=module_dropout,
train_llm_adapter=train_llm_adapter,
type_dims=type_dims,
emb_dims=emb_dims,
train_block_indices=train_block_indices,
exclude_patterns=exclude_patterns,
include_patterns=include_patterns,
reg_dims=reg_dims,
reg_lrs=reg_lrs,
verbose=verbose,
)
@@ -137,6 +131,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
@@ -173,8 +168,8 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh
class LoRANetwork(torch.nn.Module):
# Target modules: DiT blocks
ANIMA_TARGET_REPLACE_MODULE = ["Block"]
# Target modules: DiT blocks, embedders, final layer. embedders and final layer are excluded by default.
ANIMA_TARGET_REPLACE_MODULE = ["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"]
# Target modules: LLM Adapter blocks
ANIMA_ADAPTER_TARGET_REPLACE_MODULE = ["LLMAdapterTransformerBlock"]
# Target modules for text encoder (Qwen3)
@@ -197,9 +192,10 @@ class LoRANetwork(torch.nn.Module):
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
train_llm_adapter: bool = False,
type_dims: Optional[List[int]] = None,
emb_dims: Optional[List[int]] = None,
train_block_indices: Optional[List[bool]] = None,
exclude_patterns: Optional[List[str]] = None,
include_patterns: Optional[List[str]] = None,
reg_dims: Optional[Dict[str, int]] = None,
reg_lrs: Optional[Dict[str, float]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
@@ -210,21 +206,36 @@ class LoRANetwork(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.train_llm_adapter = train_llm_adapter
self.type_dims = type_dims
self.emb_dims = emb_dims
self.train_block_indices = train_block_indices
self.reg_dims = reg_dims
self.reg_lrs = reg_lrs
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info(f"create LoRA network from weights")
if self.emb_dims is None:
self.emb_dims = [0] * 3
logger.info("create LoRA network from weights")
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
# compile regular expression if specified
def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
re_patterns = []
if patterns is not None:
for pattern in patterns:
try:
re_pattern = re.compile(pattern)
except re.error as e:
logger.error(f"Invalid pattern '{pattern}': {e}")
continue
re_patterns.append(re_pattern)
return re_patterns
exclude_re_patterns = str_to_re_patterns(exclude_patterns)
include_re_patterns = str_to_re_patterns(include_patterns)
# create module instances
def create_modules(
@@ -232,15 +243,9 @@ class LoRANetwork(torch.nn.Module):
text_encoder_idx: Optional[int],
root_module: torch.nn.Module,
target_replace_modules: List[str],
filter: Optional[str] = None,
default_dim: Optional[int] = None,
include_conv2d_if_filter: bool = False,
) -> Tuple[List[LoRAModule], List[str]]:
prefix = (
self.LORA_PREFIX_ANIMA
if is_unet
else self.LORA_PREFIX_TEXT_ENCODER
)
prefix = self.LORA_PREFIX_ANIMA if is_unet else self.LORA_PREFIX_TEXT_ENCODER
loras = []
skipped = []
@@ -255,14 +260,16 @@ class LoRANetwork(torch.nn.Module):
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_")
original_name = (name + "." if name else "") + child_name
lora_name = f"{prefix}.{original_name}".replace(".", "_")
force_incl_conv2d = False
if filter is not None:
if filter not in lora_name:
continue
force_incl_conv2d = include_conv2d_if_filter
# exclude/include filter (fullmatch: pattern must match the entire original_name)
excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns)
included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns)
if excluded and not included:
if verbose:
logger.info(f"exclude: {original_name}")
continue
dim = None
alpha_val = None
@@ -272,43 +279,18 @@ class LoRANetwork(torch.nn.Module):
dim = modules_dim[lora_name]
alpha_val = modules_alpha[lora_name]
else:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
if is_unet and type_dims is not None:
# type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
# Order matters: check most specific identifiers first to avoid mismatches.
identifier_order = [
(4, ("llm_adapter",)),
(3, ("adaln_modulation",)),
(0, ("self_attn",)),
(1, ("cross_attn",)),
(2, ("mlp",)),
]
for idx, ids in identifier_order:
d = type_dims[idx]
if d is not None and all(id_str in lora_name for id_str in ids):
dim = d # 0 means skip
break
# block index filtering
if is_unet and dim and self.train_block_indices is not None and "blocks_" in lora_name:
# Extract block index from lora_name: "lora_unet_blocks_0_self_attn..."
parts = lora_name.split("_")
for pi, part in enumerate(parts):
if part == "blocks" and pi + 1 < len(parts):
try:
block_index = int(parts[pi + 1])
if not self.train_block_indices[block_index]:
dim = 0
except (ValueError, IndexError):
pass
break
elif force_incl_conv2d:
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
if self.reg_dims is not None:
for reg, d in self.reg_dims.items():
if re.fullmatch(reg, original_name):
dim = d
alpha_val = self.alpha
logger.info(f"LoRA {original_name} matched with regex {reg}, using dim: {dim}")
break
# fallback to default dim if not matched by reg_dims or reg_dims is not specified
if dim is None:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
if dim is None or dim == 0:
if is_linear or is_conv2d_1x1:
@@ -325,6 +307,7 @@ class LoRANetwork(torch.nn.Module):
rank_dropout=rank_dropout,
module_dropout=module_dropout,
)
lora.original_name = original_name
loras.append(lora)
if target_replace_modules is None:
@@ -339,9 +322,7 @@ class LoRANetwork(torch.nn.Module):
if text_encoder is None:
continue
logger.info(f"create LoRA for Text Encoder {i+1}:")
te_loras, te_skipped = create_modules(
False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
)
te_loras, te_skipped = create_modules(False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder {i+1}: {len(te_loras)} modules.")
self.text_encoder_loras.extend(te_loras)
skipped_te += te_skipped
@@ -354,19 +335,6 @@ class LoRANetwork(torch.nn.Module):
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
# emb_dims: [x_embedder, t_embedder, final_layer]
if self.emb_dims:
for filter_name, in_dim in zip(
["x_embedder", "t_embedder", "final_layer"],
self.emb_dims,
):
loras, _ = create_modules(
True, None, unet, None,
filter=filter_name, default_dim=in_dim,
include_conv2d_if_filter=(filter_name == "x_embedder"),
)
self.unet_loras.extend(loras)
logger.info(f"create LoRA for Anima DiT: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
@@ -396,6 +364,7 @@ class LoRANetwork(torch.nn.Module):
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
@@ -443,10 +412,10 @@ class LoRANetwork(torch.nn.Module):
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1:]] = weights_sd[key]
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
logger.info(f"weights are merged")
logger.info("weights are merged")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
@@ -471,8 +440,29 @@ class LoRANetwork(torch.nn.Module):
def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}}
reg_groups = {}
reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
for lora in loras:
matched_reg_lr = None
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
if re.fullmatch(regex_str, lora.original_name):
matched_reg_lr = (i, reg_lr)
logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
break
for name, param in lora.named_parameters():
if matched_reg_lr is not None:
reg_idx, reg_lr = matched_reg_lr
group_key = f"reg_lr_{reg_idx}"
if group_key not in reg_groups:
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
if loraplus_ratio is not None and "lora_up" in name:
reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param
else:
reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param
continue
if loraplus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
@@ -480,6 +470,23 @@ class LoRANetwork(torch.nn.Module):
params = []
descriptions = []
for group_key, group in reg_groups.items():
reg_lr = group["lr"]
for key in ("lora", "plus"):
param_data = {"params": group[key].values()}
if len(param_data["params"]) == 0:
continue
if key == "plus":
param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr
else:
param_data["lr"] = reg_lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
desc = f"reg_lr_{group_key.split('_')[-1]}"
descriptions.append(desc + (" plus" if key == "plus" else ""))
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
@@ -498,10 +505,7 @@ class LoRANetwork(torch.nn.Module):
if self.text_encoder_loras:
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
te1_loras = [
lora for lora in self.text_encoder_loras
if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)
]
te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)]
if len(te1_loras) > 0:
logger.info(f"Text Encoder 1 (Qwen3): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_ratio)

View File

@@ -2,7 +2,7 @@
Diagnostic script to test Anima latent & text encoder caching independently.
Usage:
python test_anima_cache.py \
python manual_test_anima_cache.py \
--image_dir /path/to/images \
--qwen3_path /path/to/qwen3 \
--vae_path /path/to/vae.safetensors \

View File

@@ -470,7 +470,7 @@ class NetworkTrainer:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss = loss.mean(dim=list(range(1, loss.ndim))) # mean over all dims except batch
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights