This commit is contained in:
Kohya S.
2026-02-08 06:36:00 +00:00
committed by GitHub
15 changed files with 3336 additions and 763 deletions

View File

@@ -32,6 +32,7 @@ hime="hime"
OT="OT"
byt="byt"
tak="tak"
temperal="temperal"
[files]
extend-exclude = ["_typos.toml", "venv"]
extend-exclude = ["_typos.toml", "venv", "configs"]

1044
anima_minimal_inference.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -49,35 +49,32 @@ def train(args):
args.skip_cache_check = args.skip_latents_validity_check
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
)
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
args.cache_text_encoder_outputs = True
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
if getattr(args, 'unsloth_offload_checkpointing', False):
if getattr(args, "unsloth_offload_checkpointing", False):
if not args.gradient_checkpointing:
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
assert not args.cpu_offload_checkpointing, \
"Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not getattr(args, 'unsloth_offload_checkpointing', False), \
"blocks_to_swap is not supported with unsloth_offload_checkpointing"
assert (args.blocks_to_swap is None or args.blocks_to_swap == 0) or not getattr(
args, "unsloth_offload_checkpointing", False
), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
# Flash attention: validate availability
if getattr(args, 'flash_attn', False):
if getattr(args, "flash_attn", False):
try:
import flash_attn # noqa: F401
logger.info("Flash Attention enabled for DiT blocks")
except ImportError:
logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
@@ -104,9 +101,7 @@ def train(args):
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0}".format(", ".join(ignored))
)
logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored)))
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
@@ -150,7 +145,7 @@ def train(args):
# Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
# dataset-level caption dropout, so we save the rate and zero out subset-level
# caption_dropout_rate to allow text encoder output caching.
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0)
if caption_dropout_rate > 0:
logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
for dataset in train_dataset_group.datasets:
@@ -175,9 +170,7 @@ def train(args):
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used"
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used"
if args.cache_text_encoder_outputs:
assert (
@@ -193,7 +186,7 @@ def train(args):
# parse transformer_dtype
transformer_dtype = None
if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None:
if hasattr(args, "transformer_dtype") and args.transformer_dtype is not None:
transformer_dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
@@ -203,12 +196,8 @@ def train(args):
# Load tokenizers and set strategies
logger.info("Loading tokenizers...")
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(
args.qwen3_path, dtype=weight_dtype, device="cpu"
)
t5_tokenizer = anima_utils.load_t5_tokenizer(
getattr(args, 't5_tokenizer_path', None)
)
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=weight_dtype, device="cpu")
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
# Set tokenize strategy
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
@@ -220,7 +209,7 @@ def train(args):
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# Set text encoding strategy
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0)
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
dropout_rate=caption_dropout_rate,
)
@@ -266,10 +255,8 @@ def train(args):
)
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
if caption_dropout_rate > 0.0:
with accelerator.autocast():
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
with accelerator.autocast():
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
accelerator.wait_for_everyone()
@@ -299,17 +286,17 @@ def train(args):
dtype=weight_dtype,
device="cpu",
transformer_dtype=transformer_dtype,
llm_adapter_path=getattr(args, 'llm_adapter_path', None),
disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False),
llm_adapter_path=getattr(args, "llm_adapter_path", None),
disable_mmap=getattr(args, "disable_mmap_load_safetensors", False),
)
if args.gradient_checkpointing:
dit.enable_gradient_checkpointing(
cpu_offload=args.cpu_offload_checkpointing,
unsloth_offload=getattr(args, 'unsloth_offload_checkpointing', False),
unsloth_offload=getattr(args, "unsloth_offload_checkpointing", False),
)
if getattr(args, 'flash_attn', False):
if getattr(args, "flash_attn", False):
dit.set_flash_attn(True)
train_dit = args.learning_rate != 0
@@ -335,11 +322,11 @@ def train(args):
param_groups = anima_train_utils.get_anima_param_groups(
dit,
base_lr=args.learning_rate,
self_attn_lr=getattr(args, 'self_attn_lr', None),
cross_attn_lr=getattr(args, 'cross_attn_lr', None),
mlp_lr=getattr(args, 'mlp_lr', None),
mod_lr=getattr(args, 'mod_lr', None),
llm_adapter_lr=getattr(args, 'llm_adapter_lr', None),
self_attn_lr=getattr(args, "self_attn_lr", None),
cross_attn_lr=getattr(args, "cross_attn_lr", None),
mlp_lr=getattr(args, "mlp_lr", None),
mod_lr=getattr(args, "mod_lr", None),
llm_adapter_lr=getattr(args, "llm_adapter_lr", None),
)
else:
param_groups = []
@@ -366,8 +353,8 @@ def train(args):
# Build param_id → lr mapping from param_groups to propagate per-component LRs
param_lr_map = {}
for group in param_groups:
for p in group['params']:
param_lr_map[id(p)] = group['lr']
for p in group["params"]:
param_lr_map[id(p)] = group["lr"]
grouped_params = []
param_group = {}
@@ -557,9 +544,7 @@ def train(args):
accelerator.print(f" num examples: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch: {len(train_dataloader)}")
accelerator.print(f" num epochs: {num_train_epochs}")
accelerator.print(
f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
accelerator.print(f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps: {args.max_train_steps}")
@@ -580,6 +565,7 @@ def train(args):
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
import wandb
wandb.define_metric("epoch")
wandb.define_metric("loss/epoch", step_metric="epoch")
@@ -589,8 +575,16 @@ def train(args):
# For --sample_at_first
optimizer_eval_fn()
anima_train_utils.sample_images(
accelerator, args, 0, global_step, dit, vae, vae_scale,
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
accelerator,
args,
0,
global_step,
dit,
vae,
vae_scale,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs,
)
optimizer_train_fn()
@@ -600,7 +594,9 @@ def train(args):
# Show model info
unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None
if unwrapped_dit is not None:
logger.info(f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}")
logger.info(
f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}"
)
if qwen3_text_encoder is not None:
logger.info(f"qwen3 device: {next(qwen3_text_encoder.parameters()).device}")
if vae is not None:
@@ -640,9 +636,7 @@ def train(args):
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Cached outputs
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(
*text_encoder_outputs_list
)
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list
else:
# Encode on-the-fly
@@ -678,10 +672,7 @@ def train(args):
bs = latents.shape[0]
h_latent = latents.shape[-2]
w_latent = latents.shape[-1]
padding_mask = torch.zeros(
bs, 1, h_latent, w_latent,
dtype=weight_dtype, device=accelerator.device
)
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
if is_swapping_blocks:
@@ -708,9 +699,7 @@ def train(args):
# Loss
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None)
loss = train_util.conditional_loss(
model_pred.float(), target.float(), args.loss_type, "none", huber_c
)
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3, 4]) # (B, C, T, H, W) -> (B,)
@@ -748,8 +737,16 @@ def train(args):
optimizer_eval_fn()
anima_train_utils.sample_images(
accelerator, args, None, global_step, dit, vae, vae_scale,
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
accelerator,
args,
None,
global_step,
dit,
vae,
vae_scale,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs,
)
@@ -773,8 +770,10 @@ def train(args):
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs_with_names(
logs, lr_scheduler, args.optimizer_type,
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else []
logs,
lr_scheduler,
args.optimizer_type,
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [],
)
accelerator.log(logs, step=global_step)
@@ -807,8 +806,16 @@ def train(args):
)
anima_train_utils.sample_images(
accelerator, args, epoch + 1, global_step, dit, vae, vae_scale,
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
accelerator,
args,
epoch + 1,
global_step,
dit,
vae,
vae_scale,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs,
)

View File

