diff --git a/_typos.toml b/_typos.toml
index 686da4af..fc33b6b3 100644
--- a/_typos.toml
+++ b/_typos.toml
@@ -32,6 +32,7 @@ hime="hime"
OT="OT"
byt="byt"
tak="tak"
+temperal="temperal"
[files]
-extend-exclude = ["_typos.toml", "venv"]
+extend-exclude = ["_typos.toml", "venv", "configs"]
diff --git a/anima_minimal_inference.py b/anima_minimal_inference.py
new file mode 100644
index 00000000..a0b2e494
--- /dev/null
+++ b/anima_minimal_inference.py
@@ -0,0 +1,1025 @@
+import argparse
+import datetime
+import gc
+from importlib.util import find_spec
+import random
+import os
+import re
+import time
+import copy
+from types import ModuleType, SimpleNamespace
+from typing import Tuple, Optional, List, Any, Dict, Union
+
+import numpy as np
+import torch
+from safetensors.torch import load_file, save_file
+from safetensors import safe_open
+from tqdm import tqdm
+from diffusers.utils.torch_utils import randn_tensor
+from PIL import Image
+
+from library import anima_models, anima_utils, hunyuan_image_utils, qwen_image_autoencoder_kl, strategy_anima, strategy_base
+from library.anima_vae import WanVAE_
+from library.device_utils import clean_memory_on_device, synchronize_device
+from library.safetensors_utils import mem_eff_save_file
+from networks import lora_hunyuan_image
+
+lycoris_available = find_spec("lycoris") is not None
+if lycoris_available:
+ from lycoris.kohya import create_network_from_weights
+
+from library.utils import setup_logging
+
+setup_logging()
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class GenerationSettings:
+ def __init__(self, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None):
+ self.device = device
+ self.dit_weight_dtype = dit_weight_dtype # not used currently because model may be optimized
+
+
+def parse_args() -> argparse.Namespace:
+ """parse command line arguments"""
+ parser = argparse.ArgumentParser(description="HunyuanImage inference script")
+
+ parser.add_argument("--dit", type=str, default=None, help="DiT directory or path")
+ parser.add_argument("--vae", type=str, default=None, help="VAE directory or path")
+ parser.add_argument("--text_encoder", type=str, required=True, help="Text Encoder 1 (Qwen2.5-VL) directory or path")
+
+ # LoRA
+ parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
+ parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
+ parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns")
+ parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns")
+
+ # inference
+ parser.add_argument(
+ "--guidance_scale", type=float, default=3.5, help="Guidance scale for classifier free guidance. Default is 3.5."
+ )
+ parser.add_argument("--prompt", type=str, default=None, help="prompt for generation")
+ parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt for generation, default is empty string")
+ parser.add_argument("--image_size", type=int, nargs=2, default=[1024, 1024], help="image size, height and width")
+ parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps, default is 50")
+ parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
+ parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
+
+ # Flow Matching
+ parser.add_argument(
+ "--flow_shift",
+ type=float,
+ default=5.0,
+ help="Shift factor for flow matching schedulers. Default is 5.0.",
+ )
+
+ parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
+ parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
+
+ parser.add_argument("--text_encoder_cpu", action="store_true", help="Inference on CPU for Text Encoders")
+ parser.add_argument(
+ "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
+ )
+ parser.add_argument(
+ "--attn_mode",
+ type=str,
+ default="torch",
+ choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "sdpa" for backward compatibility
+ help="attention mode",
+ )
+ parser.add_argument(
+ "--output_type",
+ type=str,
+ default="images",
+ choices=["images", "latent", "latent_images"],
+ help="output type",
+ )
+ parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
+ parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
+ parser.add_argument(
+ "--lycoris", action="store_true", help=f"use lycoris for inference{'' if lycoris_available else ' (not available)'}"
+ )
+
+ # arguments for batch and interactive modes
+ parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file")
+ parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console")
+
+ args = parser.parse_args()
+
+ # Validate arguments
+ if args.from_file and args.interactive:
+ raise ValueError("Cannot use both --from_file and --interactive at the same time")
+
+ if args.latent_path is None or len(args.latent_path) == 0:
+ if args.prompt is None and not args.from_file and not args.interactive:
+ raise ValueError("Either --prompt, --from_file or --interactive must be specified")
+
+ if args.lycoris and not lycoris_available:
+ raise ValueError("install lycoris: https://github.com/KohakuBlueleaf/LyCORIS")
+
+ if args.attn_mode == "sdpa":
+ args.attn_mode = "torch" # backward compatibility
+
+ return args
+
+
+def parse_prompt_line(line: str) -> Dict[str, Any]:
+ """Parse a prompt line into a dictionary of argument overrides
+
+ Args:
+ line: Prompt line with options
+
+ Returns:
+ Dict[str, Any]: Dictionary of argument overrides
+ """
+ parts = line.split(" --")
+ prompt = parts[0].strip()
+
+ # Create dictionary of overrides
+ overrides = {"prompt": prompt}
+
+ for part in parts[1:]:
+ if not part.strip():
+ continue
+ option_parts = part.split(" ", 1)
+ option = option_parts[0].strip()
+ value = option_parts[1].strip() if len(option_parts) > 1 else ""
+
+ # Map options to argument names
+ if option == "w":
+ overrides["image_size_width"] = int(value)
+ elif option == "h":
+ overrides["image_size_height"] = int(value)
+ elif option == "d":
+ overrides["seed"] = int(value)
+ elif option == "s":
+ overrides["infer_steps"] = int(value)
+ elif option == "g" or option == "l":
+ overrides["guidance_scale"] = float(value)
+ elif option == "fs":
+ overrides["flow_shift"] = float(value)
+ elif option == "n":
+ overrides["negative_prompt"] = value
+
+ return overrides
+
+
+def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace:
+ """Apply overrides to args
+
+ Args:
+ args: Original arguments
+ overrides: Dictionary of overrides
+
+ Returns:
+ argparse.Namespace: New arguments with overrides applied
+ """
+ args_copy = copy.deepcopy(args)
+
+ for key, value in overrides.items():
+ if key == "image_size_width":
+ args_copy.image_size[1] = value
+ elif key == "image_size_height":
+ args_copy.image_size[0] = value
+ else:
+ setattr(args_copy, key, value)
+
+ return args_copy
+
+
+def check_inputs(args: argparse.Namespace) -> Tuple[int, int]:
+ """Validate video size and length
+
+ Args:
+ args: command line arguments
+
+ Returns:
+ Tuple[int, int]: (height, width)
+ """
+ height = args.image_size[0]
+ width = args.image_size[1]
+
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ return height, width
+
+
+# region Model
+
+
+def load_dit_model(
+ args: argparse.Namespace, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None
+) -> anima_models.Anima:
+ """load DiT model
+
+ Args:
+ args: command line arguments
+ device: device to use
+ dit_weight_dtype: data type for the model weights. None for as-is
+
+ Returns:
+ anima_models.Anima: DiT model instance
+ """
+ # If LyCORIS is enabled, we will load the model to CPU and then merge LoRA weights (static method)
+
+ loading_device = "cpu"
+ if not args.lycoris:
+ loading_device = device
+
+ # load LoRA weights
+ if not args.lycoris and args.lora_weight is not None and len(args.lora_weight) > 0:
+ lora_weights_list = []
+ for lora_weight in args.lora_weight:
+ logger.info(f"Loading LoRA weight from: {lora_weight}")
+ lora_sd = load_file(lora_weight) # load on CPU, dtype is as is
+ # lora_sd = filter_lora_state_dict(lora_sd, args.include_patterns, args.exclude_patterns)
+ lora_weights_list.append(lora_sd)
+ else:
+ lora_weights_list = None
+
+ loading_weight_dtype = dit_weight_dtype
+ if args.fp8_scaled and not args.lycoris:
+ loading_weight_dtype = None # we will load weights as-is and then optimize to fp8
+
+ model = anima_utils.load_anima_model(
+ device,
+ args.dit,
+ args.attn_mode,
+ True, # enable split_attn to trim masked tokens
+ loading_device,
+ loading_weight_dtype,
+ args.fp8_scaled and not args.lycoris,
+ lora_weights_list=lora_weights_list,
+ lora_multipliers=args.lora_multiplier,
+ )
+ if not args.fp8_scaled:
+ # simple cast to dit_weight_dtype
+ target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict)
+ if dit_weight_dtype is not None: # in case of args.fp8 and not args.fp8_scaled
+ logger.info(f"Convert model to {dit_weight_dtype}")
+ target_dtype = dit_weight_dtype
+
+ logger.info(f"Move model to device: {device}")
+ target_device = device
+
+ model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations
+
+ # model.to(device)
+ model.to(device, dtype=torch.bfloat16) # ensure model is in bfloat16 for inference
+
+ model.eval().requires_grad_(False)
+ clean_memory_on_device(device)
+
+ return model
+
+
+# endregion
+
+
+def decode_latent(vae: WanVAE_, latent: torch.Tensor, device: torch.device) -> torch.Tensor:
+ logger.info(f"Decoding image. Latent shape {latent.shape}, device {device}")
+
+ vae.to(device)
+ with torch.no_grad():
+ pixels = vae.decode_to_pixels(latent.to(device, dtype=vae.dtype))
+ # pixels = vae.decode(latent.to(device, dtype=torch.bfloat16), scale=vae_scale)
+ if pixels.ndim == 5: # remove frame dimension if exists, [B, C, F, H, W] -> [B, C, H, W]
+ pixels = pixels.squeeze(2)
+
+ pixels = pixels.to("cpu", dtype=torch.float32) # move to CPU and convert to float32 (bfloat16 is not supported by numpy)
+ vae.to("cpu")
+
+ logger.info(f"Decoded. Pixel shape {pixels.shape}")
+ return pixels[0] # remove batch dimension
+
+
+def prepare_text_inputs(
+ args: argparse.Namespace, device: torch.device, anima: anima_models.Anima, shared_models: Optional[Dict] = None
+) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ """Prepare text-related inputs for T2I: LLM encoding. Anima model is also needed for preprocessing"""
+
+ # load text encoder: conds_cache holds cached encodings for prompts without padding
+ conds_cache = {}
+ text_encoder_device = torch.device("cpu") if args.text_encoder_cpu else device
+ if shared_models is not None:
+ text_encoder = shared_models.get("text_encoder")
+
+ if "conds_cache" in shared_models: # Use shared cache if available
+ conds_cache = shared_models["conds_cache"]
+
+ # text_encoder is on device (batched inference) or CPU (interactive inference)
+ else: # Load if not in shared_models
+ text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
+ text_encoder, _ = anima_utils.load_qwen3_text_encoder(
+ args.text_encoder, dtype=text_encoder_dtype, device=text_encoder_device
+ )
+ text_encoder.eval()
+ tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
+ # Store references so load_target_model can reuse them
+
+ # Store original devices to move back later if they were shared. This does nothing if shared_models is None
+ text_encoder_original_device = text_encoder.device if text_encoder else None
+
+ # Ensure text_encoder is not None before proceeding
+ if not text_encoder:
+ raise ValueError("Text encoder is not loaded properly.")
+
+ # Define a function to move models to device if needed
+ # This is to avoid moving models if not needed, especially in interactive mode
+ model_is_moved = False
+
+ def move_models_to_device_if_needed():
+ nonlocal model_is_moved
+ nonlocal shared_models
+
+ if model_is_moved:
+ return
+ model_is_moved = True
+
+ logger.info(f"Moving Text Encoder to appropriate device: {text_encoder_device}")
+ text_encoder.to(text_encoder_device) # If text_encoder_cpu is True, this will be CPU
+
+ logger.info("Encoding prompt with Text Encoder")
+
+ prompt = args.prompt
+ cache_key = prompt
+ if cache_key in conds_cache:
+ embed = conds_cache[cache_key]
+ else:
+ move_models_to_device_if_needed()
+
+ tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
+ encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
+
+ with torch.no_grad():
+ # embed = anima_text_encoder.get_text_embeds(anima, tokenizer, text_encoder, t5xxl_tokenizer, prompt)
+ tokens = tokenize_strategy.tokenize(prompt)
+ embed = encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)
+ crossattn_emb = anima._preprocess_text_embeds(
+ source_hidden_states=embed[0].to(anima.device),
+ target_input_ids=embed[2].to(anima.device),
+ target_attention_mask=embed[3].to(anima.device),
+ source_attention_mask=embed[1].to(anima.device),
+ )
+ crossattn_emb[~embed[3].bool()] = 0
+ embed[0] = crossattn_emb
+ embed[0] = embed[0].cpu()
+
+ conds_cache[cache_key] = embed
+
+ negative_prompt = args.negative_prompt
+ cache_key = negative_prompt
+ if cache_key in conds_cache:
+ negative_embed = conds_cache[cache_key]
+ else:
+ move_models_to_device_if_needed()
+
+ with torch.no_grad():
+ # negative_embed = anima_text_encoder.get_text_embeds(anima, tokenizer, text_encoder, t5xxl_tokenizer, negative_prompt)
+ tokens = tokenize_strategy.tokenize(negative_prompt)
+ negative_embed = encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)
+ crossattn_emb = anima._preprocess_text_embeds(
+ source_hidden_states=negative_embed[0].to(anima.device),
+ target_input_ids=negative_embed[2].to(anima.device),
+ target_attention_mask=negative_embed[3].to(anima.device),
+ source_attention_mask=negative_embed[1].to(anima.device),
+ )
+ crossattn_emb[~negative_embed[3].bool()] = 0
+ negative_embed[0] = crossattn_emb
+ negative_embed[0] = negative_embed[0].cpu()
+
+ conds_cache[cache_key] = negative_embed
+
+ if not (shared_models and "text_encoder" in shared_models): # if loaded locally
+ # There is a bug text_encoder is not freed from GPU memory when text encoder is fp8
+ del text_encoder
+ gc.collect() # This may force Text Encoder to be freed from GPU memory
+ else: # if shared, move back to original device (likely CPU)
+ if text_encoder:
+ text_encoder.to(text_encoder_original_device)
+
+ clean_memory_on_device(device)
+
+ arg_c = {"embed": embed, "prompt": prompt}
+ arg_null = {"embed": negative_embed, "prompt": negative_prompt}
+
+ return arg_c, arg_null
+
+
+def generate(
+ args: argparse.Namespace,
+ gen_settings: GenerationSettings,
+ shared_models: Optional[Dict] = None,
+ precomputed_text_data: Optional[Dict] = None,
+) -> torch.Tensor:
+ """main function for generation
+
+ Args:
+ args: command line arguments
+ shared_models: dictionary containing pre-loaded models (mainly for DiT)
+ precomputed_image_data: Optional dictionary with precomputed image data
+ precomputed_text_data: Optional dictionary with precomputed text data
+
+ Returns:
+ tuple: (HunyuanVAE2D model (vae) or None, torch.Tensor generated latent)
+ """
+ device, dit_weight_dtype = (gen_settings.device, gen_settings.dit_weight_dtype)
+
+ # prepare seed
+ seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
+ args.seed = seed # set seed to args for saving
+
+ if shared_models is None or "model" not in shared_models:
+ # load DiT model
+ anima = load_dit_model(args, device, dit_weight_dtype)
+
+ if shared_models is not None:
+ shared_models["model"] = anima
+ else:
+ # use shared model
+ logger.info("Using shared DiT model.")
+ anima: anima_models.Anima = shared_models["model"]
+
+ if precomputed_text_data is not None:
+ logger.info("Using precomputed text data.")
+ context = precomputed_text_data["context"]
+ context_null = precomputed_text_data["context_null"]
+
+ else:
+ logger.info("No precomputed data. Preparing image and text inputs.")
+ context, context_null = prepare_text_inputs(args, device, anima, shared_models)
+
+ return generate_body(args, anima, context, context_null, device, seed)
+
+
+def generate_body(
+ args: Union[argparse.Namespace, SimpleNamespace],
+ anima: anima_models.Anima,
+ context: Dict[str, Any],
+ context_null: Optional[Dict[str, Any]],
+ device: torch.device,
+ seed: int,
+) -> torch.Tensor:
+
+ # set random generator
+ seed_g = torch.Generator(device="cpu")
+ seed_g.manual_seed(seed)
+
+ height, width = check_inputs(args)
+ logger.info(f"Image size: {height}x{width} (HxW), infer_steps: {args.infer_steps}")
+
+ # image generation ######
+
+ logger.info(f"Prompt: {context['prompt']}")
+
+ embed = context["embed"][0].to(device, dtype=torch.bfloat16)
+ if context_null is None:
+ context_null = context # dummy for unconditional
+ negative_embed = context_null["embed"][0].to(device, dtype=torch.bfloat16)
+
+ # Prepare latent variables
+ num_channels_latents = anima_models.Anima.LATENT_CHANNELS
+ shape = (
+ 1,
+ num_channels_latents,
+ 1, # Frame dimension
+ height // 8, # qwen_image_autoencoder_kl.SCALE_FACTOR,
+ width // 8, # qwen_image_autoencoder_kl.SCALE_FACTOR,
+ )
+ latents = randn_tensor(shape, generator=seed_g, device=device, dtype=torch.bfloat16)
+
+ # Create padding mask
+ bs = latents.shape[0]
+ h_latent = latents.shape[-2]
+ w_latent = latents.shape[-1]
+ padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=torch.bfloat16, device=device)
+
+ logger.info(f"Embed: {embed.shape}, negative_embed: {negative_embed.shape}, latents: {latents.shape}")
+ embed = embed.to(torch.bfloat16)
+ negative_embed = negative_embed.to(torch.bfloat16)
+
+ # Prepare timesteps
+ timesteps, sigmas = hunyuan_image_utils.get_timesteps_sigmas(args.infer_steps, args.flow_shift, device)
+ timesteps /= 1000 # scale to [0,1] range
+ timesteps = timesteps.to(device, dtype=torch.bfloat16)
+
+ # Denoising loop
+ do_cfg = args.guidance_scale != 1.0
+ autocast_enabled = args.fp8
+
+ with tqdm(total=len(timesteps), desc="Denoising steps") as pbar:
+ for i, t in enumerate(timesteps):
+ t_expand = t.expand(latents.shape[0])
+
+ with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
+ noise_pred = anima(latents, t_expand, embed, padding_mask=padding_mask)
+
+ if do_cfg:
+ with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
+ uncond_noise_pred = anima(latents, t_expand, negative_embed, padding_mask=padding_mask)
+ noise_pred = uncond_noise_pred + args.guidance_scale * (noise_pred - uncond_noise_pred)
+
+ # ensure latents dtype is consistent
+ latents = hunyuan_image_utils.step(latents, noise_pred, sigmas, i).to(latents.dtype)
+
+ pbar.update()
+
+ return latents
+
+
+def get_time_flag():
+ return datetime.datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S-%f")[:-3]
+
+
+def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str:
+ """Save latent to file
+
+ Args:
+ latent: Latent tensor
+ args: command line arguments
+ height: height of frame
+ width: width of frame
+
+ Returns:
+ str: Path to saved latent file
+ """
+ save_path = args.save_path
+ os.makedirs(save_path, exist_ok=True)
+ time_flag = get_time_flag()
+
+ seed = args.seed
+
+ latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors"
+
+ if args.no_metadata:
+ metadata = None
+ else:
+ metadata = {
+ "seeds": f"{seed}",
+ "prompt": f"{args.prompt}",
+ "height": f"{height}",
+ "width": f"{width}",
+ "infer_steps": f"{args.infer_steps}",
+ # "embedded_cfg_scale": f"{args.embedded_cfg_scale}",
+ "guidance_scale": f"{args.guidance_scale}",
+ }
+ if args.negative_prompt is not None:
+ metadata["negative_prompt"] = f"{args.negative_prompt}"
+
+ sd = {"latent": latent.contiguous()}
+ save_file(sd, latent_path, metadata=metadata)
+ logger.info(f"Latent saved to: {latent_path}")
+
+ return latent_path
+
+
+def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str:
+ """Save images to directory
+
+ Args:
+ sample: Video tensor
+ args: command line arguments
+ original_base_name: Original base name (if latents are loaded from files)
+
+ Returns:
+ str: Path to saved images directory
+ """
+ save_path = args.save_path
+ os.makedirs(save_path, exist_ok=True)
+ time_flag = get_time_flag()
+
+ seed = args.seed
+ original_name = "" if original_base_name is None else f"_{original_base_name}"
+ image_name = f"{time_flag}_{seed}{original_name}"
+
+ x = torch.clamp(sample, -1.0, 1.0)
+ x = ((x + 1.0) * 127.5).to(torch.uint8).cpu().numpy()
+ x = x.transpose(1, 2, 0) # C, H, W -> H, W, C
+
+ image = Image.fromarray(x)
+ image.save(os.path.join(save_path, f"{image_name}.png"))
+
+ logger.info(f"Sample images saved to: {save_path}/{image_name}")
+
+ return f"{save_path}/{image_name}"
+
+
+def save_output(
+ args: argparse.Namespace,
+ vae: WanVAE_,
+ latent: torch.Tensor,
+ device: torch.device,
+ original_base_name: Optional[str] = None,
+) -> None:
+ """save output
+
+ Args:
+ args: command line arguments
+ vae: VAE model
+ latent: latent tensor
+ device: device to use
+ original_base_name: original base name (if latents are loaded from files)
+ """
+ height, width = latent.shape[-2], latent.shape[-1] # BCTHW
+ height *= 8 # qwen_image_autoencoder_kl.SCALE_FACTOR
+ width *= 8 # qwen_image_autoencoder_kl.SCALE_FACTOR
+ # print(f"Saving output. Latent shape {latent.shape}; pixel shape {height}x{width}")
+ if args.output_type == "latent" or args.output_type == "latent_images":
+ # save latent
+ save_latent(latent, args, height, width)
+ if args.output_type == "latent":
+ return
+
+ if vae is None:
+ logger.error("VAE is None, cannot decode latents for saving video/images.")
+ return
+
+ if latent.ndim == 2: # S,C. For packed latents from other inference scripts
+ latent = latent.unsqueeze(0)
+ height, width = check_inputs(args) # Get height/width from args
+ latent = latent.view(
+ 1,
+ vae.latent_channels,
+ 1, # Frame dimension
+ height // 8, # qwen_image_autoencoder_kl.SCALE_FACTOR,
+ width // 8, # qwen_image_autoencoder_kl.SCALE_FACTOR,
+ )
+
+ image = decode_latent(vae, latent, device)
+
+ if args.output_type == "images" or args.output_type == "latent_images":
+ # save images
+ if original_base_name is None:
+ original_name = ""
+ else:
+ original_name = f"_{original_base_name}"
+ save_images(image, args, original_name)
+
+
+def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]:
+ """Process multiple prompts for batch mode
+
+ Args:
+ prompt_lines: List of prompt lines
+ base_args: Base command line arguments
+
+ Returns:
+ List[Dict]: List of prompt data dictionaries
+ """
+ prompts_data = []
+
+ for line in prompt_lines:
+ line = line.strip()
+ if not line or line.startswith("#"): # Skip empty lines and comments
+ continue
+
+ # Parse prompt line and create override dictionary
+ prompt_data = parse_prompt_line(line)
+ logger.info(f"Parsed prompt data: {prompt_data}")
+ prompts_data.append(prompt_data)
+
+ return prompts_data
+
+
+def load_shared_models(args: argparse.Namespace) -> Dict:
+ """Load shared models for batch processing or interactive mode.
+ Models are loaded to CPU to save memory. VAE is NOT loaded here.
+ DiT model is also NOT loaded here, handled by process_batch_prompts or generate.
+
+ Args:
+ args: Base command line arguments
+
+ Returns:
+ Dict: Dictionary of shared models (text/image encoders)
+ """
+ shared_models = {}
+ # Load text encoders to CPU
+ text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
+ text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.text_encoder, dtype=text_encoder_dtype, device="cpu")
+ shared_models["text_encoder"] = text_encoder
+ return shared_models
+
+
+def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None:
+ """Process multiple prompts with model reuse and batched precomputation
+
+ Args:
+ prompts_data: List of prompt data dictionaries
+ args: Base command line arguments
+ """
+ if not prompts_data:
+ logger.warning("No valid prompts found")
+ return
+
+ gen_settings = get_generation_settings(args)
+ dit_weight_dtype = gen_settings.dit_weight_dtype
+ device = gen_settings.device
+
+ # 1. Prepare VAE
+ logger.info("Loading VAE for batch generation...")
+ vae_for_batch = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
+ vae_for_batch.to(torch.bfloat16)
+ vae_for_batch.eval()
+
+ all_prompt_args_list = [apply_overrides(args, pd) for pd in prompts_data] # Create all arg instances first
+ for prompt_args in all_prompt_args_list:
+ check_inputs(prompt_args) # Validate each prompt's height/width
+
+ # 2. Load DiT Model once
+ logger.info("Loading DiT model for batch generation...")
+ # Use args from the first prompt for DiT loading (LoRA etc. should be consistent for a batch)
+ first_prompt_args = all_prompt_args_list[0]
+ anima = load_dit_model(first_prompt_args, device, dit_weight_dtype) # Load directly to target device if possible
+
+ shared_models_for_generate = {"model": anima} # Pass DiT via shared_models
+
+ # 3. Precompute Text Data (Text Encoder)
+ logger.info("Loading Text Encoder for batch text preprocessing...")
+
+ # Text Encoder loaded to CPU by load_text_encoder
+ text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
+ text_encoder_batch, _ = anima_utils.load_qwen3_text_encoder(args.text_encoder, dtype=text_encoder_dtype, device="cpu")
+
+ # Text Encoder to device for this phase
+ text_encoder_device = torch.device("cpu") if args.text_encoder_cpu else device
+ text_encoder_batch.to(text_encoder_device) # Moved into prepare_text_inputs logic
+
+ all_precomputed_text_data = []
+ conds_cache_batch = {}
+
+ logger.info("Preprocessing text and LLM/TextEncoder encoding for all prompts...")
+ temp_shared_models_txt = {
+ "text_encoder": text_encoder_batch, # on GPU if not text_encoder_cpu
+ "conds_cache": conds_cache_batch,
+ }
+
+ for i, prompt_args_item in enumerate(all_prompt_args_list):
+ logger.info(f"Text preprocessing for prompt {i+1}/{len(all_prompt_args_list)}: {prompt_args_item.prompt}")
+
+ # prepare_text_inputs will move text_encoders to device temporarily
+ context, context_null = prepare_text_inputs(prompt_args_item, device, anima, temp_shared_models_txt)
+ text_data = {"context": context, "context_null": context_null}
+ all_precomputed_text_data.append(text_data)
+
+ # Models should be removed from device after prepare_text_inputs
+ del text_encoder_batch, temp_shared_models_txt, conds_cache_batch
+ gc.collect() # Force cleanup of Text Encoder from GPU memory
+ clean_memory_on_device(device)
+
+ all_latents = []
+
+ logger.info("Generating latents for all prompts...")
+ with torch.no_grad():
+ for i, prompt_args_item in enumerate(all_prompt_args_list):
+ current_text_data = all_precomputed_text_data[i]
+ height, width = check_inputs(prompt_args_item) # Get height/width for each prompt
+
+ logger.info(f"Generating latent for prompt {i+1}/{len(all_prompt_args_list)}: {prompt_args_item.prompt}")
+ try:
+ # generate is called with precomputed data, so it won't load Text Encoders.
+ # It will use the DiT model from shared_models_for_generate.
+ latent = generate(prompt_args_item, gen_settings, shared_models_for_generate, current_text_data)
+
+ if latent is None: # and prompt_args_item.save_merged_model: # Should be caught earlier
+ continue
+
+ # Save latent if needed (using data from precomputed_image_data for H/W)
+ if prompt_args_item.output_type in ["latent", "latent_images"]:
+ save_latent(latent, prompt_args_item, height, width)
+
+ all_latents.append(latent)
+ except Exception as e:
+ logger.error(f"Error generating latent for prompt: {prompt_args_item.prompt}. Error: {e}", exc_info=True)
+ all_latents.append(None) # Add placeholder for failed generations
+ continue
+
+ # Free DiT model
+ logger.info("Releasing DiT model from memory...")
+
+ del shared_models_for_generate["model"]
+ del anima
+ clean_memory_on_device(device)
+ synchronize_device(device) # Ensure memory is freed before loading VAE for decoding
+
+ # 4. Decode latents and save outputs (using vae_for_batch)
+ if args.output_type != "latent":
+ logger.info("Decoding latents to videos/images using batched VAE...")
+ vae_for_batch.to(device) # Move VAE to device for decoding
+
+ for i, latent in enumerate(all_latents):
+ if latent is None: # Skip failed generations
+ logger.warning(f"Skipping decoding for prompt {i+1} due to previous error.")
+ continue
+
+ current_args = all_prompt_args_list[i]
+ logger.info(f"Decoding output {i+1}/{len(all_latents)} for prompt: {current_args.prompt}")
+
+ # if args.output_type is "latent_images", we already saved latent above.
+ # so we skip saving latent here.
+ if current_args.output_type == "latent_images":
+ current_args.output_type = "images"
+
+ # save_output expects latent to be [BCTHW] or [CTHW]. generate returns [BCTHW] (batch size 1).
+ save_output(current_args, vae_for_batch, latent, device) # Pass vae_for_batch
+
+ vae_for_batch.to("cpu") # Move VAE back to CPU
+
+ del vae_for_batch
+ clean_memory_on_device(device)
+
+
+def process_interactive(args: argparse.Namespace) -> None:
+ """Process prompts in interactive mode
+
+ Args:
+ args: Base command line arguments
+ """
+ gen_settings = get_generation_settings(args)
+ device = gen_settings.device
+ shared_models = load_shared_models(args)
+ shared_models["conds_cache"] = {} # Initialize empty cache for interactive mode
+
+ vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
+ vae.to(torch.bfloat16)
+ vae.eval()
+
+ print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):")
+
+ try:
+ import prompt_toolkit
+ except ImportError:
+ logger.warning("prompt_toolkit not found. Using basic input instead.")
+ prompt_toolkit = None
+
+ if prompt_toolkit:
+ session = prompt_toolkit.PromptSession()
+
+ def input_line(prompt: str) -> str:
+ return session.prompt(prompt)
+
+ else:
+
+ def input_line(prompt: str) -> str:
+ return input(prompt)
+
+ try:
+ while True:
+ try:
+ line = input_line("> ")
+ if not line.strip():
+ continue
+ if len(line.strip()) == 1 and line.strip() in ["\x04", "\x1a"]: # Ctrl+D or Ctrl+Z with prompt_toolkit
+ raise EOFError # Exit on Ctrl+D or Ctrl+Z
+
+ # Parse prompt
+ prompt_data = parse_prompt_line(line)
+ prompt_args = apply_overrides(args, prompt_data)
+
+ # Generate latent
+ # For interactive, precomputed data is None. shared_models contains text encoders.
+ latent = generate(prompt_args, gen_settings, shared_models)
+
+ # # If not one_frame_inference, move DiT model to CPU after generation
+ # model = shared_models.get("model")
+ # model.to("cpu") # Move DiT model to CPU after generation
+
+ # Save latent and video
+ # returned_vae from generate will be used for decoding here.
+ save_output(prompt_args, vae, latent, device)
+
+ except KeyboardInterrupt:
+ print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)")
+ continue
+
+ except EOFError:
+ print("\nExiting interactive mode")
+
+
+def get_generation_settings(args: argparse.Namespace) -> GenerationSettings:
+ device = torch.device(args.device)
+
+ dit_weight_dtype = torch.bfloat16 # default
+ if args.fp8_scaled:
+ dit_weight_dtype = None # various precision weights, so don't cast to specific dtype
+ elif args.fp8:
+ dit_weight_dtype = torch.float8_e4m3fn
+
+ logger.info(f"Using device: {device}, DiT weight weight precision: {dit_weight_dtype}")
+
+ gen_settings = GenerationSettings(device=device, dit_weight_dtype=dit_weight_dtype)
+ return gen_settings
+
+
+def main():
+ # Parse arguments
+ args = parse_args()
+
+ # Check if latents are provided
+ latents_mode = args.latent_path is not None and len(args.latent_path) > 0
+
+ # Set device
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+ logger.info(f"Using device: {device}")
+ args.device = device
+
+ if latents_mode:
+ # Original latent decode mode
+ original_base_names = []
+ latents_list = []
+ seeds = []
+
+ # assert len(args.latent_path) == 1, "Only one latent path is supported for now"
+
+ for latent_path in args.latent_path:
+ original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
+ seed = 0
+
+ if os.path.splitext(latent_path)[1] != ".safetensors":
+ latents = torch.load(latent_path, map_location="cpu")
+ else:
+ latents = load_file(latent_path)["latent"]
+ with safe_open(latent_path, framework="pt") as f:
+ metadata = f.metadata()
+ if metadata is None:
+ metadata = {}
+ logger.info(f"Loaded metadata: {metadata}")
+
+ if "seeds" in metadata:
+ seed = int(metadata["seeds"])
+ if "height" in metadata and "width" in metadata:
+ height = int(metadata["height"])
+ width = int(metadata["width"])
+ args.image_size = [height, width]
+
+ seeds.append(seed)
+ logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
+
+ if latents.ndim == 5: # [BCTHW]
+ latents = latents.squeeze(0) # [CTHW]
+
+ latents_list.append(latents)
+
+ # latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape
+
+ for i, latent in enumerate(latents_list):
+ args.seed = seeds[i]
+
+ vae = qwen_image_autoencoder_kl.load_vae(args.vae, device=device, disable_mmap=True)
+ vae.to(torch.bfloat16)
+ vae.eval()
+ save_output(args, vae, latent, device, original_base_names[i])
+
+ else:
+ tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
+ qwen3_path=args.text_encoder, t5_tokenizer_path=None, qwen3_max_length=512, t5_max_length=512
+ )
+ strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
+
+ encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
+ strategy_base.TextEncodingStrategy.set_strategy(encoding_strategy)
+
+ if args.from_file:
+ # Batch mode from file
+
+ # Read prompts from file
+ with open(args.from_file, "r", encoding="utf-8") as f:
+ prompt_lines = f.readlines()
+
+ # Process prompts
+ prompts_data = preprocess_prompts_for_batch(prompt_lines, args)
+ process_batch_prompts(prompts_data, args)
+
+ elif args.interactive:
+ # Interactive mode
+ process_interactive(args)
+
+ else:
+ # Single prompt mode (original behavior)
+
+ # Generate latent
+ gen_settings = get_generation_settings(args)
+
+ # For single mode, precomputed data is None, shared_models is None.
+ # generate will load all necessary models (Text Encoders, DiT).
+ latent = generate(args, gen_settings)
+ # print(f"Generated latent shape: {latent.shape}")
+ # if args.save_merged_model:
+ # return
+
+ clean_memory_on_device(device)
+
+ # Save latent and video
+ vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
+ vae.to(torch.bfloat16)
+ vae.eval()
+ save_output(args, vae, latent, device)
+
+ logger.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/anima_train.py b/anima_train.py
index a86c30c3..13c15f0c 100644
--- a/anima_train.py
+++ b/anima_train.py
@@ -3,6 +3,7 @@
import argparse
from concurrent.futures import ThreadPoolExecutor
import copy
+import gc
import math
import os
from multiprocessing import Value
@@ -12,8 +13,9 @@ import toml
from tqdm import tqdm
import torch
-from library import utils
+from library import flux_train_utils, qwen_image_autoencoder_kl, utils
from library.device_utils import init_ipex, clean_memory_on_device
+from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
init_ipex()
@@ -49,21 +51,18 @@ def train(args):
args.skip_cache_check = args.skip_latents_validity_check
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
- logger.warning(
- "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
- )
+ logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
args.cache_text_encoder_outputs = True
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
- if getattr(args, 'unsloth_offload_checkpointing', False):
+ if args.unsloth_offload_checkpointing:
if not args.gradient_checkpointing:
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
- assert not args.cpu_offload_checkpointing, \
- "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
+ assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
@@ -71,17 +70,17 @@ def train(args):
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
- ) or not getattr(args, 'unsloth_offload_checkpointing', False), \
- "blocks_to_swap is not supported with unsloth_offload_checkpointing"
+ ) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing"
- # Flash attention: validate availability
- if getattr(args, 'flash_attn', False):
- try:
- import flash_attn # noqa: F401
- logger.info("Flash Attention enabled for DiT blocks")
- except ImportError:
- logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
- args.flash_attn = False
+ # # Flash attention: validate availability
+ # if args.flash_attn:
+ # try:
+ # import flash_attn # noqa: F401
+
+ # logger.info("Flash Attention enabled for DiT blocks")
+ # except ImportError:
+ # logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
+ # args.flash_attn = False
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
@@ -104,9 +103,7 @@ def train(args):
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
- logger.warning(
- "ignore following options because config file is found: {0}".format(", ".join(ignored))
- )
+ logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored)))
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
@@ -145,26 +142,13 @@ def train(args):
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
- train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8
-
- # Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
- # dataset-level caption dropout, so we save the rate and zero out subset-level
- # caption_dropout_rate to allow text encoder output caching.
- caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
- if caption_dropout_rate > 0:
- logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
- for dataset in train_dataset_group.datasets:
- for subset in dataset.subsets:
- subset.caption_dropout_rate = 0.0
+ train_dataset_group.verify_bucket_reso_steps(16) # Qwen-Image VAE spatial downscale = 8 * patch size = 2
if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
- args.cache_text_encoder_outputs_to_disk,
- args.text_encoder_batch_size,
- False,
- False,
+ args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
)
)
train_dataset_group.set_current_strategies()
@@ -175,13 +159,11 @@ def train(args):
return
if cache_latents:
- assert (
- train_dataset_group.is_latent_cacheable()
- ), "when caching latents, either color_aug or random_crop cannot be used"
+ assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used"
if args.cache_text_encoder_outputs:
- assert (
- train_dataset_group.is_text_encoder_output_cacheable()
+ assert train_dataset_group.is_text_encoder_output_cacheable(
+ cache_supports_dropout=True
), "when caching text encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
# prepare accelerator
@@ -191,24 +173,10 @@ def train(args):
# mixed precision dtype
weight_dtype, save_dtype = train_util.prepare_dtype(args)
- # parse transformer_dtype
- transformer_dtype = None
- if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None:
- transformer_dtype_map = {
- "float16": torch.float16,
- "bfloat16": torch.bfloat16,
- "float32": torch.float32,
- }
- transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
-
# Load tokenizers and set strategies
logger.info("Loading tokenizers...")
- qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(
- args.qwen3_path, dtype=weight_dtype, device="cpu"
- )
- t5_tokenizer = anima_utils.load_t5_tokenizer(
- getattr(args, 't5_tokenizer_path', None)
- )
+ qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
+ t5_tokenizer = anima_utils.load_t5_tokenizer(args.t5_tokenizer_path)
# Set tokenize strategy
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
@@ -219,11 +187,7 @@ def train(args):
)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
- # Set text encoding strategy
- caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
- text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
- dropout_rate=caption_dropout_rate,
- )
+ text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# Prepare text encoder (always frozen for Anima)
@@ -237,10 +201,7 @@ def train(args):
qwen3_text_encoder.eval()
text_encoder_caching_strategy = strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
- args.cache_text_encoder_outputs_to_disk,
- args.text_encoder_batch_size,
- args.skip_cache_check,
- is_partial=False,
+ args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, is_partial=False
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
@@ -259,27 +220,19 @@ def train(args):
logger.info(f" cache TE outputs for: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
- tokenize_strategy,
- [qwen3_text_encoder],
- tokens_and_masks,
- enable_dropout=False,
+ tokenize_strategy, [qwen3_text_encoder], tokens_and_masks
)
- # Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
- caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
- if caption_dropout_rate > 0.0:
- with accelerator.autocast():
- text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
-
accelerator.wait_for_everyone()
# free text encoder memory
qwen3_text_encoder = None
+ gc.collect() # Force garbage collection to free memory
clean_memory_on_device(accelerator.device)
# Load VAE and cache latents
logger.info("Loading Anima VAE...")
- vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu")
+ vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu")
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
@@ -294,24 +247,16 @@ def train(args):
# Load DiT (MiniTrainDIT + optional LLM Adapter)
logger.info("Loading Anima DiT...")
- dit = anima_utils.load_anima_dit(
- args.dit_path,
- dtype=weight_dtype,
- device="cpu",
- transformer_dtype=transformer_dtype,
- llm_adapter_path=getattr(args, 'llm_adapter_path', None),
- disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False),
+ dit = anima_utils.load_anima_model(
+ "cpu", args.pretrained_model_name_or_path, args.attn_mode, args.split_attn, "cpu", dit_weight_dtype=None
)
if args.gradient_checkpointing:
dit.enable_gradient_checkpointing(
cpu_offload=args.cpu_offload_checkpointing,
- unsloth_offload=getattr(args, 'unsloth_offload_checkpointing', False),
+ unsloth_offload=args.unsloth_offload_checkpointing,
)
- if getattr(args, 'flash_attn', False):
- dit.set_flash_attn(True)
-
train_dit = args.learning_rate != 0
dit.requires_grad_(train_dit)
if not train_dit:
@@ -327,19 +272,17 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
- # Move scale tensors to same device as VAE for on-the-fly encoding
- vae_scale = [s.to(accelerator.device) if isinstance(s, torch.Tensor) else s for s in vae_scale]
# Setup optimizer with parameter groups
if train_dit:
param_groups = anima_train_utils.get_anima_param_groups(
dit,
base_lr=args.learning_rate,
- self_attn_lr=getattr(args, 'self_attn_lr', None),
- cross_attn_lr=getattr(args, 'cross_attn_lr', None),
- mlp_lr=getattr(args, 'mlp_lr', None),
- mod_lr=getattr(args, 'mod_lr', None),
- llm_adapter_lr=getattr(args, 'llm_adapter_lr', None),
+ self_attn_lr=args.self_attn_lr,
+ cross_attn_lr=args.cross_attn_lr,
+ mlp_lr=args.mlp_lr,
+ mod_lr=args.mod_lr,
+ llm_adapter_lr=args.llm_adapter_lr,
)
else:
param_groups = []
@@ -361,57 +304,7 @@ def train(args):
# prepare optimizer
accelerator.print("prepare optimizer, data loader etc.")
- if args.blockwise_fused_optimizers:
- # Split params into per-block groups for blockwise fused optimizer
- # Build param_id → lr mapping from param_groups to propagate per-component LRs
- param_lr_map = {}
- for group in param_groups:
- for p in group['params']:
- param_lr_map[id(p)] = group['lr']
-
- grouped_params = []
- param_group = {}
- named_parameters = list(dit.named_parameters())
- for name, p in named_parameters:
- if not p.requires_grad:
- continue
- # Determine block type and index
- if name.startswith("blocks."):
- block_index = int(name.split(".")[1])
- block_type = "blocks"
- elif name.startswith("llm_adapter.blocks."):
- block_index = int(name.split(".")[2])
- block_type = "llm_adapter"
- else:
- block_index = -1
- block_type = "other"
-
- param_group_key = (block_type, block_index)
- if param_group_key not in param_group:
- param_group[param_group_key] = []
- param_group[param_group_key].append(p)
-
- for param_group_key, params in param_group.items():
- # Use per-component LR from param_groups if available
- lr = param_lr_map.get(id(params[0]), args.learning_rate)
- grouped_params.append({"params": params, "lr": lr})
- num_params = sum(p.numel() for p in params)
- accelerator.print(f"block {param_group_key}: {num_params} parameters, lr={lr}")
-
- # Create per-group optimizers
- optimizers = []
- for group in grouped_params:
- _, _, opt = train_util.get_optimizer(args, trainable_params=[group])
- optimizers.append(opt)
- optimizer = optimizers[0] # avoid error in following code
-
- logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
-
- if train_util.is_schedulefree_optimizer(optimizers[0], args):
- raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
- optimizer_train_fn = lambda: None
- optimizer_eval_fn = lambda: None
- elif args.fused_backward_pass:
+ if args.fused_backward_pass:
# Pass per-component param_groups directly to preserve per-component LRs
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
@@ -442,21 +335,19 @@ def train(args):
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr scheduler
- if args.blockwise_fused_optimizers:
- lr_schedulers = [train_util.get_scheduler_fix(args, opt, accelerator.num_processes) for opt in optimizers]
- lr_scheduler = lr_schedulers[0] # avoid error in following code
- else:
- lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# full fp16/bf16 training
+ dit_weight_dtype = weight_dtype
if args.full_fp16:
assert args.mixed_precision == "fp16", "full_fp16 requires mixed_precision='fp16'"
accelerator.print("enable full fp16 training.")
- dit.to(weight_dtype)
elif args.full_bf16:
assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'"
accelerator.print("enable full bf16 training.")
- dit.to(weight_dtype)
+ else:
+ dit_weight_dtype = torch.float32 # Default to float32
+ dit.to(dit_weight_dtype) # convert dit to target weight dtype
# move text encoder to GPU if not cached
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
@@ -498,6 +389,7 @@ def train(args):
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
+ # use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
@@ -517,55 +409,28 @@ def train(args):
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_group))
- elif args.blockwise_fused_optimizers:
- # Prepare additional optimizers and lr schedulers
- for i in range(1, len(optimizers)):
- optimizers[i] = accelerator.prepare(optimizers[i])
- lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
-
- # Counters for blockwise gradient hook
- optimizer_hooked_count = {}
- num_parameters_per_group = [0] * len(optimizers)
- parameter_optimizer_map = {}
-
- for opt_idx, opt in enumerate(optimizers):
- for param_group in opt.param_groups:
- for parameter in param_group["params"]:
- if parameter.requires_grad:
-
- def grad_hook(parameter: torch.Tensor):
- if accelerator.sync_gradients and args.max_grad_norm != 0.0:
- accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
-
- i = parameter_optimizer_map[parameter]
- optimizer_hooked_count[i] += 1
- if optimizer_hooked_count[i] == num_parameters_per_group[i]:
- optimizers[i].step()
- optimizers[i].zero_grad(set_to_none=True)
-
- parameter.register_post_accumulate_grad_hook(grad_hook)
- parameter_optimizer_map[parameter] = opt_idx
- num_parameters_per_group[opt_idx] += 1
-
# Training loop
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
- accelerator.print("running training")
- accelerator.print(f" num examples: {train_dataset_group.num_train_images}")
- accelerator.print(f" num batches per epoch: {len(train_dataloader)}")
- accelerator.print(f" num epochs: {num_train_epochs}")
+ accelerator.print("running training / 学習開始")
+ accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
- f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
+ f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
- accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}")
- accelerator.print(f" total optimization steps: {args.max_train_steps}")
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
+ noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
@@ -580,6 +445,7 @@ def train(args):
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
import wandb
+
wandb.define_metric("epoch")
wandb.define_metric("loss/epoch", step_metric="epoch")
@@ -589,8 +455,15 @@ def train(args):
# For --sample_at_first
optimizer_eval_fn()
anima_train_utils.sample_images(
- accelerator, args, 0, global_step, dit, vae, vae_scale,
- qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
+ accelerator,
+ args,
+ 0,
+ global_step,
+ dit,
+ vae,
+ qwen3_text_encoder,
+ tokenize_strategy,
+ text_encoding_strategy,
sample_prompts_te_outputs,
)
optimizer_train_fn()
@@ -600,11 +473,11 @@ def train(args):
# Show model info
unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None
if unwrapped_dit is not None:
- logger.info(f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}")
+ logger.info(f"dit device: {unwrapped_dit.device}, dtype: {unwrapped_dit.dtype}")
if qwen3_text_encoder is not None:
- logger.info(f"qwen3 device: {next(qwen3_text_encoder.parameters()).device}")
+ logger.info(f"qwen3 device: {qwen3_text_encoder.device}")
if vae is not None:
- logger.info(f"vae device: {next(vae.parameters()).device}")
+ logger.info(f"vae device: {vae.device}")
loss_recorder = train_util.LossRecorder()
epoch = 0
@@ -618,19 +491,17 @@ def train(args):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
- if args.blockwise_fused_optimizers:
- optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
-
with accelerator.accumulate(*training_models):
# Get latents
if "latents" in batch and batch["latents"] is not None:
- latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
+ latents = batch["latents"].to(accelerator.device, dtype=dit_weight_dtype)
+ if latents.ndim == 5: # Fallback for 5D latents (old cache)
+ latents = latents.squeeze(2) # (B, C, 1, H, W) -> (B, C, H, W)
else:
with torch.no_grad():
# images are already [-1, 1] from IMAGE_TRANSFORMS, add temporal dim
images = batch["images"].to(accelerator.device, dtype=weight_dtype)
- images = images.unsqueeze(2) # (B, C, 1, H, W)
- latents = vae.encode(images, vae_scale).to(accelerator.device, dtype=weight_dtype)
+ latents = vae.encode_pixels_to_latents(images).to(accelerator.device, dtype=dit_weight_dtype)
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
@@ -640,23 +511,24 @@ def train(args):
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Cached outputs
+ caption_dropout_rates = text_encoder_outputs_list[-1]
+ text_encoder_outputs_list = text_encoder_outputs_list[:-1]
+
+ # Apply caption dropout to cached outputs
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(
- *text_encoder_outputs_list
+ *text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
)
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list
else:
# Encode on-the-fly
input_ids_list = batch["input_ids_list"]
- qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = input_ids_list
with torch.no_grad():
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(
- tokenize_strategy,
- [qwen3_text_encoder],
- [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask],
+ tokenize_strategy, [qwen3_text_encoder], input_ids_list
)
# Move to device
- prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype)
+ prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit_weight_dtype)
attn_mask = attn_mask.to(accelerator.device)
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
t5_attn_mask = t5_attn_mask.to(accelerator.device)
@@ -664,9 +536,11 @@ def train(args):
# Noise and timesteps
noise = torch.randn_like(latents)
- noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
- args, latents, noise, accelerator.device, weight_dtype
+ # Get noisy model input and timesteps
+ noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
+ args, noise_scheduler, latents, noise, accelerator.device, dit_weight_dtype
)
+ timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
# NaN checks
if torch.any(torch.isnan(noisy_model_input)):
@@ -678,15 +552,10 @@ def train(args):
bs = latents.shape[0]
h_latent = latents.shape[-2]
w_latent = latents.shape[-1]
- padding_mask = torch.zeros(
- bs, 1, h_latent, w_latent,
- dtype=weight_dtype, device=accelerator.device
- )
+ padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=dit_weight_dtype, device=accelerator.device)
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
- if is_swapping_blocks:
- accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
-
+ noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, (B, C, 1, H, W)
with accelerator.autocast():
model_pred = dit(
noisy_model_input,
@@ -697,6 +566,7 @@ def train(args):
t5_input_ids=t5_input_ids,
t5_attn_mask=t5_attn_mask,
)
+ model_pred = model_pred.squeeze(2) # 5D to 4D, (B, C, H, W)
# Compute loss (rectified flow: target = noise - latents)
target = noise - latents
@@ -708,12 +578,10 @@ def train(args):
# Loss
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None)
- loss = train_util.conditional_loss(
- model_pred.float(), target.float(), args.loss_type, "none", huber_c
- )
+ loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
- loss = loss.mean([1, 2, 3, 4]) # (B, C, T, H, W) -> (B,)
+ loss = loss.mean([1, 2, 3]) # (B, C, H, W) -> (B,)
if weighting is not None:
loss = loss * weighting
@@ -724,7 +592,7 @@ def train(args):
accelerator.backward(loss)
- if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
+ if not args.fused_backward_pass:
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
@@ -737,9 +605,6 @@ def train(args):
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
- if args.blockwise_fused_optimizers:
- for i in range(1, len(optimizers)):
- lr_schedulers[i].step()
# Checks if the accelerator has performed an optimization step
if accelerator.sync_gradients:
@@ -748,8 +613,15 @@ def train(args):
optimizer_eval_fn()
anima_train_utils.sample_images(
- accelerator, args, None, global_step, dit, vae, vae_scale,
- qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
+ accelerator,
+ args,
+ None,
+ global_step,
+ dit,
+ vae,
+ qwen3_text_encoder,
+ tokenize_strategy,
+ text_encoding_strategy,
sample_prompts_te_outputs,
)
@@ -773,8 +645,10 @@ def train(args):
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs_with_names(
- logs, lr_scheduler, args.optimizer_type,
- ["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else []
+ logs,
+ lr_scheduler,
+ args.optimizer_type,
+ ["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [],
)
accelerator.log(logs, step=global_step)
@@ -807,8 +681,15 @@ def train(args):
)
anima_train_utils.sample_images(
- accelerator, args, epoch + 1, global_step, dit, vae, vae_scale,
- qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
+ accelerator,
+ args,
+ epoch + 1,
+ global_step,
+ dit,
+ vae,
+ qwen3_text_encoder,
+ tokenize_strategy,
+ text_encoding_strategy,
sample_prompts_te_outputs,
)
@@ -852,11 +733,6 @@ def setup_parser() -> argparse.ArgumentParser:
anima_train_utils.add_anima_training_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
- parser.add_argument(
- "--blockwise_fused_optimizers",
- action="store_true",
- help="enable blockwise optimizers for fused backward pass and optimizer step",
- )
parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",
@@ -884,4 +760,7 @@ if __name__ == "__main__":
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
+ if args.attn_mode == "sdpa":
+ args.attn_mode = "torch" # backward compatibility
+
train(args)
diff --git a/anima_train_network.py b/anima_train_network.py
index 57ad1681..dd5f85e6 100644
--- a/anima_train_network.py
+++ b/anima_train_network.py
@@ -1,16 +1,26 @@
# Anima LoRA training script
import argparse
-import math
from typing import Any, Optional, Union
import torch
+import torch.nn as nn
from accelerate import Accelerator
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
-from library import anima_models, anima_train_utils, anima_utils, strategy_anima, strategy_base, train_util
+from library import (
+ anima_models,
+ anima_train_utils,
+ anima_utils,
+ flux_train_utils,
+ qwen_image_autoencoder_kl,
+ sd3_train_utils,
+ strategy_anima,
+ strategy_base,
+ train_util,
+)
import train_network
from library.utils import setup_logging
@@ -24,13 +34,6 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
- self.vae = None
- self.vae_scale = None
- self.qwen3_text_encoder = None
- self.qwen3_tokenizer = None
- self.t5_tokenizer = None
- self.tokenize_strategy = None
- self.text_encoding_strategy = None
def assert_extra_args(
self,
@@ -38,137 +41,110 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
+ if args.fp8_base or args.fp8_base_unet:
+ logger.warning("fp8_base and fp8_base_unet are not supported. / fp8_baseとfp8_base_unetはサポートされていません。")
+ args.fp8_base = False
+ args.fp8_base_unet = False
+ args.fp8_scaled = False # Anima DiT does not support fp8_scaled
+
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
- logger.warning(
- "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
- )
+ logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
args.cache_text_encoder_outputs = True
- # Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
- # dataset-level caption dropout, so zero out subset-level rates to allow caching.
- caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
- if caption_dropout_rate > 0:
- logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
- if hasattr(train_dataset_group, 'datasets'):
- for dataset in train_dataset_group.datasets:
- for subset in dataset.subsets:
- subset.caption_dropout_rate = 0.0
-
if args.cache_text_encoder_outputs:
- assert (
- train_dataset_group.is_text_encoder_output_cacheable()
+ assert train_dataset_group.is_text_encoder_output_cacheable(
+ cache_supports_dropout=True
), "when caching Text Encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
- if getattr(args, 'unsloth_offload_checkpointing', False):
+ if args.unsloth_offload_checkpointing:
if not args.gradient_checkpointing:
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
- assert not args.cpu_offload_checkpointing, \
- "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
+ assert (
+ not args.cpu_offload_checkpointing
+ ), "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
- # Flash attention: validate availability
- if getattr(args, 'flash_attn', False):
- try:
- import flash_attn # noqa: F401
- logger.info("Flash Attention enabled for DiT blocks")
- except ImportError:
- logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
- args.flash_attn = False
-
- if getattr(args, 'blockwise_fused_optimizers', False):
- raise ValueError("blockwise_fused_optimizers is not supported with LoRA/NetworkTrainer")
-
- train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8
+ train_dataset_group.verify_bucket_reso_steps(16) # WanVAE spatial downscale = 8 and patch size = 2
if val_dataset_group is not None:
- val_dataset_group.verify_bucket_reso_steps(8)
+ val_dataset_group.verify_bucket_reso_steps(16)
def load_target_model(self, args, weight_dtype, accelerator):
+ self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
+
# Load Qwen3 text encoder (tokenizers already loaded in get_tokenize_strategy)
logger.info("Loading Qwen3 text encoder...")
- self.qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(
- args.qwen3_path, dtype=weight_dtype, device="cpu"
- )
- self.qwen3_text_encoder.eval()
+ qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
+ qwen3_text_encoder.eval()
- # Parse transformer_dtype
- transformer_dtype = None
- if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None:
- transformer_dtype_map = {
- "float16": torch.float16,
- "bfloat16": torch.bfloat16,
- "float32": torch.float32,
- }
- transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
+ # Load VAE
+ logger.info("Loading Anima VAE...")
+ vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True)
+ vae.to(weight_dtype)
+ vae.eval()
+
+ # Return format: (model_type, text_encoders, vae, unet)
+ return "anima", [qwen3_text_encoder], vae, None # unet loaded lazily
+
+ def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, list[nn.Module]]:
+ loading_dtype = None if args.fp8_scaled else weight_dtype
+ loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
+
+ attn_mode = "torch"
+ if args.xformers:
+ attn_mode = "xformers"
+ if args.attn_mode is not None:
+ attn_mode = args.attn_mode
# Load DiT
- logger.info("Loading Anima DiT...")
- dit = anima_utils.load_anima_dit(
- args.dit_path,
- dtype=weight_dtype,
- device="cpu",
- transformer_dtype=transformer_dtype,
- llm_adapter_path=getattr(args, 'llm_adapter_path', None),
- disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False),
+ logger.info(f"Loading Anima DiT model with attn_mode={attn_mode}, split_attn: {args.split_attn}...")
+ model = anima_utils.load_anima_model(
+ accelerator.device,
+ args.pretrained_model_name_or_path,
+ attn_mode,
+ args.split_attn,
+ loading_device,
+ loading_dtype,
+ args.fp8_scaled,
)
- # Flash attention
- if getattr(args, 'flash_attn', False):
- dit.set_flash_attn(True)
-
# Store unsloth preference so that when the base NetworkTrainer calls
# dit.enable_gradient_checkpointing(cpu_offload=...), we can override to use unsloth.
# The base trainer only passes cpu_offload, so we store the flag on the model.
- self._use_unsloth_offload_checkpointing = getattr(args, 'unsloth_offload_checkpointing', False)
+ self._use_unsloth_offload_checkpointing = args.unsloth_offload_checkpointing
# Block swap
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
- dit.enable_block_swap(args.blocks_to_swap, accelerator.device)
+ model.enable_block_swap(args.blocks_to_swap, accelerator.device)
- # Load VAE
- logger.info("Loading Anima VAE...")
- self.vae, vae_mean, vae_std, self.vae_scale = anima_utils.load_anima_vae(
- args.vae_path, dtype=weight_dtype, device="cpu"
- )
-
- # Return format: (model_type, text_encoders, vae, unet)
- return "anima", [self.qwen3_text_encoder], self.vae, dit
+ return model, text_encoders
def get_tokenize_strategy(self, args):
# Load tokenizers from paths (called before load_target_model, so self.qwen3_tokenizer isn't set yet)
- self.tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
- qwen3_path=args.qwen3_path,
- t5_tokenizer_path=getattr(args, 't5_tokenizer_path', None),
+ tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
+ qwen3_path=args.qwen3,
+ t5_tokenizer_path=args.t5_tokenizer_path,
qwen3_max_length=args.qwen3_max_token_length,
t5_max_length=args.t5_max_token_length,
)
- # Store references so load_target_model can reuse them
- self.qwen3_tokenizer = self.tokenize_strategy.qwen3_tokenizer
- self.t5_tokenizer = self.tokenize_strategy.t5_tokenizer
- return self.tokenize_strategy
+ return tokenize_strategy
def get_tokenizers(self, tokenize_strategy: strategy_anima.AnimaTokenizeStrategy):
return [tokenize_strategy.qwen3_tokenizer]
def get_latents_caching_strategy(self, args):
- return strategy_anima.AnimaLatentsCachingStrategy(
- args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
- )
+ return strategy_anima.AnimaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check)
def get_text_encoding_strategy(self, args):
- caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
- self.text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
- dropout_rate=caption_dropout_rate,
- )
- return self.text_encoding_strategy
+ return strategy_anima.AnimaTextEncodingStrategy()
def post_process_network(self, args, accelerator, network, text_encoders, unet):
# Qwen3 text encoder is always frozen for Anima
@@ -188,10 +164,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
return strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
- args.cache_text_encoder_outputs_to_disk,
- args.text_encoder_batch_size,
- args.skip_cache_check,
- is_partial=False,
+ args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
return None
@@ -200,15 +173,14 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
):
if args.cache_text_encoder_outputs:
if not args.lowram:
- logger.info("move vae and unet to cpu to save memory")
- org_vae_device = next(vae.parameters()).device
- org_unet_device = unet.device
+ # We cannot move DiT to CPU because of block swap, so only move VAE
+ logger.info("move vae to cpu to save memory")
+ org_vae_device = vae.device
vae.to("cpu")
- unet.to("cpu")
clean_memory_on_device(accelerator.device)
logger.info("move text encoder to gpu")
- text_encoders[0].to(accelerator.device, dtype=weight_dtype)
+ text_encoders[0].to(accelerator.device)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
@@ -229,59 +201,52 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
logger.info(f" cache TE outputs for: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
- tokenize_strategy,
- text_encoders,
- tokens_and_masks,
- enable_dropout=False,
+ tokenize_strategy, text_encoders, tokens_and_masks
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
- # Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
- caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
- text_encoding_strategy_for_uncond = strategy_base.TextEncodingStrategy.get_strategy()
- if caption_dropout_rate > 0.0:
- tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy()
- with accelerator.autocast():
- text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders)
-
accelerator.wait_for_everyone()
# move text encoder back to cpu
logger.info("move text encoder back to cpu")
text_encoders[0].to("cpu")
- clean_memory_on_device(accelerator.device)
if not args.lowram:
- logger.info("move vae and unet back to original device")
+ logger.info("move vae back to original device")
vae.to(org_vae_device)
- unet.to(org_unet_device)
+
+ clean_memory_on_device(accelerator.device)
else:
- text_encoders[0].to(accelerator.device, dtype=weight_dtype)
+ # move text encoder to device for encoding during training/validation
+ text_encoders[0].to(accelerator.device)
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] # compatibility
te = self.get_models_for_text_encoding(args, accelerator, text_encoders)
qwen3_te = te[0] if te is not None else None
+ text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
+ tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
anima_train_utils.sample_images(
- accelerator, args, epoch, global_step, unet, vae, self.vae_scale,
- qwen3_te, self.tokenize_strategy, self.text_encoding_strategy,
+ accelerator,
+ args,
+ epoch,
+ global_step,
+ unet,
+ vae,
+ qwen3_te,
+ tokenize_strategy,
+ text_encoding_strategy,
self.sample_prompts_te_outputs,
)
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
- noise_scheduler = anima_train_utils.FlowMatchEulerDiscreteScheduler(
- num_train_timesteps=1000, shift=args.discrete_flow_shift
- )
+ noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
return noise_scheduler
def encode_images_to_latents(self, args, vae, images):
- # images are already [-1,1] from IMAGE_TRANSFORMS, add temporal dim
- images = images.unsqueeze(2) # (B, C, 1, H, W)
- # Ensure scale tensors are on the same device as images
- vae_device = images.device
- scale = [s.to(vae_device) if isinstance(s, torch.Tensor) else s for s in self.vae_scale]
- return vae.encode(images, scale)
+ vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage
+ return vae.encode_pixels_to_latents(images) # Keep 4D for input/output
def shift_scale_latents(self, args, latents):
# Latents already normalized by vae.encode with scale
@@ -301,13 +266,18 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
train_unet,
is_train=True,
):
+ anima: anima_models.Anima = unet
+
# Sample noise
+ if latents.ndim == 5: # Fallback for 5D latents (old cache)
+ latents = latents.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
noise = torch.randn_like(latents)
# Get noisy model input and timesteps
- noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
- args, latents, noise, accelerator.device, weight_dtype
+ noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
+ args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
+ timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
# Gradient checkpointing support
if args.gradient_checkpointing:
@@ -329,147 +299,81 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
bs = latents.shape[0]
h_latent = latents.shape[-2]
w_latent = latents.shape[-1]
- padding_mask = torch.zeros(
- bs, 1, h_latent, w_latent,
- dtype=weight_dtype, device=accelerator.device
- )
+ padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
- # Prepare block swap
- if self.is_swapping_blocks:
- accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
-
- # Call model (LLM adapter runs inside forward for DDP gradient sync)
+ # Call model
+ noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, [B, C, H, W] -> [B, C, 1, H, W]
with torch.set_grad_enabled(is_train), accelerator.autocast():
- model_pred = unet(
+ model_pred = anima(
noisy_model_input,
timesteps,
prompt_embeds,
padding_mask=padding_mask,
+ target_input_ids=t5_input_ids,
+ target_attention_mask=t5_attn_mask,
source_attention_mask=attn_mask,
- t5_input_ids=t5_input_ids,
- t5_attn_mask=t5_attn_mask,
)
+ model_pred = model_pred.squeeze(2) # 5D to 4D, [B, C, 1, H, W] -> [B, C, H, W]
# Rectified flow target: noise - latents
target = noise - latents
# Loss weighting
- weighting = anima_train_utils.compute_loss_weighting_for_anima(
- weighting_scheme=args.weighting_scheme, sigmas=sigmas
- )
-
- # Differential output preservation
- if "custom_attributes" in batch:
- diff_output_pr_indices = []
- for i, custom_attributes in enumerate(batch["custom_attributes"]):
- if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
- diff_output_pr_indices.append(i)
-
- if len(diff_output_pr_indices) > 0:
- network.set_multiplier(0.0)
- with torch.no_grad(), accelerator.autocast():
- if self.is_swapping_blocks:
- accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
- model_pred_prior = unet(
- noisy_model_input[diff_output_pr_indices],
- timesteps[diff_output_pr_indices],
- prompt_embeds[diff_output_pr_indices],
- padding_mask=padding_mask[diff_output_pr_indices],
- source_attention_mask=attn_mask[diff_output_pr_indices],
- t5_input_ids=t5_input_ids[diff_output_pr_indices],
- t5_attn_mask=t5_attn_mask[diff_output_pr_indices],
- )
- network.set_multiplier(1.0)
-
- target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
+ weighting = anima_train_utils.compute_loss_weighting_for_anima(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
return model_pred, target, timesteps, weighting
def process_batch(
- self, batch, text_encoders, unet, network, vae, noise_scheduler,
- vae_dtype, weight_dtype, accelerator, args,
- text_encoding_strategy, tokenize_strategy,
- is_train=True, train_text_encoder=True, train_unet=True,
+ self,
+ batch,
+ text_encoders,
+ unet,
+ network,
+ vae,
+ noise_scheduler,
+ vae_dtype,
+ weight_dtype,
+ accelerator,
+ args,
+ text_encoding_strategy,
+ tokenize_strategy,
+ is_train=True,
+ train_text_encoder=True,
+ train_unet=True,
) -> torch.Tensor:
- """Override base process_batch for 5D video latents (B, C, T, H, W).
-
- Base class assumes 4D (B, C, H, W) for loss.mean([1,2,3]) and weighting broadcast.
- """
- import typing
- from library.custom_train_functions import apply_masked_loss
-
- with torch.no_grad():
- if "latents" in batch and batch["latents"] is not None:
- latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
- else:
- if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
- latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
- else:
- chunks = [
- batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
- ]
- list_latents = []
- for chunk in chunks:
- with torch.no_grad():
- chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
- list_latents.append(chunk)
- latents = torch.cat(list_latents, dim=0)
-
- if torch.any(torch.isnan(latents)):
- accelerator.print("NaN found in latents, replacing with zeros")
- latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents))
-
- latents = self.shift_scale_latents(args, latents)
+ """Override base process_batch for caption dropout with cached text encoder outputs."""
# Text encoder conditions
- text_encoder_conds = []
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
+ anima_text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = text_encoding_strategy
if text_encoder_outputs_list is not None:
- text_encoder_conds = text_encoder_outputs_list
+ caption_dropout_rates = text_encoder_outputs_list[-1]
+ text_encoder_outputs_list = text_encoder_outputs_list[:-1]
- if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
- with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
- input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
- encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
- tokenize_strategy,
- self.get_models_for_text_encoding(args, accelerator, text_encoders),
- input_ids,
- )
- if args.full_fp16:
- encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
+ # Apply caption dropout to cached outputs
+ text_encoder_outputs_list = anima_text_encoding_strategy.drop_cached_text_encoder_outputs(
+ *text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
+ )
+ batch["text_encoder_outputs_list"] = text_encoder_outputs_list
- if len(text_encoder_conds) == 0:
- text_encoder_conds = encoded_text_encoder_conds
- else:
- for i in range(len(encoded_text_encoder_conds)):
- if encoded_text_encoder_conds[i] is not None:
- text_encoder_conds[i] = encoded_text_encoder_conds[i]
-
- noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
- args, accelerator, noise_scheduler, latents, batch,
- text_encoder_conds, unet, network, weight_dtype, train_unet, is_train=is_train,
+ return super().process_batch(
+ batch,
+ text_encoders,
+ unet,
+ network,
+ vae,
+ noise_scheduler,
+ vae_dtype,
+ weight_dtype,
+ accelerator,
+ args,
+ text_encoding_strategy,
+ tokenize_strategy,
+ is_train,
+ train_text_encoder,
+ train_unet,
)
- huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
- loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
-
- if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
- loss = apply_masked_loss(loss, batch)
-
- # Reduce all non-batch dims: (B, C, T, H, W) -> (B,) for 5D, (B, C, H, W) -> (B,) for 4D
- reduce_dims = list(range(1, loss.ndim))
- loss = loss.mean(reduce_dims)
-
- # Apply weighting after reducing to (B,)
- if weighting is not None:
- loss = loss * weighting
-
- loss_weights = batch["loss_weights"]
- loss = loss * loss_weights
-
- loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
- return loss.mean()
-
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
@@ -478,12 +382,15 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
def update_metadata(self, metadata, args):
metadata["ss_weighting_scheme"] = args.weighting_scheme
+ metadata["ss_logit_mean"] = args.logit_mean
+ metadata["ss_logit_std"] = args.logit_std
+ metadata["ss_mode_scale"] = args.mode_scale
+ metadata["ss_timestep_sampling"] = args.timestep_sampling
+ metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
- metadata["ss_timestep_sample_method"] = getattr(args, 'timestep_sample_method', 'logit_normal')
- metadata["ss_sigmoid_scale"] = getattr(args, 'sigmoid_scale', 1.0)
def is_text_encoder_not_needed_for_training(self, args):
- return args.cache_text_encoder_outputs
+ return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
@@ -496,23 +403,16 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)
- dit = unet
- dit = accelerator.prepare(dit, device_placement=[not self.is_swapping_blocks])
- accelerator.unwrap_model(dit).move_to_device_except_swap_blocks(accelerator.device)
- accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
+ model = unet
+ model = accelerator.prepare(model, device_placement=[not self.is_swapping_blocks])
+ accelerator.unwrap_model(model).move_to_device_except_swap_blocks(accelerator.device)
+ accelerator.unwrap_model(model).prepare_block_swap_before_forward()
- return dit
-
- def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
- # Drop cached text encoder outputs for caption dropout
- text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
- if text_encoder_outputs_list is not None:
- text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
- text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
- batch["text_encoder_outputs_list"] = text_encoder_outputs_list
+ return model
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
+ # prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
@@ -520,6 +420,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
anima_train_utils.add_anima_training_arguments(parser)
+ # parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
parser.add_argument(
"--unsloth_offload_checkpointing",
action="store_true",
@@ -536,5 +437,8 @@ if __name__ == "__main__":
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
+ if args.attn_mode == "sdpa":
+ args.attn_mode = "torch" # backward compatibility
+
trainer = AnimaNetworkTrainer()
trainer.train(args)
diff --git a/docs/anima_train_network.md b/docs/anima_train_network.md
index fe6b2354..c88dba9a 100644
--- a/docs/anima_train_network.md
+++ b/docs/anima_train_network.md
@@ -37,14 +37,14 @@ This guide assumes you already understand the basics of LoRA training. For commo
## 2. Differences from `train_network.py` / `train_network.py` との違い
-`anima_train_network.py` is based on `train_network.py` but modified for Anima . Main differences are:
+`anima_train_network.py` is based on `train_network.py` but modified for Anima. Main differences are:
* **Target models:** Anima DiT models.
* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a WanVAE (16-channel latent space with 8x spatial downscale).
-* **Arguments:** Options exist to specify the Anima DiT model, Qwen3 text encoder, WanVAE, LLM adapter, and T5 tokenizer separately.
-* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used.
-* **Anima specific options:** Additional parameters for component-wise learning rates (self_attn, cross_attn, mlp, mod, llm_adapter), timestep sampling, discrete flow shift, and flash attention.
-* **6 Parameter Groups:** Independent learning rates for `base`, `self_attn`, `cross_attn`, `mlp`, `adaln_modulation`, and `llm_adapter` components.
+* **Arguments:** Uses the common `--pretrained_model_name_or_path` for the DiT model path, `--qwen3` for the Qwen3 text encoder, and `--vae` for the WanVAE. The LLM adapter and T5 tokenizer can be specified separately with `--llm_adapter_path` and `--t5_tokenizer_path`.
+* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used. `--fp8_base` is not supported.
+* **Timestep sampling:** Uses the same `--timestep_sampling` options as FLUX training (`sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift`).
+* **LoRA:** Uses regex-based module selection and per-module rank/learning rate control (`network_reg_dims`, `network_reg_lrs`) instead of per-component arguments. Module exclusion/inclusion is controlled by `exclude_patterns` and `include_patterns`.
日本語
@@ -53,10 +53,10 @@ This guide assumes you already understand the basics of LoRA training. For commo
* **対象モデル:** Anima DiTモデルを対象とします。
* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびWanVAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。
-* **引数:** Anima DiTモデル、Qwen3テキストエンコーダー、WanVAE、LLM Adapter、T5トークナイザーを個別に指定する引数があります。
-* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)はAnimaの学習では使用されません。
-* **Anima特有の引数:** コンポーネント別学習率(self_attn, cross_attn, mlp, mod, llm_adapter)、タイムステップサンプリング、離散フローシフト、Flash Attentionに関する引数が追加されています。
-* **6パラメータグループ:** `base`、`self_attn`、`cross_attn`、`mlp`、`adaln_modulation`、`llm_adapter`の各コンポーネントに対して独立した学習率を設定できます。
+* **引数:** DiTモデルのパスには共通引数`--pretrained_model_name_or_path`を、Qwen3テキストエンコーダーには`--qwen3`を、WanVAEには`--vae`を使用します。LLM AdapterとT5トークナイザーはそれぞれ`--llm_adapter_path`、`--t5_tokenizer_path`で個別に指定できます。
+* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)は使用されません。`--fp8_base`はサポートされていません。
+* **タイムステップサンプリング:** FLUX学習と同じ`--timestep_sampling`オプション(`sigma`、`uniform`、`sigmoid`、`shift`、`flux_shift`)を使用します。
+* **LoRA:** コンポーネント別の引数の代わりに、正規表現ベースのモジュール選択とモジュール単位のランク/学習率制御(`network_reg_dims`、`network_reg_lrs`)を使用します。モジュールの除外/包含は`exclude_patterns`と`include_patterns`で制御します。
## 3. Preparation / 準備
@@ -74,7 +74,6 @@ The following files are required before starting training:
**Notes:**
* When using a single `.safetensors` file for Qwen3, download the `config.json`, `tokenizer.json`, `tokenizer_config.json`, and `vocab.json` from the [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFace repository into the `configs/qwen3_06b/` directory.
* The T5 tokenizer only needs the tokenizer files (not the T5 model weights). It uses the vocabulary from `google/t5-v1_1-xxl`.
-* Models are saved with a `net.` prefix on all keys for ComfyUI compatibility.
日本語
@@ -92,7 +91,6 @@ The following files are required before starting training:
**注意:**
* Qwen3の単体`.safetensors`ファイルを使用する場合、[Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFaceリポジトリから`config.json`、`tokenizer.json`、`tokenizer_config.json`、`vocab.json`をダウンロードし、`configs/qwen3_06b/`ディレクトリに配置してください。
* T5トークナイザーはトークナイザーファイルのみ必要です(T5モデルの重みは不要)。`google/t5-v1_1-xxl`の語彙を使用します。
-* モデルはComfyUI互換のため、すべてのキーに`net.`プレフィックスを付けて保存されます。
## 4. Running the Training / 学習の実行
@@ -103,9 +101,9 @@ Example command:
```bash
accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
- --dit_path="" \
- --qwen3_path="" \
- --vae_path="" \
+ --pretrained_model_name_or_path="" \
+ --qwen3="" \
+ --vae="" \
--llm_adapter_path="" \
--dataset_config="my_anima_dataset_config.toml" \
--output_dir="