diff --git a/anima_train.py b/anima_train.py index 081d6963..ae3cf6a0 100644 --- a/anima_train.py +++ b/anima_train.py @@ -859,11 +859,11 @@ def setup_parser() -> argparse.ArgumentParser: anima_train_utils.add_anima_training_arguments(parser) sai_model_spec.add_model_spec_arguments(parser) - parser.add_argument( - "--blockwise_fused_optimizers", - action="store_true", - help="enable blockwise optimizers for fused backward pass and optimizer step", - ) + # parser.add_argument( + # "--blockwise_fused_optimizers", + # action="store_true", + # help="enable blockwise optimizers for fused backward pass and optimizer step", + # ) parser.add_argument( "--cpu_offload_checkpointing", action="store_true", diff --git a/library/anima_train_utils.py b/library/anima_train_utils.py index edac2fb7..2e696c43 100644 --- a/library/anima_train_utils.py +++ b/library/anima_train_utils.py @@ -15,6 +15,7 @@ from tqdm import tqdm from PIL import Image from library.device_utils import init_ipex, clean_memory_on_device +from library import anima_models, anima_utils, strategy_base, train_util, qwen_image_autoencoder_kl init_ipex() @@ -25,10 +26,6 @@ import logging logger = logging.getLogger(__name__) -from library import anima_models, anima_utils, strategy_base, train_util - -from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas - # Anima-specific training arguments @@ -36,19 +33,7 @@ from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas def add_anima_training_arguments(parser: argparse.ArgumentParser): """Add Anima-specific training arguments to the parser.""" parser.add_argument( - "--dit_path", - type=str, - default=None, - help="Path to Anima DiT model safetensors file", - ) - parser.add_argument( - "--vae_path", - type=str, - default=None, - help="Path to WanVAE safetensors/pth file", - ) - parser.add_argument( - "--qwen3_path", + "--qwen3", type=str, default=None, help="Path to Qwen3-0.6B model (safetensors file or directory)", @@ -87,7 +72,7 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser): "--mod_lr", type=float, default=None, - help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze", + help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze. Note: mod layers are not included in LoRA by default.", ) parser.add_argument( "--t5_tokenizer_path", @@ -114,34 +99,29 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser): help="Timestep distribution shift for rectified flow training (default: 1.0)", ) parser.add_argument( - "--timestep_sample_method", + "--timestep_sampling", type=str, - default="logit_normal", - choices=["logit_normal", "uniform"], - help="Timestep sampling method (default: logit_normal)", + default="sigmoid", + choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"], + help="Timestep sampling method (default: sigmoid (logit normal))", ) parser.add_argument( "--sigmoid_scale", type=float, default=1.0, - help="Scale factor for logit_normal timestep sampling (default: 1.0)", + help="Scale factor for sigmoid (logit_normal) timestep sampling (default: 1.0)", ) - # Note: --caption_dropout_rate is defined by base add_dataset_arguments(). - # Anima uses embedding-level dropout (via AnimaTextEncodingStrategy.dropout_rate) - # instead of dataset-level caption dropout, so the subset caption_dropout_rate - # is zeroed out in the training scripts to allow caching. parser.add_argument( - "--transformer_dtype", - type=str, + "--attn_mode", + choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility default=None, - choices=["float16", "bfloat16", "float32", None], - help="Separate dtype for transformer blocks. If None, uses same as mixed_precision", + help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa." + " / 使用するAttentionの実装。デフォルトはNone(torch)です。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません(推論のみ)。このオプションは--xformersまたは--sdpaを上書きします。", ) parser.add_argument( - "--flash_attn", + "--split_attn", action="store_true", - help="Use Flash Attention for DiT self/cross-attention (requires flash-attn package). " - "Falls back to PyTorch SDPA if flash-attn is not installed.", + help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する", ) @@ -218,7 +198,7 @@ def get_noisy_model_input_and_timesteps( def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor: """Compute loss weighting for Anima training. - Same schemes as SD3 but can add Anima-specific ones. + Same schemes as SD3 but can add Anima-specific ones if needed in future. """ if weighting_scheme == "sigma_sqrt": weighting = (sigmas**-2.0).float() @@ -382,6 +362,7 @@ def do_sample( dtype: torch.dtype, device: torch.device, guidance_scale: float = 1.0, + flow_shift: float = 3.0, neg_crossattn_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Generate a sample using Euler discrete sampling for rectified flow. @@ -395,6 +376,7 @@ def do_sample( dtype: Compute dtype device: Compute device guidance_scale: CFG scale (1.0 = no guidance) + flow_shift: Flow shift parameter for rectified flow neg_crossattn_emb: Negative cross-attention embeddings for CFG Returns: @@ -414,6 +396,9 @@ def do_sample( # Timestep schedule: linear from 1.0 to 0.0 sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype) + flow_shift = float(flow_shift) + if flow_shift != 1.0: + sigmas = (sigmas * flow_shift) / (1 + (flow_shift - 1) * sigmas) # Start from pure noise x = noise.clone() @@ -461,7 +446,6 @@ def sample_images( steps, dit, vae, - vae_scale, text_encoder, tokenize_strategy, text_encoding_strategy, @@ -515,7 +499,6 @@ def sample_images( dit, text_encoder, vae, - vae_scale, tokenize_strategy, text_encoding_strategy, save_dir, @@ -539,8 +522,7 @@ def _sample_image_inference( args, dit, text_encoder, - vae, - vae_scale, + vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage, tokenize_strategy, text_encoding_strategy, save_dir, @@ -558,6 +540,7 @@ def _sample_image_inference( height = prompt_dict.get("height", 512) scale = prompt_dict.get("scale", 7.5) seed = prompt_dict.get("seed") + flow_shift = prompt_dict.get("flow_shift", 3.0) if prompt_replacement is not None: prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) @@ -571,7 +554,9 @@ def _sample_image_inference( height = max(64, height - height % 16) width = max(64, width - width % 16) - logger.info(f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}") + logger.info( + f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}, flow_shift: {flow_shift}, seed: {seed}" + ) # Encode prompt def encode_prompt(prpt): @@ -597,14 +582,14 @@ def _sample_image_inference( t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0) t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0) - prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype) + prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.dtype) attn_mask = attn_mask.to(accelerator.device) t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long) t5_attn_mask = t5_attn_mask.to(accelerator.device) # Process through LLM adapter if available - if dit.use_llm_adapter and hasattr(dit, "llm_adapter"): - crossattn_emb = dit.llm_adapter( + if dit.net.use_llm_adapter: + crossattn_emb = dit.net.llm_adapter( source_hidden_states=prompt_embeds, target_input_ids=t5_input_ids, target_attention_mask=t5_attn_mask, @@ -626,13 +611,13 @@ def _sample_image_inference( neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0) neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0) - neg_pe = neg_pe.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype) + neg_pe = neg_pe.to(accelerator.device, dtype=dit.dtype) neg_am = neg_am.to(accelerator.device) neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long) neg_t5_am = neg_t5_am.to(accelerator.device) - if dit.use_llm_adapter and hasattr(dit, "llm_adapter"): - neg_crossattn_emb = dit.llm_adapter( + if dit.net.use_llm_adapter: + neg_crossattn_emb = dit.net.llm_adapter( source_hidden_states=neg_pe, target_input_ids=neg_t5_ids, target_attention_mask=neg_t5_am, @@ -645,23 +630,14 @@ def _sample_image_inference( # Generate sample clean_memory_on_device(accelerator.device) latents = do_sample( - height, - width, - seed, - dit, - crossattn_emb, - sample_steps, - dit.t_embedding_norm.weight.dtype, - accelerator.device, - scale, - neg_crossattn_emb, + height, width, seed, dit, crossattn_emb, sample_steps, dit.dtype, accelerator.device, scale, flow_shift, neg_crossattn_emb ) # Decode latents clean_memory_on_device(accelerator.device) - org_vae_device = next(vae.parameters()).device + org_vae_device = vae.device vae.to(accelerator.device) - decoded = vae.decode(latents.to(next(vae.parameters()).device, dtype=next(vae.parameters()).dtype), vae_scale) + decoded = vae.decode_to_pixels(latents) vae.to(org_vae_device) clean_memory_on_device(accelerator.device)