fix: update argument names in anima_train_utils to align with other archtectures

This commit is contained in:
kohya-ss
2026-02-09 12:46:04 +09:00
parent f320c1b964
commit 06dcb30016
2 changed files with 39 additions and 63 deletions

View File

@@ -859,11 +859,11 @@ def setup_parser() -> argparse.ArgumentParser:
anima_train_utils.add_anima_training_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
parser.add_argument(
"--blockwise_fused_optimizers",
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step",
)
# parser.add_argument(
# "--blockwise_fused_optimizers",
# action="store_true",
# help="enable blockwise optimizers for fused backward pass and optimizer step",
# )
parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",

View File

@@ -15,6 +15,7 @@ from tqdm import tqdm
from PIL import Image
from library.device_utils import init_ipex, clean_memory_on_device
from library import anima_models, anima_utils, strategy_base, train_util, qwen_image_autoencoder_kl
init_ipex()
@@ -25,10 +26,6 @@ import logging
logger = logging.getLogger(__name__)
from library import anima_models, anima_utils, strategy_base, train_util
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas
# Anima-specific training arguments
@@ -36,19 +33,7 @@ from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas
def add_anima_training_arguments(parser: argparse.ArgumentParser):
"""Add Anima-specific training arguments to the parser."""
parser.add_argument(
"--dit_path",
type=str,
default=None,
help="Path to Anima DiT model safetensors file",
)
parser.add_argument(
"--vae_path",
type=str,
default=None,
help="Path to WanVAE safetensors/pth file",
)
parser.add_argument(
"--qwen3_path",
"--qwen3",
type=str,
default=None,
help="Path to Qwen3-0.6B model (safetensors file or directory)",
@@ -87,7 +72,7 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
"--mod_lr",
type=float,
default=None,
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze",
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze. Note: mod layers are not included in LoRA by default.",
)
parser.add_argument(
"--t5_tokenizer_path",
@@ -114,34 +99,29 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
help="Timestep distribution shift for rectified flow training (default: 1.0)",
)
parser.add_argument(
"--timestep_sample_method",
"--timestep_sampling",
type=str,
default="logit_normal",
choices=["logit_normal", "uniform"],
help="Timestep sampling method (default: logit_normal)",
default="sigmoid",
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
help="Timestep sampling method (default: sigmoid (logit normal))",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
help="Scale factor for logit_normal timestep sampling (default: 1.0)",
help="Scale factor for sigmoid (logit_normal) timestep sampling (default: 1.0)",
)
# Note: --caption_dropout_rate is defined by base add_dataset_arguments().
# Anima uses embedding-level dropout (via AnimaTextEncodingStrategy.dropout_rate)
# instead of dataset-level caption dropout, so the subset caption_dropout_rate
# is zeroed out in the training scripts to allow caching.
parser.add_argument(
"--transformer_dtype",
type=str,
"--attn_mode",
choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility
default=None,
choices=["float16", "bfloat16", "float32", None],
help="Separate dtype for transformer blocks. If None, uses same as mixed_precision",
help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa."
" / 使用するAttentionの実装。デフォルトはNonetorchです。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません推論のみ。このオプションは--xformersまたは--sdpaを上書きします。",
)
parser.add_argument(
"--flash_attn",
"--split_attn",
action="store_true",
help="Use Flash Attention for DiT self/cross-attention (requires flash-attn package). "
"Falls back to PyTorch SDPA if flash-attn is not installed.",
help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
)
@@ -218,7 +198,7 @@ def get_noisy_model_input_and_timesteps(
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
"""Compute loss weighting for Anima training.
Same schemes as SD3 but can add Anima-specific ones.
Same schemes as SD3 but can add Anima-specific ones if needed in future.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
@@ -382,6 +362,7 @@ def do_sample(
dtype: torch.dtype,
device: torch.device,
guidance_scale: float = 1.0,
flow_shift: float = 3.0,
neg_crossattn_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Generate a sample using Euler discrete sampling for rectified flow.
@@ -395,6 +376,7 @@ def do_sample(
dtype: Compute dtype
device: Compute device
guidance_scale: CFG scale (1.0 = no guidance)
flow_shift: Flow shift parameter for rectified flow
neg_crossattn_emb: Negative cross-attention embeddings for CFG
Returns:
@@ -414,6 +396,9 @@ def do_sample(
# Timestep schedule: linear from 1.0 to 0.0
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
flow_shift = float(flow_shift)
if flow_shift != 1.0:
sigmas = (sigmas * flow_shift) / (1 + (flow_shift - 1) * sigmas)
# Start from pure noise
x = noise.clone()
@@ -461,7 +446,6 @@ def sample_images(
steps,
dit,
vae,
vae_scale,
text_encoder,
tokenize_strategy,
text_encoding_strategy,
@@ -515,7 +499,6 @@ def sample_images(
dit,
text_encoder,
vae,
vae_scale,
tokenize_strategy,
text_encoding_strategy,
save_dir,
@@ -539,8 +522,7 @@ def _sample_image_inference(
args,
dit,
text_encoder,
vae,
vae_scale,
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage,
tokenize_strategy,
text_encoding_strategy,
save_dir,
@@ -558,6 +540,7 @@ def _sample_image_inference(
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
flow_shift = prompt_dict.get("flow_shift", 3.0)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
@@ -571,7 +554,9 @@ def _sample_image_inference(
height = max(64, height - height % 16)
width = max(64, width - width % 16)
logger.info(f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}")
logger.info(
f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}, flow_shift: {flow_shift}, seed: {seed}"
)
# Encode prompt
def encode_prompt(prpt):
@@ -597,14 +582,14 @@ def _sample_image_inference(
t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0)
t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0)
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.dtype)
attn_mask = attn_mask.to(accelerator.device)
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
t5_attn_mask = t5_attn_mask.to(accelerator.device)
# Process through LLM adapter if available
if dit.use_llm_adapter and hasattr(dit, "llm_adapter"):
crossattn_emb = dit.llm_adapter(
if dit.net.use_llm_adapter:
crossattn_emb = dit.net.llm_adapter(
source_hidden_states=prompt_embeds,
target_input_ids=t5_input_ids,
target_attention_mask=t5_attn_mask,
@@ -626,13 +611,13 @@ def _sample_image_inference(
neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0)
neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0)
neg_pe = neg_pe.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
neg_pe = neg_pe.to(accelerator.device, dtype=dit.dtype)
neg_am = neg_am.to(accelerator.device)
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
neg_t5_am = neg_t5_am.to(accelerator.device)
if dit.use_llm_adapter and hasattr(dit, "llm_adapter"):
neg_crossattn_emb = dit.llm_adapter(
if dit.net.use_llm_adapter:
neg_crossattn_emb = dit.net.llm_adapter(
source_hidden_states=neg_pe,
target_input_ids=neg_t5_ids,
target_attention_mask=neg_t5_am,
@@ -645,23 +630,14 @@ def _sample_image_inference(
# Generate sample
clean_memory_on_device(accelerator.device)
latents = do_sample(
height,
width,
seed,
dit,
crossattn_emb,
sample_steps,
dit.t_embedding_norm.weight.dtype,
accelerator.device,
scale,
neg_crossattn_emb,
height, width, seed, dit, crossattn_emb, sample_steps, dit.dtype, accelerator.device, scale, flow_shift, neg_crossattn_emb
)
# Decode latents
clean_memory_on_device(accelerator.device)
org_vae_device = next(vae.parameters()).device
org_vae_device = vae.device
vae.to(accelerator.device)
decoded = vae.decode(latents.to(next(vae.parameters()).device, dtype=next(vae.parameters()).dtype), vae_scale)
decoded = vae.decode_to_pixels(latents)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)