Support Anima model (#2260)

* Support Anima model

* Update document and fix bug

* Fix latent normlization

* Fix typo

* Fix cache embedding

* fix typo in tests/test_anima_cache.py

* Remove redundant argument apply_t5_attn_mask

* Improving caching with argument caption_dropout_rate

* Fix W&B logging bugs

* Fix discrete_flow_shift default value
This commit is contained in:
duongve13112002
2026-02-08 08:18:55 +07:00
committed by GitHub
parent b996440c5f
commit e21a7736f8
21 changed files with 462100 additions and 3 deletions

View File

@@ -0,0 +1,665 @@
# Anima Training Utilities
import argparse
import math
import os
import time
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from safetensors.torch import save_file
from accelerate import Accelerator, PartialState
from tqdm import tqdm
from PIL import Image
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import anima_models, anima_utils, strategy_base, train_util
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas
# Anima-specific training arguments
def add_anima_training_arguments(parser: argparse.ArgumentParser):
"""Add Anima-specific training arguments to the parser."""
parser.add_argument(
"--dit_path",
type=str,
default=None,
help="Path to Anima DiT model safetensors file",
)
parser.add_argument(
"--vae_path",
type=str,
default=None,
help="Path to WanVAE safetensors/pth file",
)
parser.add_argument(
"--qwen3_path",
type=str,
default=None,
help="Path to Qwen3-0.6B model (safetensors file or directory)",
)
parser.add_argument(
"--llm_adapter_path",
type=str,
default=None,
help="Path to separate LLM adapter weights. If None, adapter is loaded from DiT file if present",
)
parser.add_argument(
"--llm_adapter_lr",
type=float,
default=None,
help="Learning rate for LLM adapter. None=same as base LR, 0=freeze adapter",
)
parser.add_argument(
"--self_attn_lr",
type=float,
default=None,
help="Learning rate for self-attention layers. None=same as base LR, 0=freeze",
)
parser.add_argument(
"--cross_attn_lr",
type=float,
default=None,
help="Learning rate for cross-attention layers. None=same as base LR, 0=freeze",
)
parser.add_argument(
"--mlp_lr",
type=float,
default=None,
help="Learning rate for MLP layers. None=same as base LR, 0=freeze",
)
parser.add_argument(
"--mod_lr",
type=float,
default=None,
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze",
)
parser.add_argument(
"--t5_tokenizer_path",
type=str,
default=None,
help="Path to T5 tokenizer directory. If None, uses default configs/t5_old/",
)
parser.add_argument(
"--qwen3_max_token_length",
type=int,
default=512,
help="Maximum token length for Qwen3 tokenizer (default: 512)",
)
parser.add_argument(
"--t5_max_token_length",
type=int,
default=512,
help="Maximum token length for T5 tokenizer (default: 512)",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=1.0,
help="Timestep distribution shift for rectified flow training (default: 1.0)",
)
parser.add_argument(
"--timestep_sample_method",
type=str,
default="logit_normal",
choices=["logit_normal", "uniform"],
help="Timestep sampling method (default: logit_normal)",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
help="Scale factor for logit_normal timestep sampling (default: 1.0)",
)
# Note: --caption_dropout_rate is defined by base add_dataset_arguments().
# Anima uses embedding-level dropout (via AnimaTextEncodingStrategy.dropout_rate)
# instead of dataset-level caption dropout, so the subset caption_dropout_rate
# is zeroed out in the training scripts to allow caching.
parser.add_argument(
"--transformer_dtype",
type=str,
default=None,
choices=["float16", "bfloat16", "float32", None],
help="Separate dtype for transformer blocks. If None, uses same as mixed_precision",
)
parser.add_argument(
"--flash_attn",
action="store_true",
help="Use Flash Attention for DiT self/cross-attention (requires flash-attn package). "
"Falls back to PyTorch SDPA if flash-attn is not installed.",
)
# Noise & Timestep sampling (Rectified Flow)
def get_noisy_model_input_and_timesteps(
args,
latents: torch.Tensor,
noise: torch.Tensor,
device: torch.device,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate noisy model input and timesteps for rectified flow training.
Rectified flow: noisy_input = (1 - t) * latents + t * noise
Target: noise - latents
Args:
args: Training arguments with timestep_sample_method, sigmoid_scale, discrete_flow_shift
latents: Clean latent tensors
noise: Random noise tensors
device: Target device
dtype: Target dtype
Returns:
(noisy_model_input, timesteps, sigmas)
"""
bs = latents.shape[0]
timestep_sample_method = getattr(args, 'timestep_sample_method', 'logit_normal')
sigmoid_scale = getattr(args, 'sigmoid_scale', 1.0)
shift = getattr(args, 'discrete_flow_shift', 1.0)
if timestep_sample_method == 'logit_normal':
dist = torch.distributions.normal.Normal(0, 1)
elif timestep_sample_method == 'uniform':
dist = torch.distributions.uniform.Uniform(0, 1)
else:
raise NotImplementedError(f"Unknown timestep_sample_method: {timestep_sample_method}")
t = dist.sample((bs,)).to(device)
if timestep_sample_method == 'logit_normal':
t = t * sigmoid_scale
t = torch.sigmoid(t)
# Apply shift
if shift is not None and shift != 1.0:
t = (t * shift) / (1 + (shift - 1) * t)
# Clamp to avoid exact 0 or 1
t = t.clamp(1e-5, 1.0 - 1e-5)
# Create noisy input: (1 - t) * latents + t * noise
t_expanded = t.view(-1, *([1] * (latents.ndim - 1)))
ip_noise_gamma = getattr(args, 'ip_noise_gamma', None)
if ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if getattr(args, 'ip_noise_gamma_random_strength', False):
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * ip_noise_gamma
noisy_model_input = (1 - t_expanded) * latents + t_expanded * (noise + ip_noise_gamma * xi)
else:
noisy_model_input = (1 - t_expanded) * latents + t_expanded * noise
# Sigmas for potential loss weighting
sigmas = t.view(-1, 1)
return noisy_model_input.to(dtype), t.to(dtype), sigmas.to(dtype)
# Loss weighting
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
"""Compute loss weighting for Anima training.
Same schemes as SD3 but can add Anima-specific ones.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
elif weighting_scheme == "none" or weighting_scheme is None:
weighting = torch.ones_like(sigmas)
else:
weighting = torch.ones_like(sigmas)
return weighting
# Parameter groups (6 groups with separate LRs)
def get_anima_param_groups(
dit,
base_lr: float,
self_attn_lr: Optional[float] = None,
cross_attn_lr: Optional[float] = None,
mlp_lr: Optional[float] = None,
mod_lr: Optional[float] = None,
llm_adapter_lr: Optional[float] = None,
):
"""Create parameter groups for Anima training with separate learning rates.
Args:
dit: MiniTrainDIT model
base_lr: Base learning rate
self_attn_lr: LR for self-attention layers (None = base_lr, 0 = freeze)
cross_attn_lr: LR for cross-attention layers
mlp_lr: LR for MLP layers
mod_lr: LR for AdaLN modulation layers
llm_adapter_lr: LR for LLM adapter
Returns:
List of parameter group dicts for optimizer
"""
if self_attn_lr is None:
self_attn_lr = base_lr
if cross_attn_lr is None:
cross_attn_lr = base_lr
if mlp_lr is None:
mlp_lr = base_lr
if mod_lr is None:
mod_lr = base_lr
if llm_adapter_lr is None:
llm_adapter_lr = base_lr
base_params = []
self_attn_params = []
cross_attn_params = []
mlp_params = []
mod_params = []
llm_adapter_params = []
for name, p in dit.named_parameters():
# Store original name for debugging
p.original_name = name
if 'llm_adapter' in name:
llm_adapter_params.append(p)
elif '.self_attn' in name:
self_attn_params.append(p)
elif '.cross_attn' in name:
cross_attn_params.append(p)
elif '.mlp' in name:
mlp_params.append(p)
elif '.adaln_modulation' in name:
mod_params.append(p)
else:
base_params.append(p)
logger.info(f"Parameter groups:")
logger.info(f" base_params: {len(base_params)} (lr={base_lr})")
logger.info(f" self_attn_params: {len(self_attn_params)} (lr={self_attn_lr})")
logger.info(f" cross_attn_params: {len(cross_attn_params)} (lr={cross_attn_lr})")
logger.info(f" mlp_params: {len(mlp_params)} (lr={mlp_lr})")
logger.info(f" mod_params: {len(mod_params)} (lr={mod_lr})")
logger.info(f" llm_adapter_params: {len(llm_adapter_params)} (lr={llm_adapter_lr})")
param_groups = []
for lr, params, name in [
(base_lr, base_params, "base"),
(self_attn_lr, self_attn_params, "self_attn"),
(cross_attn_lr, cross_attn_params, "cross_attn"),
(mlp_lr, mlp_params, "mlp"),
(mod_lr, mod_params, "mod"),
(llm_adapter_lr, llm_adapter_params, "llm_adapter"),
]:
if lr == 0:
for p in params:
p.requires_grad_(False)
logger.info(f" Frozen {name} params ({len(params)} parameters)")
elif len(params) > 0:
param_groups.append({'params': params, 'lr': lr})
total_trainable = sum(p.numel() for group in param_groups for p in group['params'] if p.requires_grad)
logger.info(f"Total trainable parameters: {total_trainable:,}")
return param_groups
# Save functions
def save_anima_model_on_train_end(
args: argparse.Namespace,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
dit: anima_models.MiniTrainDIT,
):
"""Save Anima model at the end of training."""
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True
)
dit_sd = dit.state_dict()
# Save with 'net.' prefix for ComfyUI compatibility
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
def save_anima_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator: Accelerator,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
dit: anima_models.MiniTrainDIT,
):
"""Save Anima model at epoch end or specific steps."""
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True
)
dit_sd = dit.state_dict()
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
True,
True,
epoch,
num_train_epochs,
global_step,
sd_saver,
None,
)
# Sampling (Euler discrete for rectified flow)
def do_sample(
height: int,
width: int,
seed: Optional[int],
dit: anima_models.MiniTrainDIT,
crossattn_emb: torch.Tensor,
steps: int,
dtype: torch.dtype,
device: torch.device,
guidance_scale: float = 1.0,
neg_crossattn_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Generate a sample using Euler discrete sampling for rectified flow.
Args:
height, width: Output image dimensions
seed: Random seed (None for random)
dit: MiniTrainDIT model
crossattn_emb: Cross-attention embeddings (B, N, D)
steps: Number of sampling steps
dtype: Compute dtype
device: Compute device
guidance_scale: CFG scale (1.0 = no guidance)
neg_crossattn_emb: Negative cross-attention embeddings for CFG
Returns:
Denoised latents
"""
# Latent shape: (1, 16, 1, H/8, W/8) for single image
latent_h = height // 8
latent_w = width // 8
latent = torch.zeros(1, 16, 1, latent_h, latent_w, device=device, dtype=dtype)
# Generate noise
if seed is not None:
generator = torch.manual_seed(seed)
else:
generator = None
noise = torch.randn(
latent.size(), dtype=torch.float32, generator=generator, device="cpu"
).to(dtype).to(device)
# Timestep schedule: linear from 1.0 to 0.0
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
# Start from pure noise
x = noise.clone()
# Padding mask (zeros = no padding) — resized in prepare_embedded_sequence to match latent dims
padding_mask = torch.zeros(1, 1, latent_h, latent_w, dtype=dtype, device=device)
use_cfg = guidance_scale > 1.0 and neg_crossattn_emb is not None
for i in tqdm(range(steps), desc="Sampling"):
sigma = sigmas[i]
t = sigma.unsqueeze(0) # (1,)
dit.prepare_block_swap_before_forward()
if use_cfg:
# CFG: concat positive and negative
x_input = torch.cat([x, x], dim=0)
t_input = torch.cat([t, t], dim=0)
crossattn_input = torch.cat([crossattn_emb, neg_crossattn_emb], dim=0)
padding_input = torch.cat([padding_mask, padding_mask], dim=0)
model_output = dit(x_input, t_input, crossattn_input, padding_mask=padding_input)
model_output = model_output.float()
pos_out, neg_out = model_output.chunk(2)
model_output = neg_out + guidance_scale * (pos_out - neg_out)
else:
model_output = dit(x, t, crossattn_emb, padding_mask=padding_mask)
model_output = model_output.float()
# Euler step: x_{t-1} = x_t - (sigma_t - sigma_{t-1}) * model_output
dt = sigmas[i + 1] - sigma
x = x + model_output * dt
x = x.to(dtype)
dit.prepare_block_swap_before_forward()
return x
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
dit,
vae,
vae_scale,
text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs=None,
prompt_replacement=None,
):
"""Generate sample images during training.
This is a simplified sampler for Anima - it generates images using the current model state.
"""
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None:
return
logger.info(f"Generating sample images at step {steps}")
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file: {args.sample_prompts}")
return
# Unwrap models
dit = accelerator.unwrap_model(dit)
if text_encoder is not None:
text_encoder = accelerator.unwrap_model(text_encoder)
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = os.path.join(args.output_dir, "sample")
os.makedirs(save_dir, exist_ok=True)
# Save RNG state
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
_sample_image_inference(
accelerator, args, dit, text_encoder, vae, vae_scale,
tokenize_strategy, text_encoding_strategy,
save_dir, prompt_dict, epoch, steps,
sample_prompts_te_outputs, prompt_replacement,
)
# Restore RNG state
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
clean_memory_on_device(accelerator.device)
def _sample_image_inference(
accelerator, args, dit, text_encoder, vae, vae_scale,
tokenize_strategy, text_encoding_strategy,
save_dir, prompt_dict, epoch, steps,
sample_prompts_te_outputs, prompt_replacement,
):
"""Generate a single sample image."""
prompt = prompt_dict.get("prompt", "")
negative_prompt = prompt_dict.get("negative_prompt", "")
sample_steps = prompt_dict.get("sample_steps", 30)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # seed all CUDA devices for multi-GPU
height = max(64, height - height % 16)
width = max(64, width - width % 16)
logger.info(f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}")
# Encode prompt
def encode_prompt(prpt):
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
return sample_prompts_te_outputs[prpt]
if text_encoder is not None:
tokens = tokenize_strategy.tokenize(prpt)
encoded = text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)
return encoded
return None
encoded = encode_prompt(prompt)
if encoded is None:
logger.warning("Cannot encode prompt, skipping sample")
return
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = encoded
# Convert to tensors if numpy
if isinstance(prompt_embeds, np.ndarray):
prompt_embeds = torch.from_numpy(prompt_embeds).unsqueeze(0)
attn_mask = torch.from_numpy(attn_mask).unsqueeze(0)
t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0)
t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0)
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
attn_mask = attn_mask.to(accelerator.device)
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
t5_attn_mask = t5_attn_mask.to(accelerator.device)
# Process through LLM adapter if available
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
crossattn_emb = dit.llm_adapter(
source_hidden_states=prompt_embeds,
target_input_ids=t5_input_ids,
target_attention_mask=t5_attn_mask,
source_attention_mask=attn_mask,
)
crossattn_emb[~t5_attn_mask.bool()] = 0
else:
crossattn_emb = prompt_embeds
# Encode negative prompt for CFG
neg_crossattn_emb = None
if scale > 1.0 and negative_prompt is not None:
neg_encoded = encode_prompt(negative_prompt)
if neg_encoded is not None:
neg_pe, neg_am, neg_t5_ids, neg_t5_am = neg_encoded
if isinstance(neg_pe, np.ndarray):
neg_pe = torch.from_numpy(neg_pe).unsqueeze(0)
neg_am = torch.from_numpy(neg_am).unsqueeze(0)
neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0)
neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0)
neg_pe = neg_pe.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
neg_am = neg_am.to(accelerator.device)
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
neg_t5_am = neg_t5_am.to(accelerator.device)
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
neg_crossattn_emb = dit.llm_adapter(
source_hidden_states=neg_pe,
target_input_ids=neg_t5_ids,
target_attention_mask=neg_t5_am,
source_attention_mask=neg_am,
)
neg_crossattn_emb[~neg_t5_am.bool()] = 0
else:
neg_crossattn_emb = neg_pe
# Generate sample
clean_memory_on_device(accelerator.device)
latents = do_sample(
height, width, seed, dit, crossattn_emb,
sample_steps, dit.t_embedding_norm.weight.dtype,
accelerator.device, scale, neg_crossattn_emb,
)
# Decode latents
clean_memory_on_device(accelerator.device)
org_vae_device = next(vae.parameters()).device
vae.to(accelerator.device)
decoded = vae.decode(latents.to(next(vae.parameters()).device, dtype=next(vae.parameters()).dtype), vae_scale)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
# Convert to image
image = decoded.float()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
# Remove temporal dim if present
if image.ndim == 4:
image = image[:, 0, :, :]
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
decoded_np = decoded_np.astype(np.uint8)
image = Image.fromarray(decoded_np)
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i = prompt_dict.get("enum", 0)
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# Log to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False)