@@ -39,17 +39,15 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
val_dataset_group: Optional[train_util.DatasetGroup],
):
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
)
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
args.cache_text_encoder_outputs = True
# Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
# dataset-level caption dropout, so zero out subset-level rates to allow caching.
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0)
if caption_dropout_rate > 0:
logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
if hasattr(train_dataset_group, 'datasets'):
if hasattr(train_dataset_group, "datasets"):
for dataset in train_dataset_group.datasets:
for subset in dataset.subsets:
subset.caption_dropout_rate = 0.0
@@ -63,26 +61,28 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
if getattr(args, 'unsloth_offload_checkpointing', False):
if getattr(args, "unsloth_offload_checkpointing", False):
if not args.gradient_checkpointing:
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
assert not args.cpu_offload_checkpointing, \
"Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert (
not args.cpu_offload_checkpointing
), "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
# Flash attention: validate availability
if getattr(args, 'flash_attn', False):
if getattr(args, "flash_attn", False):
try:
import flash_attn # noqa: F401
logger.info("Flash Attention enabled for DiT blocks")
except ImportError:
logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
args.flash_attn = False
if getattr(args, 'blockwise_fused_optimizers', False):
if getattr(args, "blockwise_fused_optimizers", False):
raise ValueError("blockwise_fused_optimizers is not supported with LoRA/NetworkTrainer")
train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8
@@ -92,14 +92,12 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
def load_target_model(self, args, weight_dtype, accelerator):
# Load Qwen3 text encoder (tokenizers already loaded in get_tokenize_strategy)
logger.info("Loading Qwen3 text encoder...")
self.qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(
args.qwen3_path, dtype=weight_dtype, device="cpu"
)
self.qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=weight_dtype, device="cpu")
self.qwen3_text_encoder.eval()
# Parse transformer_dtype
transformer_dtype = None
if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None:
if hasattr(args, "transformer_dtype") and args.transformer_dtype is not None:
transformer_dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
@@ -114,18 +112,18 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
dtype=weight_dtype,
device="cpu",
transformer_dtype=transformer_dtype,
llm_adapter_path=getattr(args, 'llm_adapter_path', None),
disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False),
llm_adapter_path=getattr(args, "llm_adapter_path", None),
disable_mmap=getattr(args, "disable_mmap_load_safetensors", False),
)
# Flash attention
if getattr(args, 'flash_attn', False):
if getattr(args, "flash_attn", False):
dit.set_flash_attn(True)
# Store unsloth preference so that when the base NetworkTrainer calls
# dit.enable_gradient_checkpointing(cpu_offload=...), we can override to use unsloth.
# The base trainer only passes cpu_offload, so we store the flag on the model.
self._use_unsloth_offload_checkpointing = getattr(args, 'unsloth_offload_checkpointing', False)
self._use_unsloth_offload_checkpointing = getattr(args, "unsloth_offload_checkpointing", False)
# Block swap
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
@@ -135,9 +133,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
# Load VAE
logger.info("Loading Anima VAE...")
self.vae, vae_mean, vae_std, self.vae_scale = anima_utils.load_anima_vae(
args.vae_path, dtype=weight_dtype, device="cpu"
)
self.vae, vae_mean, vae_std, self.vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu")
# Return format: (model_type, text_encoders, vae, unet)
return "anima", [self.qwen3_text_encoder], self.vae, dit
@@ -146,7 +142,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
# Load tokenizers from paths (called before load_target_model, so self.qwen3_tokenizer isn't set yet)
self.tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
qwen3_path=args.qwen3_path,
t5_tokenizer_path=getattr(args, 't5_tokenizer_path', None),
t5_tokenizer_path=getattr(args, "t5_tokenizer_path", None),
qwen3_max_length=args.qwen3_max_token_length,
t5_max_length=args.t5_max_token_length,
)
@@ -159,12 +155,10 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
return [tokenize_strategy.qwen3_tokenizer]
def get_latents_caching_strategy(self, args):
return strategy_anima.AnimaLatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return strategy_anima.AnimaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check)
def get_text_encoding_strategy(self, args):
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0)
self.text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
dropout_rate=caption_dropout_rate,
)
@@ -237,12 +231,10 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
self.sample_prompts_te_outputs = sample_prompts_te_outputs
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
text_encoding_strategy_for_uncond = strategy_base.TextEncodingStrategy.get_strategy()
if caption_dropout_rate > 0.0:
tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy()
with accelerator.autocast():
text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders)
tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy()
with accelerator.autocast():
text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders)
accelerator.wait_for_everyone()
@@ -264,8 +256,16 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
qwen3_te = te[0] if te is not None else None
anima_train_utils.sample_images(
accelerator, args, epoch, global_step, unet, vae, self.vae_scale,
qwen3_te, self.tokenize_strategy, self.text_encoding_strategy,
accelerator,
args,
epoch,
global_step,
unet,
vae,
self.vae_scale,
qwen3_te,
self.tokenize_strategy,
self.text_encoding_strategy,
self.sample_prompts_te_outputs,
)
@@ -329,10 +329,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
bs = latents.shape[0]
h_latent = latents.shape[-2]
w_latent = latents.shape[-1]
padding_mask = torch.zeros(
bs, 1, h_latent, w_latent,
dtype=weight_dtype, device=accelerator.device
)
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
# Prepare block swap
if self.is_swapping_blocks:
@@ -354,9 +351,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
target = noise - latents
# Loss weighting
weighting = anima_train_utils.compute_loss_weighting_for_anima(
weighting_scheme=args.weighting_scheme, sigmas=sigmas
)
weighting = anima_train_utils.compute_loss_weighting_for_anima(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# Differential output preservation
if "custom_attributes" in batch:
@@ -386,10 +381,22 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
return model_pred, target, timesteps, weighting
def process_batch(
self, batch, text_encoders, unet, network, vae, noise_scheduler,
vae_dtype, weight_dtype, accelerator, args,
text_encoding_strategy, tokenize_strategy,
is_train=True, train_text_encoder=True, train_unet=True,
self,
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=True,
train_unet=True,
) -> torch.Tensor:
"""Override base process_batch for 5D video latents (B, C, T, H, W).
@@ -424,13 +431,21 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
# Text encoder conditions
text_encoder_conds = []
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
anima_text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = text_encoding_strategy
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list
caption_dropout_rates = text_encoder_outputs_list[-1]
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
# Apply caption dropout to cached outputs
text_encoder_conds = anima_text_encoding_strategy.drop_cached_text_encoder_outputs(
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
)
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
# TODO stop gradient for uncond embeddings when using caption dropout?
encoded_text_encoder_conds = anima_text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
@@ -441,13 +456,23 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# Fill in only missing parts (partial caching)
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
args, accelerator, noise_scheduler, latents, batch,
text_encoder_conds, unet, network, weight_dtype, train_unet, is_train=is_train,
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet,
network,
weight_dtype,
train_unet,
is_train=is_train,
)
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
@@ -479,8 +504,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
def update_metadata(self, metadata, args):
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
metadata["ss_timestep_sample_method"] = getattr(args, 'timestep_sample_method', 'logit_normal')
metadata["ss_sigmoid_scale"] = getattr(args, 'sigmoid_scale', 1.0)
metadata["ss_timestep_sample_method"] = getattr(args, "timestep_sample_method", "logit_normal")
metadata["ss_sigmoid_scale"] = getattr(args, "sigmoid_scale", 1.0)
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs

View File

@@ -118,7 +118,7 @@ accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
--optimizer_type="AdamW8bit" \
--lr_scheduler="constant" \
--timestep_sample_method="logit_normal" \
--discrete_flow_shift=3.0 \
--discrete_flow_shift=1.0 \
--max_train_epochs=10 \
--save_every_n_epochs=1 \
--mixed_precision="bf16" \
@@ -162,7 +162,7 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
* `--timestep_sample_method=<choice>`
- Timestep sampling method. Choose from `logit_normal` (default) or `uniform`.
* `--discrete_flow_shift=<float>`
- Shift for the timestep distribution in Rectified Flow training. Default `3.0`. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
- Shift for the timestep distribution in Rectified Flow training. Default `1.0`. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`. 1.0 means no shift.
* `--sigmoid_scale=<float>`
- Scale factor for `logit_normal` timestep sampling. Default `1.0`.
* `--qwen3_max_token_length=<integer>`

View File

@@ -13,11 +13,10 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from library import custom_offloading_utils
from library import custom_offloading_utils, attention
from library.device_utils import clean_memory_on_device
def to_device(x, device):
if isinstance(x, torch.Tensor):
return x.to(device)
@@ -39,11 +38,13 @@ def to_cpu(x):
else:
return x
# Unsloth Offloaded Gradient Checkpointing
# Based on Unsloth Zoo by Daniel Han-Chen & the Unsloth team
try:
from deepspeed.runtime.activation_checkpointing.checkpointing import detach_variable
except ImportError:
def detach_variable(inputs, device=None):
"""Detach tensors from computation graph, optionally moving to a device.
@@ -80,11 +81,11 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
"""
@staticmethod
@torch.amp.custom_fwd(device_type='cuda')
@torch.amp.custom_fwd(device_type="cuda")
def forward(ctx, forward_function, hidden_states, *args):
# Remember the original device for backward pass (multi-GPU support)
ctx.input_device = hidden_states.device
saved_hidden_states = hidden_states.to('cpu', non_blocking=True)
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
@@ -96,7 +97,7 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
return output
@staticmethod
@torch.amp.custom_bwd(device_type='cuda')
@torch.amp.custom_bwd(device_type="cuda")
def backward(ctx, *grads):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to(ctx.input_device, non_blocking=True).detach()
@@ -108,8 +109,9 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
output_tensors = []
grad_tensors = []
for out, grad in zip(outputs if isinstance(outputs, tuple) else (outputs,),
grads if isinstance(grads, tuple) else (grads,)):
for out, grad in zip(
outputs if isinstance(outputs, tuple) else (outputs,), grads if isinstance(grads, tuple) else (grads,)
):
if isinstance(out, torch.Tensor) and out.requires_grad:
output_tensors.append(out)
grad_tensors.append(grad)
@@ -123,24 +125,24 @@ def unsloth_checkpoint(function, *args):
return UnslothOffloadedGradientCheckpointer.apply(function, *args)
# Flash Attention support
try:
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
FLASH_ATTN_AVAILABLE = True
except ImportError:
_flash_attn_func = None
FLASH_ATTN_AVAILABLE = False
# # Flash Attention support
# try:
# from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
# FLASH_ATTN_AVAILABLE = True
# except ImportError:
# _flash_attn_func = None
# FLASH_ATTN_AVAILABLE = False
def flash_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
"""Computes multi-head attention using Flash Attention.
# def flash_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
# """Computes multi-head attention using Flash Attention.
Input format: (batch, seq_len, n_heads, head_dim)
Output format: (batch, seq_len, n_heads * head_dim) — matches torch_attention_op output.
"""
# flash_attn_func expects (B, S, H, D) and returns (B, S, H, D)
out = _flash_attn_func(q_B_S_H_D, k_B_S_H_D, v_B_S_H_D)
return rearrange(out, "b s h d -> b s (h d)")
# Input format: (batch, seq_len, n_heads, head_dim)
# Output format: (batch, seq_len, n_heads * head_dim) — matches torch_attention_op output.
# """
# # flash_attn_func expects (B, S, H, D) and returns (B, S, H, D)
# out = _flash_attn_func(q_B_S_H_D, k_B_S_H_D, v_B_S_H_D)
# return rearrange(out, "b s h d -> b s (h d)")
from .utils import setup_logging
@@ -174,14 +176,10 @@ def _apply_rotary_pos_emb_base(
if start_positions is not None:
max_offset = torch.max(start_positions)
assert (
max_offset + cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
assert max_offset + cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1)
assert (
cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
assert cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
@@ -205,13 +203,9 @@ def apply_rotary_pos_emb(
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
) -> torch.Tensor:
assert not (
cp_size > 1 and start_positions is not None
), "start_positions != None with CP SIZE > 1 is not supported!"
assert not (cp_size > 1 and start_positions is not None), "start_positions != None with CP SIZE > 1 is not supported!"
assert (
tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
assert tensor_format != "thd" or cu_seqlens is not None, "cu_seqlens must not be None when tensor_format is 'thd'."
assert fused == False
@@ -223,9 +217,7 @@ def apply_rotary_pos_emb(
_apply_rotary_pos_emb_base(
x.unsqueeze(1),
freqs,
start_positions=(
start_positions[idx : idx + 1] if start_positions is not None else None
),
start_positions=(start_positions[idx : idx + 1] if start_positions is not None else None),
interleaved=interleaved,
)
for idx, x in enumerate(torch.split(t, seqlens))
@@ -262,7 +254,7 @@ class RMSNorm(torch.nn.Module):
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
@torch.amp.autocast(device_type='cuda', dtype=torch.float32)
@torch.amp.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
@@ -308,9 +300,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
result_B_S_HD = rearrange(
F.scaled_dot_product_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D), "b h ... l -> b ... (h l)"
)
result_B_S_HD = rearrange(F.scaled_dot_product_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D), "b h ... l -> b ... (h l)")
return result_B_S_HD
@@ -399,18 +389,23 @@ class Attention(nn.Module):
return q, k, v
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
result = self.attn_op(q, k, v) # [B, S, H, D]
return self.output_dropout(self.output_proj(result))
# def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
# result = self.attn_op(q, k, v) # [B, S, H, D]
# return self.output_dropout(self.output_proj(result))
def forward(
self,
x: torch.Tensor,
attn_params: attention.AttentionParams,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
return self.compute_attention(q, k, v)
# return self.compute_attention(q, k, v)
qkv = [q, k, v]
del q, k, v
result = attention.attention(qkv, attn_params=attn_params)
return self.output_dropout(self.output_proj(result))
# Positional Embeddings
@@ -484,12 +479,8 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
dim_t = self._dim_t
self.seq = torch.arange(max(self.max_h, self.max_w, self.max_t)).float().to(self.dim_spatial_range.device)
self.dim_spatial_range = (
torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h
)
self.dim_temporal_range = (
torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t
)
self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h
self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t
def generate_embeddings(
self,
@@ -679,9 +670,7 @@ class FourierFeatures(nn.Module):
def reset_parameters(self) -> None:
generator = torch.Generator()
generator.manual_seed(0)
self.freqs = (
2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
)
self.freqs = 2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device)
def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
@@ -713,9 +702,7 @@ class PatchEmbed(nn.Module):
m=spatial_patch_size,
n=spatial_patch_size,
),
nn.Linear(
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False
),
nn.Linear(in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False),
)
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
@@ -765,9 +752,7 @@ class FinalLayer(nn.Module):
nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False),
)
else:
self.adaln_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False)
)
self.adaln_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False))
self.init_weights()
@@ -790,9 +775,9 @@ class FinalLayer(nn.Module):
):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
shift_B_T_D, scale_B_T_D = (
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
).chunk(2, dim=-1)
shift_B_T_D, scale_B_T_D = (self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]).chunk(
2, dim=-1
)
else:
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
@@ -833,7 +818,11 @@ class Block(nn.Module):
self.layer_norm_cross_attn = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
self.cross_attn = Attention(
x_dim, context_dim, num_heads, x_dim // num_heads, qkv_format="bshd",
x_dim,
context_dim,
num_heads,
x_dim // num_heads,
qkv_format="bshd",
)
self.layer_norm_mlp = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
@@ -904,6 +893,7 @@ class Block(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -919,13 +909,13 @@ class Block(nn.Module):
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk(
3, dim=-1
)
else:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
emb_B_T_D
).chunk(3, dim=-1)
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(emb_B_T_D).chunk(
3, dim=-1
)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
emb_B_T_D
).chunk(3, dim=-1)
@@ -954,11 +944,14 @@ class Block(nn.Module):
result = rearrange(
self.self_attn(
rearrange(normalized_x, "b t h w d -> b (t h w) d"),
attn_params,
None,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
t=T, h=H, w=W,
t=T,
h=H,
w=W,
)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result
@@ -967,11 +960,14 @@ class Block(nn.Module):
result = rearrange(
self.cross_attn(
rearrange(normalized_x, "b t h w d -> b (t h w) d"),
attn_params,
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
t=T, h=H, w=W,
t=T,
h=H,
w=W,
)
x_B_T_H_W_D = result * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
@@ -987,6 +983,7 @@ class Block(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -996,8 +993,13 @@ class Block(nn.Module):
# Unsloth: async non-blocking CPU RAM offload (fastest offload method)
return unsloth_checkpoint(
self._forward,
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
x_B_T_H_W_D,
emb_B_T_D,
crossattn_emb,
attn_params,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
)
elif self.cpu_offload_checkpointing:
# Standard cpu offload: blocking transfers
@@ -1008,26 +1010,42 @@ class Block(nn.Module):
device_inputs = to_device(inputs, device)
outputs = func(*device_inputs)
return to_cpu(outputs)
return custom_forward
return torch_checkpoint(
create_custom_forward(self._forward),
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
x_B_T_H_W_D,
emb_B_T_D,
crossattn_emb,
attn_params,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
use_reentrant=False,
)
else:
# Standard gradient checkpointing (no offload)
return torch_checkpoint(
self._forward,
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
x_B_T_H_W_D,
emb_B_T_D,
crossattn_emb,
attn_params,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
use_reentrant=False,
)
else:
return self._forward(
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
x_B_T_H_W_D,
emb_B_T_D,
crossattn_emb,
attn_params,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
)
@@ -1069,6 +1087,8 @@ class MiniTrainDIT(nn.Module):
extra_t_extrapolation_ratio: float = 1.0,
rope_enable_fps_modulation: bool = True,
use_llm_adapter: bool = False,
attn_mode: str = "torch",
split_attn: bool = False,
) -> None:
super().__init__()
self.max_img_h = max_img_h
@@ -1097,6 +1117,9 @@ class MiniTrainDIT(nn.Module):
self.rope_enable_fps_modulation = rope_enable_fps_modulation
self.use_llm_adapter = use_llm_adapter
self.attn_mode = attn_mode
self.split_attn = split_attn
# Block swap support
self.blocks_to_swap = None
self.offloader: Optional[custom_offloading_utils.ModelOffloader] = None
@@ -1156,7 +1179,6 @@ class MiniTrainDIT(nn.Module):
self.final_layer.init_weights()
self.t_embedding_norm.reset_parameters()
def enable_gradient_checkpointing(self, cpu_offload: bool = False, unsloth_offload: bool = False):
for block in self.blocks:
block.enable_gradient_checkpointing(cpu_offload=cpu_offload, unsloth_offload=unsloth_offload)
@@ -1169,18 +1191,17 @@ class MiniTrainDIT(nn.Module):
def device(self):
return next(self.parameters()).device
# def set_flash_attn(self, use_flash_attn: bool):
# """Toggle flash attention for all DiT blocks (self-attn + cross-attn).
def set_flash_attn(self, use_flash_attn: bool):
"""Toggle flash attention for all DiT blocks (self-attn + cross-attn).
LLM Adapter attention is NOT affected (it uses attention masks incompatible with flash_attn).
"""
if use_flash_attn and not FLASH_ATTN_AVAILABLE:
raise ImportError("flash_attn package is required for --flash_attn but is not installed")
attn_op = flash_attention_op if use_flash_attn else torch_attention_op
for block in self.blocks:
block.self_attn.attn_op = attn_op
block.cross_attn.attn_op = attn_op
# LLM Adapter attention is NOT affected (it uses attention masks incompatible with flash_attn).
# """
# if use_flash_attn and not FLASH_ATTN_AVAILABLE:
# raise ImportError("flash_attn package is required for --flash_attn but is not installed")
# attn_op = flash_attention_op if use_flash_attn else torch_attention_op
# for block in self.blocks:
# block.self_attn.attn_op = attn_op
# block.cross_attn.attn_op = attn_op
def build_patch_embed(self) -> None:
in_channels = self.in_channels + 1 if self.concat_padding_mask else self.in_channels
@@ -1232,9 +1253,7 @@ class MiniTrainDIT(nn.Module):
padding_mask = transforms.functional.resize(
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
x_B_C_T_H_W = torch.cat(
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
)
x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1)
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb:
@@ -1258,7 +1277,6 @@ class MiniTrainDIT(nn.Module):
)
return x_B_C_Tt_Hp_Wp
def enable_block_swap(self, num_blocks: int, device: torch.device):
self.blocks_to_swap = num_blocks
@@ -1266,9 +1284,7 @@ class MiniTrainDIT(nn.Module):
self.blocks_to_swap <= self.num_blocks - 2
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
self.offloader = custom_offloading_utils.ModelOffloader(
self.blocks, self.blocks_to_swap, device
)
self.offloader = custom_offloading_utils.ModelOffloader(self.blocks, self.blocks_to_swap, device)
logger.info(f"Anima: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
def move_to_device_except_swap_blocks(self, device: torch.device):
@@ -1310,7 +1326,7 @@ class MiniTrainDIT(nn.Module):
t5_attn_mask: Optional T5 attention mask
"""
# Run LLM adapter inside forward for correct DDP gradient synchronization
if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, 'llm_adapter'):
if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, "llm_adapter"):
crossattn_emb = self.llm_adapter(
source_hidden_states=crossattn_emb,
target_input_ids=t5_input_ids,
@@ -1337,6 +1353,8 @@ class MiniTrainDIT(nn.Module):
"extra_per_block_pos_emb": extra_pos_emb,
}
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx)
@@ -1345,6 +1363,7 @@ class MiniTrainDIT(nn.Module):
x_B_T_H_W_D,
t_embedding_B_T_D,
crossattn_emb,
attn_params,
**block_kwargs,
)
@@ -1485,24 +1504,36 @@ class LLMAdapterTransformerBlock(nn.Module):
self.norm_mlp = nn.LayerNorm(model_dim) if layer_norm else LLMAdapterRMSNorm(model_dim)
self.mlp = nn.Sequential(
nn.Linear(model_dim, int(model_dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(model_dim * mlp_ratio), model_dim)
nn.Linear(model_dim, int(model_dim * mlp_ratio)), nn.GELU(), nn.Linear(int(model_dim * mlp_ratio), model_dim)
)
def forward(self, x, context, target_attention_mask=None, source_attention_mask=None,
position_embeddings=None, position_embeddings_context=None):
def forward(
self,
x,
context,
target_attention_mask=None,
source_attention_mask=None,
position_embeddings=None,
position_embeddings_context=None,
):
if self.has_self_attn:
normed = self.norm_self_attn(x)
attn_out = self.self_attn(normed, mask=target_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings)
attn_out = self.self_attn(
normed,
mask=target_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings,
)
x = x + attn_out
normed = self.norm_cross_attn(x)
attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context)
attn_out = self.cross_attn(
normed,
mask=source_attention_mask,
context=context,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
x = x + attn_out
x = x + self.mlp(self.norm_mlp(x))
@@ -1518,8 +1549,9 @@ class LLMAdapter(nn.Module):
Uses T5 token IDs as target input, embeds them, and cross-attends to Qwen3 hidden states.
"""
def __init__(self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16,
embed=None, self_attn=False, layer_norm=False):
def __init__(
self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16, embed=None, self_attn=False, layer_norm=False
):
super().__init__()
if embed is not None:
self.embed = nn.Embedding.from_pretrained(embed.weight)
@@ -1530,11 +1562,12 @@ class LLMAdapter(nn.Module):
else:
self.in_proj = nn.Identity()
self.rotary_emb = AdapterRotaryEmbedding(model_dim // num_heads)
self.blocks = nn.ModuleList([
LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads,
self_attn=self_attn, layer_norm=layer_norm)
for _ in range(num_layers)
])
self.blocks = nn.ModuleList(
[
LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, self_attn=self_attn, layer_norm=layer_norm)
for _ in range(num_layers)
]
)
self.out_proj = nn.Linear(model_dim, target_dim)
self.norm = LLMAdapterRMSNorm(target_dim)
@@ -1556,41 +1589,119 @@ class LLMAdapter(nn.Module):
position_embeddings = self.rotary_emb(x, position_ids)
position_embeddings_context = self.rotary_emb(x, position_ids_context)
for block in self.blocks:
x = block(x, context, target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context)
x = block(
x,
context,
target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
return self.norm(self.out_proj(x))
class Anima(nn.Module):
"""
Wrapper class for the MiniTrainDIT and LLM Adapter.
"""
LATENT_CHANNELS = 16
def __init__(self, dit_config: dict):
super().__init__()
self.net = MiniTrainDIT(**dit_config)
@property
def device(self):
return self.net.device
@property
def dtype(self):
return self.net.dtype
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return self.net(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs)
def preprocess_text_embeds(
self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None
):
if target_input_ids is not None:
return self.net.llm_adapter(
source_hidden_states,
target_input_ids,
target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
)
else:
return source_hidden_states
# VAE Wrapper
# VAE normalization constants
ANIMA_VAE_MEAN = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
]
ANIMA_VAE_STD = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
]
# DiT config detection from state_dict
KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer']
KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"]
def get_dit_config(state_dict, key_prefix=''):
def get_dit_config(state_dict, key_prefix=""):
"""Derive DiT configuration from state_dict weight shapes."""
dit_config = {}
dit_config["max_img_h"] = 512
dit_config["max_img_w"] = 512
dit_config["max_frames"] = 128
concat_padding_mask = True
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
dit_config["in_channels"] = (state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[1] // 4) - int(
concat_padding_mask
)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
dit_config["model_channels"] = state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[0]
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["crossattn_emb_channels"] = 1024
dit_config["pos_emb_cls"] = "rope3d"

View File

@@ -32,6 +32,7 @@ from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas
# Anima-specific training arguments
def add_anima_training_arguments(parser: argparse.ArgumentParser):
"""Add Anima-specific training arguments to the parser."""
parser.add_argument(
@@ -169,20 +170,20 @@ def get_noisy_model_input_and_timesteps(
"""
bs = latents.shape[0]
timestep_sample_method = getattr(args, 'timestep_sample_method', 'logit_normal')
sigmoid_scale = getattr(args, 'sigmoid_scale', 1.0)
shift = getattr(args, 'discrete_flow_shift', 1.0)
timestep_sample_method = getattr(args, "timestep_sample_method", "logit_normal")
sigmoid_scale = getattr(args, "sigmoid_scale", 1.0)
shift = getattr(args, "discrete_flow_shift", 1.0)
if timestep_sample_method == 'logit_normal':
if timestep_sample_method == "logit_normal":
dist = torch.distributions.normal.Normal(0, 1)
elif timestep_sample_method == 'uniform':
elif timestep_sample_method == "uniform":
dist = torch.distributions.uniform.Uniform(0, 1)
else:
raise NotImplementedError(f"Unknown timestep_sample_method: {timestep_sample_method}")
t = dist.sample((bs,)).to(device)
if timestep_sample_method == 'logit_normal':
if timestep_sample_method == "logit_normal":
t = t * sigmoid_scale
t = torch.sigmoid(t)
@@ -196,10 +197,10 @@ def get_noisy_model_input_and_timesteps(
# Create noisy input: (1 - t) * latents + t * noise
t_expanded = t.view(-1, *([1] * (latents.ndim - 1)))
ip_noise_gamma = getattr(args, 'ip_noise_gamma', None)
ip_noise_gamma = getattr(args, "ip_noise_gamma", None)
if ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if getattr(args, 'ip_noise_gamma_random_strength', False):
if getattr(args, "ip_noise_gamma_random_strength", False):
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * ip_noise_gamma
noisy_model_input = (1 - t_expanded) * latents + t_expanded * (noise + ip_noise_gamma * xi)
else:
@@ -213,6 +214,7 @@ def get_noisy_model_input_and_timesteps(
# Loss weighting
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
"""Compute loss weighting for Anima training.
@@ -276,15 +278,15 @@ def get_anima_param_groups(
# Store original name for debugging
p.original_name = name
if 'llm_adapter' in name:
if "llm_adapter" in name:
llm_adapter_params.append(p)
elif '.self_attn' in name:
elif ".self_attn" in name:
self_attn_params.append(p)
elif '.cross_attn' in name:
elif ".cross_attn" in name:
cross_attn_params.append(p)
elif '.mlp' in name:
elif ".mlp" in name:
mlp_params.append(p)
elif '.adaln_modulation' in name:
elif ".adaln_modulation" in name:
mod_params.append(p)
else:
base_params.append(p)
@@ -311,9 +313,9 @@ def get_anima_param_groups(
p.requires_grad_(False)
logger.info(f" Frozen {name} params ({len(params)} parameters)")
elif len(params) > 0:
param_groups.append({'params': params, 'lr': lr})
param_groups.append({"params": params, "lr": lr})
total_trainable = sum(p.numel() for group in param_groups for p in group['params'] if p.requires_grad)
total_trainable = sum(p.numel() for group in param_groups for p in group["params"] if p.requires_grad)
logger.info(f"Total trainable parameters: {total_trainable:,}")
return param_groups
@@ -328,10 +330,9 @@ def save_anima_model_on_train_end(
dit: anima_models.MiniTrainDIT,
):
"""Save Anima model at the end of training."""
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True
)
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
dit_sd = dit.state_dict()
# Save with 'net.' prefix for ComfyUI compatibility
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
@@ -350,10 +351,9 @@ def save_anima_model_on_epoch_end_or_stepwise(
dit: anima_models.MiniTrainDIT,
):
"""Save Anima model at epoch end or specific steps."""
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True
)
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
dit_sd = dit.state_dict()
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
@@ -410,9 +410,7 @@ def do_sample(
generator = torch.manual_seed(seed)
else:
generator = None
noise = torch.randn(
latent.size(), dtype=torch.float32, generator=generator, device="cpu"
).to(dtype).to(device)
noise = torch.randn(latent.size(), dtype=torch.float32, generator=generator, device="cpu").to(dtype).to(device)
# Timestep schedule: linear from 1.0 to 0.0
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
@@ -512,10 +510,20 @@ def sample_images(
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
_sample_image_inference(
accelerator, args, dit, text_encoder, vae, vae_scale,
tokenize_strategy, text_encoding_strategy,
save_dir, prompt_dict, epoch, steps,
sample_prompts_te_outputs, prompt_replacement,
accelerator,
args,
dit,
text_encoder,
vae,
vae_scale,
tokenize_strategy,
text_encoding_strategy,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
# Restore RNG state
@@ -527,10 +535,20 @@ def sample_images(
def _sample_image_inference(
accelerator, args, dit, text_encoder, vae, vae_scale,
tokenize_strategy, text_encoding_strategy,
save_dir, prompt_dict, epoch, steps,
sample_prompts_te_outputs, prompt_replacement,
accelerator,
args,
dit,
text_encoder,
vae,
vae_scale,
tokenize_strategy,
text_encoding_strategy,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
):
"""Generate a single sample image."""
prompt = prompt_dict.get("prompt", "")
@@ -585,7 +603,7 @@ def _sample_image_inference(
t5_attn_mask = t5_attn_mask.to(accelerator.device)
# Process through LLM adapter if available
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
if dit.use_llm_adapter and hasattr(dit, "llm_adapter"):
crossattn_emb = dit.llm_adapter(
source_hidden_states=prompt_embeds,
target_input_ids=t5_input_ids,
@@ -613,7 +631,7 @@ def _sample_image_inference(
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
neg_t5_am = neg_t5_am.to(accelerator.device)
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
if dit.use_llm_adapter and hasattr(dit, "llm_adapter"):
neg_crossattn_emb = dit.llm_adapter(
source_hidden_states=neg_pe,
target_input_ids=neg_t5_ids,
@@ -627,9 +645,16 @@ def _sample_image_inference(
# Generate sample
clean_memory_on_device(accelerator.device)
latents = do_sample(
height, width, seed, dit, crossattn_emb,
sample_steps, dit.t_embedding_norm.weight.dtype,
accelerator.device, scale, neg_crossattn_emb,
height,
width,
seed,
dit,
crossattn_emb,
sample_steps,
dit.t_embedding_norm.weight.dtype,
accelerator.device,
scale,
neg_crossattn_emb,
)
# Decode latents
@@ -662,4 +687,5 @@ def _sample_image_inference(
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False)

View File

@@ -6,7 +6,10 @@ import torch
import torch.nn as nn
from safetensors.torch import load_file, save_file
from accelerate.utils import set_module_tensor_to_device # kept for potential future use
from accelerate import init_empty_weights
from library.fp8_optimization_utils import apply_fp8_monkey_patch
from library.lora_utils import load_safetensors_with_lora_and_fp8
from .utils import setup_logging
setup_logging()
@@ -18,7 +21,7 @@ from library import anima_models
# Keys that should stay in high precision (float32/bfloat16, not quantized)
KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer']
KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"]
def load_safetensors(path: str, device: str = "cpu", dtype: Optional[torch.dtype] = None) -> Dict[str, torch.Tensor]:
@@ -53,6 +56,7 @@ def load_anima_dit(
logger.info(f"Loading Anima DiT from {dit_path}")
if disable_mmap:
from library.safetensors_utils import load_safetensors as load_safetensors_no_mmap
state_dict = load_safetensors_no_mmap(dit_path, device="cpu", disable_mmap=True)
else:
state_dict = load_file(dit_path, device="cpu")
@@ -60,8 +64,8 @@ def load_anima_dit(
# Remove 'net.' prefix if present
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('net.'):
k = k[len('net.'):]
if k.startswith("net."):
k = k[len("net.") :]
new_state_dict[k] = v
state_dict = new_state_dict
@@ -71,18 +75,20 @@ def load_anima_dit(
# Detect LLM adapter
if llm_adapter_path is not None:
use_llm_adapter = True
dit_config['use_llm_adapter'] = True
dit_config["use_llm_adapter"] = True
llm_adapter_state_dict = load_safetensors(llm_adapter_path, device="cpu")
elif 'llm_adapter.out_proj.weight' in state_dict:
elif "llm_adapter.out_proj.weight" in state_dict:
use_llm_adapter = True
dit_config['use_llm_adapter'] = True
dit_config["use_llm_adapter"] = True
llm_adapter_state_dict = None # Loaded as part of DiT
else:
use_llm_adapter = False
llm_adapter_state_dict = None
logger.info(f"DiT config: model_channels={dit_config['model_channels']}, num_blocks={dit_config['num_blocks']}, "
f"num_heads={dit_config['num_heads']}, use_llm_adapter={use_llm_adapter}")
logger.info(
f"DiT config: model_channels={dit_config['model_channels']}, num_blocks={dit_config['num_blocks']}, "
f"num_heads={dit_config['num_heads']}, use_llm_adapter={use_llm_adapter}"
)
# Build model normally on CPU — buffers get proper values from __init__
dit = anima_models.MiniTrainDIT(**dit_config)
@@ -96,9 +102,11 @@ def load_anima_dit(
missing, unexpected = dit.load_state_dict(state_dict, strict=False)
if missing:
# Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
unexpected_missing = [k for k in missing if not any(
buf_name in k for buf_name in ('seq', 'dim_spatial_range', 'dim_temporal_range', 'inv_freq')
)]
unexpected_missing = [
k
for k in missing
if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq"))
]
if unexpected_missing:
logger.warning(f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}")
if unexpected:
@@ -106,9 +114,7 @@ def load_anima_dit(
# Apply per-parameter dtype (high precision for 1D/critical, transformer_dtype for rest)
for name, p in dit.named_parameters():
dtype_to_use = dtype if (
any(keyword in name for keyword in KEEP_IN_HIGH_PRECISION) or p.ndim == 1
) else transformer_dtype
dtype_to_use = dtype if (any(keyword in name for keyword in KEEP_IN_HIGH_PRECISION) or p.ndim == 1) else transformer_dtype
p.data = p.data.to(dtype=dtype_to_use)
dit.to(device)
@@ -116,6 +122,127 @@ def load_anima_dit(
return dit
FP8_OPTIMIZATION_TARGET_KEYS = ["blocks", ""]
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer"]
def load_anima_model(
device: Union[str, torch.device],
dit_path: str,
attn_mode: str,
split_attn: bool,
loading_device: Union[str, torch.device],
dit_weight_dtype: Optional[torch.dtype],
fp8_scaled: bool = False,
lora_weights_list: Optional[Dict[str, torch.Tensor]] = None,
lora_multipliers: Optional[list[float]] = None,
) -> anima_models.Anima:
"""
Load a HunyuanImage model from the specified checkpoint.
Args:
device (Union[str, torch.device]): Device for optimization or merging
dit_path (str): Path to the DiT model checkpoint.
attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc.
split_attn (bool): Whether to use split attention.
loading_device (Union[str, torch.device]): Device to load the model weights on.
dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights.
If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype.
fp8_scaled (bool): Whether to use fp8 scaling for the model weights.
lora_weights_list (Optional[Dict[str, torch.Tensor]]): LoRA weights to apply, if any.
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
"""
# dit_weight_dtype is None for fp8_scaled
assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None)
device = torch.device(device)
loading_device = torch.device(loading_device)
# We currently support fixed DiT config for Anima models
dit_config = {
"max_img_h": 512,
"max_img_w": 512,
"max_frames": 128,
"in_channels": 16,
"out_channels": 16,
"patch_spatial": 2,
"patch_temporal": 1,
"model_channels": 2048,
"concat_padding_mask": True,
"crossattn_emb_channels": 1024,
"pos_emb_cls": "rope3d",
"pos_emb_learnable": True,
"pos_emb_interpolation": "crop",
"min_fps": 1,
"max_fps": 30,
"use_adaln_lora": True,
"adaln_lora_dim": 256,
"num_blocks": 28,
"num_heads": 16,
"extra_per_block_abs_pos_emb": False,
"rope_h_extrapolation_ratio": 4.0,
"rope_w_extrapolation_ratio": 4.0,
"rope_t_extrapolation_ratio": 1.0,
"extra_h_extrapolation_ratio": 1.0,
"extra_w_extrapolation_ratio": 1.0,
"extra_t_extrapolation_ratio": 1.0,
"rope_enable_fps_modulation": False,
"use_llm_adapter": True,
"attn_mode": attn_mode,
"split_attn": split_attn,
}
# model = create_model(attn_mode, split_attn, dit_weight_dtype)
with init_empty_weights():
model = anima_models.Anima(dit_config)
if dit_weight_dtype is not None:
model.to(dit_weight_dtype)
# load model weights with dynamic fp8 optimization and LoRA merging if needed
logger.info(f"Loading DiT model from {dit_path}, device={loading_device}")
sd = load_safetensors_with_lora_and_fp8(
model_files=dit_path,
lora_weights_list=lora_weights_list,
lora_multipliers=lora_multipliers,
fp8_optimization=fp8_scaled,
calc_device=device,
move_to_device=(loading_device == device),
dit_weight_dtype=dit_weight_dtype,
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
)
if fp8_scaled:
apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
if loading_device.type != "cpu":
# make sure all the model weights are on the loading_device
logger.info(f"Moving weights to {loading_device}")
for key in sd.keys():
sd[key] = sd[key].to(loading_device)
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
if missing:
# Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
unexpected_missing = [
k
for k in missing
if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq"))
]
if unexpected_missing:
# Raise error to avoid silent failures
raise RuntimeError(
f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}"
)
missing = {} # all missing keys were expected
if unexpected:
# Raise error to avoid silent failures
raise RuntimeError(f"Unexpected keys in checkpoint: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
logger.info(f"Loaded DiT model from {dit_path}, unexpected missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")
return model
def load_anima_vae(vae_path: str, dtype: torch.dtype = torch.float32, device: str = "cpu"):
"""Load WanVAE from a safetensors/pth file.
@@ -139,14 +266,14 @@ def load_anima_vae(vae_path: str, dtype: torch.dtype = torch.float32, device: st
from library.anima_vae import WanVAE_
# Build model
with torch.device('meta'):
with torch.device("meta"):
vae = WanVAE_(**vae_config)
# Load state dict
if vae_path.endswith('.safetensors'):
vae_sd = load_file(vae_path, device='cpu')
if vae_path.endswith(".safetensors"):
vae_sd = load_file(vae_path, device="cpu")
else:
vae_sd = torch.load(vae_path, map_location='cpu', weights_only=True)
vae_sd = torch.load(vae_path, map_location="cpu", weights_only=True)
vae.load_state_dict(vae_sd, assign=True)
vae = vae.eval().requires_grad_(False).to(device, dtype=dtype)
@@ -175,7 +302,7 @@ def load_qwen3_tokenizer(qwen3_path: str):
if os.path.isdir(qwen3_path):
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
else:
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b')
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
if not os.path.exists(config_dir):
raise FileNotFoundError(
f"Qwen3 config directory not found at {config_dir}. "
@@ -209,12 +336,10 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16
if os.path.isdir(qwen3_path):
# Directory with full model
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
model = transformers.AutoModelForCausalLM.from_pretrained(
qwen3_path, torch_dtype=dtype, local_files_only=True
).model
model = transformers.AutoModelForCausalLM.from_pretrained(qwen3_path, torch_dtype=dtype, local_files_only=True).model
else:
# Single safetensors file - use configs/qwen3_06b/ for config
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b')
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
if not os.path.exists(config_dir):
raise FileNotFoundError(
f"Qwen3 config directory not found at {config_dir}. "
@@ -227,16 +352,16 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16
model = transformers.Qwen3ForCausalLM(qwen3_config).model
# Load weights
if qwen3_path.endswith('.safetensors'):
state_dict = load_file(qwen3_path, device='cpu')
if qwen3_path.endswith(".safetensors"):
state_dict = load_file(qwen3_path, device="cpu")
else:
state_dict = torch.load(qwen3_path, map_location='cpu', weights_only=True)
state_dict = torch.load(qwen3_path, map_location="cpu", weights_only=True)
# Remove 'model.' prefix if present
new_sd = {}
for k, v in state_dict.items():
if k.startswith('model.'):
new_sd[k[len('model.'):]] = v
if k.startswith("model."):
new_sd[k[len("model.") :]] = v
else:
new_sd[k] = v
@@ -265,11 +390,11 @@ def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
return T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
# Use bundled config
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 't5_old')
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "t5_old")
if os.path.exists(config_dir):
return T5TokenizerFast(
vocab_file=os.path.join(config_dir, 'spiece.model'),
tokenizer_file=os.path.join(config_dir, 'tokenizer.json'),
vocab_file=os.path.join(config_dir, "spiece.model"),
tokenizer_file=os.path.join(config_dir, "tokenizer.json"),
)
raise FileNotFoundError(
@@ -291,9 +416,9 @@ def save_anima_model(save_path: str, dit_state_dict: Dict[str, torch.Tensor], dt
for k, v in dit_state_dict.items():
if dtype is not None:
v = v.to(dtype)
prefixed_sd['net.' + k] = v.contiguous()
prefixed_sd["net." + k] = v.contiguous()
save_file(prefixed_sd, save_path, metadata={'format': 'pt'})
save_file(prefixed_sd, save_path, metadata={"format": "pt"})
logger.info(f"Saved Anima model to {save_path}")

View File

@@ -16,8 +16,7 @@ class CausalConv3d(nn.Conv3d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
@@ -41,12 +40,10 @@ class RMS_norm(nn.Module):
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
@@ -61,65 +58,48 @@ class Upsample(nn.Upsample):
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d")
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
if mode == "upsample2d":
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
)
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == 'downsample2d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == 'downsample3d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
@@ -127,15 +107,14 @@ class Resample(nn.Module):
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
if self.mode == 'downsample3d':
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
@@ -144,8 +123,7 @@ class Resample(nn.Module):
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
@@ -166,8 +144,8 @@ class Resample(nn.Module):
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
@@ -181,12 +159,15 @@ class ResidualBlock(nn.Module):
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(),
RMS_norm(in_dim, images=False),
nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1))
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
RMS_norm(out_dim, images=False),
nn.SiLU(),
nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1),
)
self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
@@ -196,11 +177,7 @@ class ResidualBlock(nn.Module):
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -229,13 +206,10 @@ class AttentionBlock(nn.Module):
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.norm(x)
# compute query, key, value
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
-1).permute(0, 1, 3,
2).contiguous().chunk(
3, dim=-1)
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(
@@ -247,20 +221,22 @@ class AttentionBlock(nn.Module):
# output
x = self.proj(x)
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
return x + identity
class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
@@ -288,21 +264,18 @@ class Encoder3d(nn.Module):
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[
i] else 'downsample2d'
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout))
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout)
)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
@@ -310,11 +283,7 @@ class Encoder3d(nn.Module):
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -342,11 +311,7 @@ class Encoder3d(nn.Module):
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -357,14 +322,16 @@ class Encoder3d(nn.Module):
class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
@@ -375,15 +342,15 @@ class Decoder3d(nn.Module):
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout)
)
# upsample blocks
upsamples = []
@@ -399,15 +366,13 @@ class Decoder3d(nn.Module):
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
@@ -416,11 +381,7 @@ class Decoder3d(nn.Module):
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -448,11 +409,7 @@ class Decoder3d(nn.Module):
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
@@ -471,14 +428,16 @@ def count_conv3d(model):
class WanVAE_(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
@@ -489,12 +448,10 @@ class WanVAE_(nn.Module):
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
@@ -510,20 +467,15 @@ class WanVAE_(nn.Module):
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx
)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
@@ -533,8 +485,7 @@ class WanVAE_(nn.Module):
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
@@ -542,15 +493,9 @@ class WanVAE_(nn.Module):
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
@@ -571,7 +516,7 @@ class WanVAE_(nn.Module):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
# cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

File diff suppressed because it is too large Load Diff

View File

@@ -45,8 +45,8 @@ class AnimaTokenizeStrategy(TokenizeStrategy):
t5_tokenizer = anima_utils.load_t5_tokenizer(t5_tokenizer_path)
self.qwen3_tokenizer = qwen3_tokenizer
self.t5_tokenizer = t5_tokenizer
self.qwen3_max_length = qwen3_max_length
self.t5_tokenizer = t5_tokenizer
self.t5_max_length = t5_max_length
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
@@ -54,22 +54,14 @@ class AnimaTokenizeStrategy(TokenizeStrategy):
# Tokenize with Qwen3
qwen3_encoding = self.qwen3_tokenizer.batch_encode_plus(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=self.qwen3_max_length,
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.qwen3_max_length
)
qwen3_input_ids = qwen3_encoding["input_ids"]
qwen3_attn_mask = qwen3_encoding["attention_mask"]
# Tokenize with T5 (for LLM Adapter target tokens)
t5_encoding = self.t5_tokenizer.batch_encode_plus(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=self.t5_max_length,
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.t5_max_length
)
t5_input_ids = t5_encoding["input_ids"]
t5_attn_mask = t5_encoding["attention_mask"]
@@ -84,23 +76,17 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
T5 tokens are passed through unchanged (only used by LLM Adapter).
"""
def __init__(
self,
dropout_rate: float = 0.0,
) -> None:
self.dropout_rate = dropout_rate
def __init__(self) -> None:
super().__init__()
# Cached unconditional embeddings (from encoding empty caption "")
# Must be initialized via cache_uncond_embeddings() before text encoder is deleted
self._uncond_prompt_embeds: Optional[torch.Tensor] = None # (1, seq_len, hidden)
self._uncond_attn_mask: Optional[torch.Tensor] = None # (1, seq_len)
self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len)
self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len)
self._uncond_attn_mask: Optional[torch.Tensor] = None # (1, seq_len)
self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len)
self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len)
def cache_uncond_embeddings(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
) -> None:
def cache_uncond_embeddings(self, tokenize_strategy: TokenizeStrategy, models: List[Any]) -> None:
"""Pre-encode empty caption "" and cache the unconditional embeddings.
Must be called before the text encoder is deleted from GPU.
@@ -110,7 +96,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
logger.info("Caching unconditional embeddings for caption dropout (encoding empty caption)...")
tokens = tokenize_strategy.tokenize("")
with torch.no_grad():
uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens, enable_dropout=False)
uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens)
# Store as CPU tensors (1, seq_len, ...) to avoid GPU memory waste
self._uncond_prompt_embeds = uncond_outputs[0].cpu()
self._uncond_attn_mask = uncond_outputs[1].cpu()
@@ -119,11 +105,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
logger.info(" Unconditional embeddings cached successfully")
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
enable_dropout: bool = True,
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
@@ -134,82 +116,19 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
Returns:
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
"""
# Do not handle dropout here; handled dataset-side or in drop_cached_text_encoder_outputs()
qwen3_text_encoder = models[0]
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
# Handle dropout: replace dropped items with unconditional embeddings (matching diffusion-pipe-main)
batch_size = qwen3_input_ids.shape[0]
non_drop_indices = []
for i in range(batch_size):
drop = enable_dropout and (self.dropout_rate > 0.0 and random.random() < self.dropout_rate)
if not drop:
non_drop_indices.append(i)
encoder_device = qwen3_text_encoder.device
encoder_device = qwen3_text_encoder.device if hasattr(qwen3_text_encoder, 'device') else next(qwen3_text_encoder.parameters()).device
qwen3_input_ids = qwen3_input_ids.to(encoder_device)
qwen3_attn_mask = qwen3_attn_mask.to(encoder_device)
outputs = qwen3_text_encoder(input_ids=qwen3_input_ids, attention_mask=qwen3_attn_mask)
prompt_embeds = outputs.last_hidden_state
if len(non_drop_indices) > 0 and len(non_drop_indices) < batch_size:
# Only encode non-dropped items to save compute
nd_input_ids = qwen3_input_ids[non_drop_indices].to(encoder_device)
nd_attn_mask = qwen3_attn_mask[non_drop_indices].to(encoder_device)
elif len(non_drop_indices) == batch_size:
nd_input_ids = qwen3_input_ids.to(encoder_device)
nd_attn_mask = qwen3_attn_mask.to(encoder_device)
else:
nd_input_ids = None
nd_attn_mask = None
if nd_input_ids is not None:
outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask)
nd_encoded_text = outputs.last_hidden_state
# Zero out padding positions
nd_encoded_text[~nd_attn_mask.bool()] = 0
# Build full batch: fill non-dropped with encoded, dropped with unconditional
if len(non_drop_indices) == batch_size:
prompt_embeds = nd_encoded_text
attn_mask = qwen3_attn_mask.to(encoder_device)
else:
# Get unconditional embeddings
if self._uncond_prompt_embeds is not None:
uncond_pe = self._uncond_prompt_embeds[0]
uncond_am = self._uncond_attn_mask[0]
uncond_t5_ids = self._uncond_t5_input_ids[0]
uncond_t5_am = self._uncond_t5_attn_mask[0]
else:
# Encode empty caption on-the-fly (text encoder still available)
uncond_tokens = tokenize_strategy.tokenize("")
uncond_ids = uncond_tokens[0].to(encoder_device)
uncond_mask = uncond_tokens[1].to(encoder_device)
uncond_out = qwen3_text_encoder(input_ids=uncond_ids, attention_mask=uncond_mask)
uncond_pe = uncond_out.last_hidden_state[0]
uncond_pe[~uncond_mask[0].bool()] = 0
uncond_am = uncond_mask[0]
uncond_t5_ids = uncond_tokens[2][0]
uncond_t5_am = uncond_tokens[3][0]
seq_len = qwen3_input_ids.shape[1]
hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1]
dtype = nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype
prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype)
attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype)
if len(non_drop_indices) > 0:
prompt_embeds[non_drop_indices] = nd_encoded_text
attn_mask[non_drop_indices] = nd_attn_mask
# Fill dropped items with unconditional embeddings
t5_input_ids = t5_input_ids.clone()
t5_attn_mask = t5_attn_mask.clone()
drop_indices = [i for i in range(batch_size) if i not in non_drop_indices]
for i in drop_indices:
prompt_embeds[i] = uncond_pe.to(device=encoder_device, dtype=dtype)
attn_mask[i] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype)
t5_input_ids[i] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
t5_attn_mask[i] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
return [prompt_embeds, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
def drop_cached_text_encoder_outputs(
self,
@@ -217,6 +136,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
attn_mask: torch.Tensor,
t5_input_ids: torch.Tensor,
t5_attn_mask: torch.Tensor,
caption_dropout_rates: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
"""Apply dropout to cached text encoder outputs.
@@ -224,37 +144,30 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
Replaces dropped items with pre-cached unconditional embeddings (from encoding "")
to match diffusion-pipe-main behavior.
"""
if prompt_embeds is not None and self.dropout_rate > 0.0:
# Clone to avoid in-place modification of cached tensors
prompt_embeds = prompt_embeds.clone()
if attn_mask is not None:
attn_mask = attn_mask.clone()
if t5_input_ids is not None:
t5_input_ids = t5_input_ids.clone()
if t5_attn_mask is not None:
t5_attn_mask = t5_attn_mask.clone()
if caption_dropout_rates is None or all(caption_dropout_rates == 0.0):
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
for i in range(prompt_embeds.shape[0]):
if random.random() < self.dropout_rate:
if self._uncond_prompt_embeds is not None:
# Use pre-cached unconditional embeddings
prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if attn_mask is not None:
attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype)
if t5_input_ids is not None:
t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
if t5_attn_mask is not None:
t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
else:
# Fallback: zero out (should not happen if cache_uncond_embeddings was called)
logger.warning("Unconditional embeddings not cached, falling back to zeros for caption dropout")
prompt_embeds[i] = torch.zeros_like(prompt_embeds[i])
if attn_mask is not None:
attn_mask[i] = torch.zeros_like(attn_mask[i])
if t5_input_ids is not None:
t5_input_ids[i] = torch.zeros_like(t5_input_ids[i])
if t5_attn_mask is not None:
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
assert self._uncond_prompt_embeds is not None, "Unconditional embeddings not cached, cannot apply caption dropout"
# Clone to avoid in-place modification of cached tensors
prompt_embeds = prompt_embeds.clone()
if attn_mask is not None:
attn_mask = attn_mask.clone()
if t5_input_ids is not None:
t5_input_ids = t5_input_ids.clone()
if t5_attn_mask is not None:
t5_attn_mask = t5_attn_mask.clone()
for i in range(prompt_embeds.shape[0]):
if random.random() < caption_dropout_rates[i].item():
# Use pre-cached unconditional embeddings
prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if attn_mask is not None:
attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype)
if t5_input_ids is not None:
t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
if t5_attn_mask is not None:
t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
@@ -297,6 +210,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return False
if "t5_attn_mask" not in npz:
return False
if "caption_dropout_rate" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -309,7 +224,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask = data["attn_mask"]
t5_input_ids = data["t5_input_ids"]
t5_attn_mask = data["t5_attn_mask"]
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
caption_dropout_rate = data["caption_dropout_rate"]
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
def cache_batch_outputs(
self,
@@ -344,6 +260,7 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask_i = attn_mask[i]
t5_input_ids_i = t5_input_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
if self.cache_to_disk:
np.savez(
@@ -352,9 +269,16 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask=attn_mask_i,
t5_input_ids=t5_input_ids_i,
t5_attn_mask=t5_attn_mask_i,
caption_dropout_rate=caption_dropout_rate,
)
else:
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i)
info.text_encoder_outputs = (
prompt_embeds_i,
attn_mask_i,
t5_input_ids_i,
t5_attn_mask_i,
caption_dropout_rate,
)
class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
@@ -374,18 +298,10 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
return self.ANIMA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ self.ANIMA_LATENTS_NPZ_SUFFIX
)
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX
def is_disk_cached_latents_expected(
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
):
return self._default_is_disk_cached_latents_expected(
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]

View File

@@ -179,12 +179,15 @@ def split_train_val(
class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
def __init__(
self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str, caption_dropout_rate: float = 0.0
) -> None:
self.image_key: str = image_key
self.num_repeats: int = num_repeats
self.caption: str = caption
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
self.caption_dropout_rate: float = caption_dropout_rate
self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
@@ -197,7 +200,7 @@ class ImageInfo:
)
self.cond_img_path: Optional[str] = None
self.image: Optional[Image.Image] = None # optional, original PIL Image
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
self.text_encoder_outputs_npz: Optional[str] = None # filename. set in cache_text_encoder_outputs
# new
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
@@ -2137,7 +2140,7 @@ class DreamBoothDataset(BaseDataset):
num_train_images += num_repeats * len(img_paths)
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path, subset.caption_dropout_rate)
info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)
@@ -2338,7 +2341,7 @@ class FineTuningDataset(BaseDataset):
if caption is None:
caption = ""
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path, subset.caption_dropout_rate)
image_info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)

View File

@@ -1,18 +1,17 @@
# LoRA network module for Anima
import math
# LoRA network module for Anima
import ast
import os
import re
from typing import Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch
from library.utils import setup_logging
from networks.lora_flux import LoRAModule, LoRAInfModule
setup_logging()
import logging
setup_logging()
logger = logging.getLogger(__name__)
from networks.lora_flux import LoRAModule, LoRAInfModule
def create_network(
multiplier: float,
@@ -29,68 +28,28 @@ def create_network(
if network_alpha is None:
network_alpha = 1.0
# type_dims: [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
self_attn_dim = kwargs.get("self_attn_dim", None)
cross_attn_dim = kwargs.get("cross_attn_dim", None)
mlp_dim = kwargs.get("mlp_dim", None)
mod_dim = kwargs.get("mod_dim", None)
llm_adapter_dim = kwargs.get("llm_adapter_dim", None)
if self_attn_dim is not None:
self_attn_dim = int(self_attn_dim)
if cross_attn_dim is not None:
cross_attn_dim = int(cross_attn_dim)
if mlp_dim is not None:
mlp_dim = int(mlp_dim)
if mod_dim is not None:
mod_dim = int(mod_dim)
if llm_adapter_dim is not None:
llm_adapter_dim = int(llm_adapter_dim)
type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
if all([d is None for d in type_dims]):
type_dims = None
# emb_dims: [x_embedder, t_embedder, final_layer]
emb_dims = kwargs.get("emb_dims", None)
if emb_dims is not None:
emb_dims = emb_dims.strip()
if emb_dims.startswith("[") and emb_dims.endswith("]"):
emb_dims = emb_dims[1:-1]
emb_dims = [int(d) for d in emb_dims.split(",")]
assert len(emb_dims) == 3, f"invalid emb_dims: {emb_dims}, must be 3 dimensions (x_embedder, t_embedder, final_layer)"
# block selection
def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
if selection == "all":
return [True] * total_blocks
if selection == "none" or selection == "":
return [False] * total_blocks
selected = [False] * total_blocks
ranges = selection.split(",")
for r in ranges:
if "-" in r:
start, end = map(str.strip, r.split("-"))
start, end = int(start), int(end)
assert 0 <= start < total_blocks and 0 <= end < total_blocks and start <= end
for i in range(start, end + 1):
selected[i] = True
else:
index = int(r)
assert 0 <= index < total_blocks
selected[index] = True
return selected
train_block_indices = kwargs.get("train_block_indices", None)
if train_block_indices is not None:
num_blocks = len(unet.blocks) if hasattr(unet, 'blocks') else 999
train_block_indices = parse_block_selection(train_block_indices, num_blocks)
# train LLM adapter
train_llm_adapter = kwargs.get("train_llm_adapter", False)
if train_llm_adapter is not None:
train_llm_adapter = True if train_llm_adapter == "True" else False
train_llm_adapter = True if train_llm_adapter.lower() == "true" else False
exclude_patterns = kwargs.get("exclude_patterns", None)
if exclude_patterns is None:
exclude_patterns = []
else:
exclude_patterns = ast.literal_eval(exclude_patterns)
if not isinstance(exclude_patterns, list):
exclude_patterns = [exclude_patterns]
# add default exclude patterns
exclude_patterns.append(r".*(_modulation|_norm|_embedder|final_layer).*")
# regular expression for module selection: exclude and include
include_patterns = kwargs.get("include_patterns", None)
if include_patterns is not None:
include_patterns = ast.literal_eval(include_patterns)
if not isinstance(include_patterns, list):
include_patterns = [include_patterns]
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
@@ -103,7 +62,7 @@ def create_network(
# verbose
verbose = kwargs.get("verbose", False)
if verbose is not None:
verbose = True if verbose == "True" else False
verbose = True if verbose.lower() == "true" else False
network = LoRANetwork(
text_encoders,
@@ -115,9 +74,8 @@ def create_network(
rank_dropout=rank_dropout,
module_dropout=module_dropout,
train_llm_adapter=train_llm_adapter,
type_dims=type_dims,
emb_dims=emb_dims,
train_block_indices=train_block_indices,
exclude_patterns=exclude_patterns,
include_patterns=include_patterns,
verbose=verbose,
)
@@ -137,6 +95,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
@@ -173,8 +132,8 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh
class LoRANetwork(torch.nn.Module):
# Target modules: DiT blocks
ANIMA_TARGET_REPLACE_MODULE = ["Block"]
# Target modules: DiT blocks, embedders, final layer. embedders and final layer are excluded by default.
ANIMA_TARGET_REPLACE_MODULE = ["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"]
# Target modules: LLM Adapter blocks
ANIMA_ADAPTER_TARGET_REPLACE_MODULE = ["LLMAdapterTransformerBlock"]
# Target modules for text encoder (Qwen3)
@@ -197,9 +156,8 @@ class LoRANetwork(torch.nn.Module):
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
train_llm_adapter: bool = False,
type_dims: Optional[List[int]] = None,
emb_dims: Optional[List[int]] = None,
train_block_indices: Optional[List[bool]] = None,
exclude_patterns: Optional[List[str]] = None,
include_patterns: Optional[List[str]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
@@ -210,21 +168,36 @@ class LoRANetwork(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.train_llm_adapter = train_llm_adapter
self.type_dims = type_dims
self.emb_dims = emb_dims
self.train_block_indices = train_block_indices
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info(f"create LoRA network from weights")
logger.info("create LoRA network from weights")
if self.emb_dims is None:
self.emb_dims = [0] * 3
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
# compile regular expression if specified
def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
re_patterns = []
if patterns is not None:
for pattern in patterns:
try:
re_pattern = re.compile(pattern)
except re.error as e:
logger.error(f"Invalid pattern '{pattern}': {e}")
continue
re_patterns.append(re_pattern)
return re_patterns
exclude_re_patterns = str_to_re_patterns(exclude_patterns)
include_re_patterns = str_to_re_patterns(include_patterns)
# create module instances
def create_modules(
@@ -232,15 +205,9 @@ class LoRANetwork(torch.nn.Module):
text_encoder_idx: Optional[int],
root_module: torch.nn.Module,
target_replace_modules: List[str],
filter: Optional[str] = None,
default_dim: Optional[int] = None,
include_conv2d_if_filter: bool = False,
) -> Tuple[List[LoRAModule], List[str]]:
prefix = (
self.LORA_PREFIX_ANIMA
if is_unet
else self.LORA_PREFIX_TEXT_ENCODER
)
prefix = self.LORA_PREFIX_ANIMA if is_unet else self.LORA_PREFIX_TEXT_ENCODER
loras = []
skipped = []
@@ -255,14 +222,16 @@ class LoRANetwork(torch.nn.Module):
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_")
original_name = (name + "." if name else "") + child_name
lora_name = f"{prefix}.{original_name}".replace(".", "_")
force_incl_conv2d = False
if filter is not None:
if filter not in lora_name:
continue
force_incl_conv2d = include_conv2d_if_filter
# exclude/include filter
excluded = any(pattern.match(original_name) for pattern in exclude_re_patterns)
included = any(pattern.match(original_name) for pattern in include_re_patterns)
if excluded and not included:
if verbose:
logger.info(f"exclude: {original_name}")
continue
dim = None
alpha_val = None
@@ -276,40 +245,6 @@ class LoRANetwork(torch.nn.Module):
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
if is_unet and type_dims is not None:
# type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
# Order matters: check most specific identifiers first to avoid mismatches.
identifier_order = [
(4, ("llm_adapter",)),
(3, ("adaln_modulation",)),
(0, ("self_attn",)),
(1, ("cross_attn",)),
(2, ("mlp",)),
]
for idx, ids in identifier_order:
d = type_dims[idx]
if d is not None and all(id_str in lora_name for id_str in ids):
dim = d # 0 means skip
break
# block index filtering
if is_unet and dim and self.train_block_indices is not None and "blocks_" in lora_name:
# Extract block index from lora_name: "lora_unet_blocks_0_self_attn..."
parts = lora_name.split("_")
for pi, part in enumerate(parts):
if part == "blocks" and pi + 1 < len(parts):
try:
block_index = int(parts[pi + 1])
if not self.train_block_indices[block_index]:
dim = 0
except (ValueError, IndexError):
pass
break
elif force_incl_conv2d:
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
if dim is None or dim == 0:
if is_linear or is_conv2d_1x1:
skipped.append(lora_name)
@@ -339,9 +274,7 @@ class LoRANetwork(torch.nn.Module):
if text_encoder is None:
continue
logger.info(f"create LoRA for Text Encoder {i+1}:")
te_loras, te_skipped = create_modules(
False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
)
te_loras, te_skipped = create_modules(False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder {i+1}: {len(te_loras)} modules.")
self.text_encoder_loras.extend(te_loras)
skipped_te += te_skipped
@@ -354,19 +287,6 @@ class LoRANetwork(torch.nn.Module):
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
# emb_dims: [x_embedder, t_embedder, final_layer]
if self.emb_dims:
for filter_name, in_dim in zip(
["x_embedder", "t_embedder", "final_layer"],
self.emb_dims,
):
loras, _ = create_modules(
True, None, unet, None,
filter=filter_name, default_dim=in_dim,
include_conv2d_if_filter=(filter_name == "x_embedder"),
)
self.unet_loras.extend(loras)
logger.info(f"create LoRA for Anima DiT: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
@@ -396,6 +316,7 @@ class LoRANetwork(torch.nn.Module):
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
@@ -443,10 +364,10 @@ class LoRANetwork(torch.nn.Module):
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1:]] = weights_sd[key]
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
logger.info(f"weights are merged")
logger.info("weights are merged")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
@@ -498,10 +419,7 @@ class LoRANetwork(torch.nn.Module):
if self.text_encoder_loras:
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
te1_loras = [
lora for lora in self.text_encoder_loras
if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)
]
te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)]
if len(te1_loras) > 0:
logger.info(f"Text Encoder 1 (Qwen3): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_ratio)

View File

@@ -2,7 +2,7 @@
Diagnostic script to test Anima latent & text encoder caching independently.
Usage:
python test_anima_cache.py \
python manual_test_anima_cache.py \
--image_dir /path/to/images \
--qwen3_path /path/to/qwen3 \
--vae_path /path/to/vae.safetensors \