diff --git a/_typos.toml b/_typos.toml index 686da4af..fc33b6b3 100644 --- a/_typos.toml +++ b/_typos.toml @@ -32,6 +32,7 @@ hime="hime" OT="OT" byt="byt" tak="tak" +temperal="temperal" [files] -extend-exclude = ["_typos.toml", "venv"] +extend-exclude = ["_typos.toml", "venv", "configs"] diff --git a/anima_minimal_inference.py b/anima_minimal_inference.py new file mode 100644 index 00000000..67af2876 --- /dev/null +++ b/anima_minimal_inference.py @@ -0,0 +1,1044 @@ +import argparse +import datetime +import gc +from importlib.util import find_spec +import random +import os +import re +import time +import copy +from types import ModuleType, SimpleNamespace +from typing import Tuple, Optional, List, Any, Dict, Union + +import numpy as np +import torch +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from tqdm import tqdm +from diffusers.utils.torch_utils import randn_tensor +from PIL import Image + +from library import anima_models, anima_utils, hunyuan_image_utils, qwen_image_autoencoder_kl, strategy_anima, strategy_base +from library.anima_vae import WanVAE_ +from library.device_utils import clean_memory_on_device, synchronize_device +from library.safetensors_utils import mem_eff_save_file +from networks import lora_hunyuan_image + +lycoris_available = find_spec("lycoris") is not None +if lycoris_available: + from lycoris.kohya import create_network_from_weights + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class GenerationSettings: + def __init__(self, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None): + self.device = device + self.dit_weight_dtype = dit_weight_dtype # not used currently because model may be optimized + + +def parse_args() -> argparse.Namespace: + """parse command line arguments""" + parser = argparse.ArgumentParser(description="HunyuanImage inference script") + + parser.add_argument("--dit", type=str, default=None, help="DiT directory or path") + parser.add_argument("--vae", type=str, default=None, help="VAE directory or path") + parser.add_argument("--text_encoder", type=str, required=True, help="Text Encoder 1 (Qwen2.5-VL) directory or path") + + # LoRA + parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") + parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") + parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns") + parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns") + + # inference + parser.add_argument( + "--guidance_scale", type=float, default=3.5, help="Guidance scale for classifier free guidance. Default is 3.5." + ) + parser.add_argument("--prompt", type=str, default=None, help="prompt for generation") + parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt for generation, default is empty string") + parser.add_argument("--image_size", type=int, nargs=2, default=[1024, 1024], help="image size, height and width") + parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps, default is 50") + parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") + parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") + + # Flow Matching + parser.add_argument( + "--flow_shift", + type=float, + default=5.0, + help="Shift factor for flow matching schedulers. Default is 5.0.", + ) + + parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") + parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8") + + parser.add_argument("--text_encoder_cpu", action="store_true", help="Inference on CPU for Text Encoders") + parser.add_argument( + "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" + ) + parser.add_argument( + "--attn_mode", + type=str, + default="torch", + choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "sdpa" for backward compatibility + help="attention mode", + ) + parser.add_argument( + "--output_type", + type=str, + default="images", + choices=["images", "latent", "latent_images"], + help="output type", + ) + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") + parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") + parser.add_argument( + "--lycoris", action="store_true", help=f"use lycoris for inference{'' if lycoris_available else ' (not available)'}" + ) + + # arguments for batch and interactive modes + parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file") + parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console") + + args = parser.parse_args() + + # Validate arguments + if args.from_file and args.interactive: + raise ValueError("Cannot use both --from_file and --interactive at the same time") + + if args.latent_path is None or len(args.latent_path) == 0: + if args.prompt is None and not args.from_file and not args.interactive: + raise ValueError("Either --prompt, --from_file or --interactive must be specified") + + if args.lycoris and not lycoris_available: + raise ValueError("install lycoris: https://github.com/KohakuBlueleaf/LyCORIS") + + if args.attn_mode == "sdpa": + args.attn_mode = "torch" # backward compatibility + + return args + + +def parse_prompt_line(line: str) -> Dict[str, Any]: + """Parse a prompt line into a dictionary of argument overrides + + Args: + line: Prompt line with options + + Returns: + Dict[str, Any]: Dictionary of argument overrides + """ + parts = line.split(" --") + prompt = parts[0].strip() + + # Create dictionary of overrides + overrides = {"prompt": prompt} + + for part in parts[1:]: + if not part.strip(): + continue + option_parts = part.split(" ", 1) + option = option_parts[0].strip() + value = option_parts[1].strip() if len(option_parts) > 1 else "" + + # Map options to argument names + if option == "w": + overrides["image_size_width"] = int(value) + elif option == "h": + overrides["image_size_height"] = int(value) + elif option == "d": + overrides["seed"] = int(value) + elif option == "s": + overrides["infer_steps"] = int(value) + elif option == "g" or option == "l": + overrides["guidance_scale"] = float(value) + elif option == "fs": + overrides["flow_shift"] = float(value) + elif option == "n": + overrides["negative_prompt"] = value + + return overrides + + +def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace: + """Apply overrides to args + + Args: + args: Original arguments + overrides: Dictionary of overrides + + Returns: + argparse.Namespace: New arguments with overrides applied + """ + args_copy = copy.deepcopy(args) + + for key, value in overrides.items(): + if key == "image_size_width": + args_copy.image_size[1] = value + elif key == "image_size_height": + args_copy.image_size[0] = value + else: + setattr(args_copy, key, value) + + return args_copy + + +def check_inputs(args: argparse.Namespace) -> Tuple[int, int]: + """Validate video size and length + + Args: + args: command line arguments + + Returns: + Tuple[int, int]: (height, width) + """ + height = args.image_size[0] + width = args.image_size[1] + + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + return height, width + + +# region Model + + +def load_dit_model( + args: argparse.Namespace, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None +) -> anima_models.MiniTrainDIT: + """load DiT model + + Args: + args: command line arguments + device: device to use + dit_weight_dtype: data type for the model weights. None for as-is + + Returns: + anima_models.MiniTrainDIT: DiT model instance + """ + # If LyCORIS is enabled, we will load the model to CPU and then merge LoRA weights (static method) + + loading_device = "cpu" + if not args.lycoris: + loading_device = device + + # load LoRA weights + if not args.lycoris and args.lora_weight is not None and len(args.lora_weight) > 0: + lora_weights_list = [] + for lora_weight in args.lora_weight: + logger.info(f"Loading LoRA weight from: {lora_weight}") + lora_sd = load_file(lora_weight) # load on CPU, dtype is as is + # lora_sd = filter_lora_state_dict(lora_sd, args.include_patterns, args.exclude_patterns) + lora_weights_list.append(lora_sd) + else: + lora_weights_list = None + + loading_weight_dtype = dit_weight_dtype + if args.fp8_scaled and not args.lycoris: + loading_weight_dtype = None # we will load weights as-is and then optimize to fp8 + + model = anima_utils.load_anima_model( + device, + args.dit, + args.attn_mode, + True, # enable split_attn to trim masked tokens + loading_device, + loading_weight_dtype, + args.fp8_scaled and not args.lycoris, + lora_weights_list=lora_weights_list, + lora_multipliers=args.lora_multiplier, + ) + # model = anima_utils.load_anima_dit( + # args.dit, + # dtype=loading_weight_dtype, + # device=loading_device, + # transformer_dtype=loading_weight_dtype, + # llm_adapter_path=None, # getattr(args, "llm_adapter_path", None), + # disable_mmap=False, # getattr(args, "disable_mmap_load_safetensors", False), + # ) + if not args.fp8_scaled: + # simple cast to dit_weight_dtype + target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict) + if dit_weight_dtype is not None: # in case of args.fp8 and not args.fp8_scaled + logger.info(f"Convert model to {dit_weight_dtype}") + target_dtype = dit_weight_dtype + + logger.info(f"Move model to device: {device}") + target_device = device + + model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations + + # model.to(device) + model.to(device, dtype=torch.bfloat16) # ensure model is in bfloat16 for inference + + model.eval().requires_grad_(False) + clean_memory_on_device(device) + + return model + + +# endregion + + +def decode_latent(vae: WanVAE_, latent: torch.Tensor, device: torch.device) -> torch.Tensor: + logger.info(f"Decoding image. Latent shape {latent.shape}, device {device}") + + vae.to(device) + with torch.no_grad(): + pixels = vae.decode_to_pixels(latent.to(device, dtype=vae.dtype)) + # pixels = vae.decode(latent.to(device, dtype=torch.bfloat16), scale=vae_scale) + + pixels = pixels.to("cpu", dtype=torch.float32) # move to CPU and convert to float32 (bfloat16 is not supported by numpy) + vae.to("cpu") + + logger.info(f"Decoded. Pixel shape {pixels.shape}") + return pixels[0] # remove batch dimension + + +def prepare_text_inputs( + args: argparse.Namespace, device: torch.device, anima: anima_models.Anima, shared_models: Optional[Dict] = None +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Prepare text-related inputs for T2I: LLM encoding. Anima model is also needed for preprocessing""" + + # load text encoder: conds_cache holds cached encodings for prompts without padding + conds_cache = {} + text_encoder_device = torch.device("cpu") if args.text_encoder_cpu else device + if shared_models is not None: + tokenizer = shared_models.get("tokenizer") + text_encoder = shared_models.get("text_encoder") + t5xxl_tokenizer = shared_models.get("t5xxl_tokenizer") + + if "conds_cache" in shared_models: # Use shared cache if available + conds_cache = shared_models["conds_cache"] + + # text_encoder is on device (batched inference) or CPU (interactive inference) + else: # Load if not in shared_models + text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder + # tokenizer, text_encoder = anima_text_encoder.load_qwen3( + # args.text_encoder, dtype=text_encoder_dtype, device=text_encoder_device, disable_mmap=True + # ) + # t5xxl_tokenizer = anima_text_encoder.load_t5xxl_tokenizer() + text_encoder, _ = anima_utils.load_qwen3_text_encoder( + args.text_encoder, dtype=text_encoder_dtype, device=text_encoder_device + ) + text_encoder.eval() + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + if tokenize_strategy is None: + tokenize_strategy = strategy_anima.AnimaTokenizeStrategy( + qwen3_path=args.text_encoder, + t5_tokenizer_path=getattr(args, "t5_tokenizer_path", None), + qwen3_max_length=512, # args.qwen3_max_token_length, + t5_max_length=512, # args.t5_max_token_length, + ) + strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) + # Store references so load_target_model can reuse them + tokenizer = tokenize_strategy.qwen3_tokenizer + t5xxl_tokenizer = tokenize_strategy.t5_tokenizer + # Store original devices to move back later if they were shared. This does nothing if shared_models is None + text_encoder_original_device = text_encoder.device if text_encoder else None + + # Ensure text_encoder is not None before proceeding + if not text_encoder or not tokenizer or not t5xxl_tokenizer: + raise ValueError("Text encoder or tokenizer is not loaded properly.") + + # Define a function to move models to device if needed + # This is to avoid moving models if not needed, especially in interactive mode + model_is_moved = False + + def move_models_to_device_if_needed(): + nonlocal model_is_moved + nonlocal shared_models + + if model_is_moved: + return + model_is_moved = True + + logger.info(f"Moving Text Encoder to appropriate device: {text_encoder_device}") + text_encoder.to(text_encoder_device) # If text_encoder_cpu is True, this will be CPU + + logger.info("Encoding prompt with Text Encoder") + + prompt = args.prompt + cache_key = prompt + if cache_key in conds_cache: + embed = conds_cache[cache_key] + else: + move_models_to_device_if_needed() + + encoding_strategy = strategy_anima.AnimaTextEncodingStrategy() + tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy() + + with torch.no_grad(): + # embed = anima_text_encoder.get_text_embeds(anima, tokenizer, text_encoder, t5xxl_tokenizer, prompt) + tokens = tokenize_strategy.tokenize(prompt) + embed = encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens) + crossattn_emb = anima.preprocess_text_embeds( + source_hidden_states=embed[0].to(anima.device), + target_input_ids=embed[2].to(anima.device), + target_attention_mask=embed[3].to(anima.device), + source_attention_mask=embed[1].to(anima.device), + ) + crossattn_emb[~embed[3].bool()] = 0 + embed[0] = crossattn_emb + embed[0] = embed[0].cpu() + + conds_cache[cache_key] = embed + + negative_prompt = args.negative_prompt + cache_key = negative_prompt + if cache_key in conds_cache: + negative_embed = conds_cache[cache_key] + else: + move_models_to_device_if_needed() + + with torch.no_grad(): + # negative_embed = anima_text_encoder.get_text_embeds(anima, tokenizer, text_encoder, t5xxl_tokenizer, negative_prompt) + tokens = tokenize_strategy.tokenize(negative_prompt) + negative_embed = encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens) + crossattn_emb = anima.preprocess_text_embeds( + source_hidden_states=negative_embed[0].to(anima.device), + target_input_ids=negative_embed[2].to(anima.device), + target_attention_mask=negative_embed[3].to(anima.device), + source_attention_mask=negative_embed[1].to(anima.device), + ) + crossattn_emb[~negative_embed[3].bool()] = 0 + negative_embed[0] = crossattn_emb + negative_embed[0] = negative_embed[0].cpu() + + conds_cache[cache_key] = negative_embed + + if not (shared_models and "text_encoder" in shared_models): # if loaded locally + # There is a bug text_encoder is not freed from GPU memory when text encoder is fp8 + del tokenizer, text_encoder, t5xxl_tokenizer + gc.collect() # This may force Text Encoder to be freed from GPU memory + else: # if shared, move back to original device (likely CPU) + if text_encoder: + text_encoder.to(text_encoder_original_device) + + clean_memory_on_device(device) + + arg_c = {"embed": embed, "prompt": prompt} + arg_null = {"embed": negative_embed, "prompt": negative_prompt} + + return arg_c, arg_null + + +def generate( + args: argparse.Namespace, + gen_settings: GenerationSettings, + shared_models: Optional[Dict] = None, + precomputed_text_data: Optional[Dict] = None, +) -> torch.Tensor: + """main function for generation + + Args: + args: command line arguments + shared_models: dictionary containing pre-loaded models (mainly for DiT) + precomputed_image_data: Optional dictionary with precomputed image data + precomputed_text_data: Optional dictionary with precomputed text data + + Returns: + tuple: (HunyuanVAE2D model (vae) or None, torch.Tensor generated latent) + """ + device, dit_weight_dtype = (gen_settings.device, gen_settings.dit_weight_dtype) + + # prepare seed + seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) + args.seed = seed # set seed to args for saving + + if shared_models is None or "model" not in shared_models: + # load DiT model + anima = load_dit_model(args, device, dit_weight_dtype) + + if shared_models is not None: + shared_models["model"] = anima + else: + # use shared model + logger.info("Using shared DiT model.") + anima: anima_models.MiniTrainDIT = shared_models["model"] + + if precomputed_text_data is not None: + logger.info("Using precomputed text data.") + context = precomputed_text_data["context"] + context_null = precomputed_text_data["context_null"] + + else: + logger.info("No precomputed data. Preparing image and text inputs.") + context, context_null = prepare_text_inputs(args, device, anima, shared_models) + + return generate_body(args, anima, context, context_null, device, seed) + + +def generate_body( + args: Union[argparse.Namespace, SimpleNamespace], + anima: anima_models.MiniTrainDIT, + context: Dict[str, Any], + context_null: Optional[Dict[str, Any]], + device: torch.device, + seed: int, +) -> torch.Tensor: + + # set random generator + seed_g = torch.Generator(device="cpu") + seed_g.manual_seed(seed) + + height, width = check_inputs(args) + logger.info(f"Image size: {height}x{width} (HxW), infer_steps: {args.infer_steps}") + + # image generation ###### + + logger.info(f"Prompt: {context['prompt']}") + + embed = context["embed"][0].to(device, dtype=torch.bfloat16) + if context_null is None: + context_null = context # dummy for unconditional + negative_embed = context_null["embed"][0].to(device, dtype=torch.bfloat16) + + # Prepare latent variables + num_channels_latents = 16 # anima_models.MiniTrainDIT.LATENT_CHANNELS + shape = ( + 1, + num_channels_latents, + 1, # Frame dimension + height // 8, # qwen_image_autoencoder_kl.SCALE_FACTOR, + width // 8, # qwen_image_autoencoder_kl.SCALE_FACTOR, + ) + latents = randn_tensor(shape, generator=seed_g, device=device, dtype=torch.bfloat16) + + # Create padding mask + bs = latents.shape[0] + h_latent = latents.shape[-2] + w_latent = latents.shape[-1] + padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=torch.bfloat16, device=device) + + logger.info(f"Embed: {embed.shape}, negative_embed: {negative_embed.shape}, latents: {latents.shape}") + embed = embed.to(torch.bfloat16) + negative_embed = negative_embed.to(torch.bfloat16) + + # Prepare timesteps + timesteps, sigmas = hunyuan_image_utils.get_timesteps_sigmas(args.infer_steps, args.flow_shift, device) + timesteps /= 1000 # scale to [0,1] range + timesteps = timesteps.to(device, dtype=torch.bfloat16) + + # Denoising loop + do_cfg = args.guidance_scale != 1.0 + autocast_enabled = args.fp8 + + with tqdm(total=len(timesteps), desc="Denoising steps") as pbar: + for i, t in enumerate(timesteps): + t_expand = t.expand(latents.shape[0]) + + with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled): + noise_pred = anima(latents, t_expand, embed, padding_mask=padding_mask) + + if do_cfg: + with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled): + uncond_noise_pred = anima(latents, t_expand, negative_embed, padding_mask=padding_mask) + noise_pred = uncond_noise_pred + args.guidance_scale * (noise_pred - uncond_noise_pred) + + # ensure latents dtype is consistent + latents = hunyuan_image_utils.step(latents, noise_pred, sigmas, i).to(latents.dtype) + + pbar.update() + + return latents + + +def get_time_flag(): + return datetime.datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S-%f")[:-3] + + +def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str: + """Save latent to file + + Args: + latent: Latent tensor + args: command line arguments + height: height of frame + width: width of frame + + Returns: + str: Path to saved latent file + """ + save_path = args.save_path + os.makedirs(save_path, exist_ok=True) + time_flag = get_time_flag() + + seed = args.seed + + latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors" + + if args.no_metadata: + metadata = None + else: + metadata = { + "seeds": f"{seed}", + "prompt": f"{args.prompt}", + "height": f"{height}", + "width": f"{width}", + "infer_steps": f"{args.infer_steps}", + # "embedded_cfg_scale": f"{args.embedded_cfg_scale}", + "guidance_scale": f"{args.guidance_scale}", + } + if args.negative_prompt is not None: + metadata["negative_prompt"] = f"{args.negative_prompt}" + + sd = {"latent": latent.contiguous()} + save_file(sd, latent_path, metadata=metadata) + logger.info(f"Latent saved to: {latent_path}") + + return latent_path + + +def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str: + """Save images to directory + + Args: + sample: Video tensor + args: command line arguments + original_base_name: Original base name (if latents are loaded from files) + + Returns: + str: Path to saved images directory + """ + save_path = args.save_path + os.makedirs(save_path, exist_ok=True) + time_flag = get_time_flag() + + seed = args.seed + original_name = "" if original_base_name is None else f"_{original_base_name}" + image_name = f"{time_flag}_{seed}{original_name}" + + x = torch.clamp(sample, -1.0, 1.0) + x = ((x + 1.0) * 127.5).to(torch.uint8).cpu().numpy() + x = x.transpose(1, 2, 0) # C, H, W -> H, W, C + + image = Image.fromarray(x) + image.save(os.path.join(save_path, f"{image_name}.png")) + + logger.info(f"Sample images saved to: {save_path}/{image_name}") + + return f"{save_path}/{image_name}" + + +def save_output( + args: argparse.Namespace, + vae: WanVAE_, + latent: torch.Tensor, + device: torch.device, + original_base_name: Optional[str] = None, +) -> None: + """save output + + Args: + args: command line arguments + vae: VAE model + latent: latent tensor + device: device to use + original_base_name: original base name (if latents are loaded from files) + """ + height, width = latent.shape[-2], latent.shape[-1] # BCTHW + height *= 8 # qwen_image_autoencoder_kl.SCALE_FACTOR + width *= 8 # qwen_image_autoencoder_kl.SCALE_FACTOR + # print(f"Saving output. Latent shape {latent.shape}; pixel shape {height}x{width}") + if args.output_type == "latent" or args.output_type == "latent_images": + # save latent + save_latent(latent, args, height, width) + if args.output_type == "latent": + return + + if vae is None: + logger.error("VAE is None, cannot decode latents for saving video/images.") + return + + if latent.ndim == 2: # S,C. For packed latents from other inference scripts + latent = latent.unsqueeze(0) + height, width = check_inputs(args) # Get height/width from args + latent = latent.view( + 1, + vae.latent_channels, + 1, # Frame dimension + height // 8, # qwen_image_autoencoder_kl.SCALE_FACTOR, + width // 8, # qwen_image_autoencoder_kl.SCALE_FACTOR, + ) + + image = decode_latent(vae, latent, device) + + if args.output_type == "images" or args.output_type == "latent_images": + # save images + if original_base_name is None: + original_name = "" + else: + original_name = f"_{original_base_name}" + save_images(image, args, original_name) + + +def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]: + """Process multiple prompts for batch mode + + Args: + prompt_lines: List of prompt lines + base_args: Base command line arguments + + Returns: + List[Dict]: List of prompt data dictionaries + """ + prompts_data = [] + + for line in prompt_lines: + line = line.strip() + if not line or line.startswith("#"): # Skip empty lines and comments + continue + + # Parse prompt line and create override dictionary + prompt_data = parse_prompt_line(line) + logger.info(f"Parsed prompt data: {prompt_data}") + prompts_data.append(prompt_data) + + return prompts_data + + +def load_shared_models(args: argparse.Namespace) -> Dict: + """Load shared models for batch processing or interactive mode. + Models are loaded to CPU to save memory. VAE is NOT loaded here. + DiT model is also NOT loaded here, handled by process_batch_prompts or generate. + + Args: + args: Base command line arguments + + Returns: + Dict: Dictionary of shared models (text/image encoders) + """ + shared_models = {} + # Load text encoders to CPU + text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder + tokenizer, text_encoder = anima_text_encoder.load_qwen3( + args.text_encoder, dtype=text_encoder_dtype, device="cpu", disable_mmap=True + ) + t5xxl_tokenizer = anima_text_encoder.load_t5xxl_tokenizer() + shared_models["tokenizer"] = tokenizer + shared_models["text_encoder"] = text_encoder + shared_models["t5xxl_tokenizer"] = t5xxl_tokenizer + return shared_models + + +def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None: + """Process multiple prompts with model reuse and batched precomputation + + Args: + prompts_data: List of prompt data dictionaries + args: Base command line arguments + """ + if not prompts_data: + logger.warning("No valid prompts found") + return + + gen_settings = get_generation_settings(args) + dit_weight_dtype = gen_settings.dit_weight_dtype + device = gen_settings.device + + # 1. Prepare VAE + logger.info("Loading VAE for batch generation...") + vae_for_batch = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True) + vae_for_batch.eval() + + all_prompt_args_list = [apply_overrides(args, pd) for pd in prompts_data] # Create all arg instances first + for prompt_args in all_prompt_args_list: + check_inputs(prompt_args) # Validate each prompt's height/width + + # 2. Load DiT Model once + logger.info("Loading DiT model for batch generation...") + # Use args from the first prompt for DiT loading (LoRA etc. should be consistent for a batch) + first_prompt_args = all_prompt_args_list[0] + anima = load_dit_model(first_prompt_args, device, dit_weight_dtype) # Load directly to target device if possible + + shared_models_for_generate = {"model": anima} # Pass DiT via shared_models + + # 3. Precompute Text Data (Text Encoder) + logger.info("Loading Text Encoder for batch text preprocessing...") + + # Text Encoder loaded to CPU by load_text_encoder + text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder + tokenizer_batch, text_encoder_batch = anima_text_encoder.load_qwen3( + args.text_encoder, dtype=text_encoder_dtype, device="cpu", disable_mmap=True + ) + t5xxl_tokenizer_batch = anima_text_encoder.load_t5xxl_tokenizer() + + # Text Encoder to device for this phase + text_encoder_device = torch.device("cpu") if args.text_encoder_cpu else device + text_encoder_batch.to(text_encoder_device) # Moved into prepare_text_inputs logic + + all_precomputed_text_data = [] + conds_cache_batch = {} + + logger.info("Preprocessing text and LLM/TextEncoder encoding for all prompts...") + temp_shared_models_txt = { + "tokenizer": tokenizer_batch, + "text_encoder": text_encoder_batch, # on GPU if not text_encoder_cpu + "t5xxl_tokenizer": t5xxl_tokenizer_batch, + "conds_cache": conds_cache_batch, + } + + for i, prompt_args_item in enumerate(all_prompt_args_list): + logger.info(f"Text preprocessing for prompt {i+1}/{len(all_prompt_args_list)}: {prompt_args_item.prompt}") + + # prepare_text_inputs will move text_encoders to device temporarily + context, context_null = prepare_text_inputs(prompt_args_item, device, anima, temp_shared_models_txt) + text_data = {"context": context, "context_null": context_null} + all_precomputed_text_data.append(text_data) + + # Models should be removed from device after prepare_text_inputs + del tokenizer_batch, text_encoder_batch, t5xxl_tokenizer_batch, temp_shared_models_txt, conds_cache_batch + gc.collect() # Force cleanup of Text Encoder from GPU memory + clean_memory_on_device(device) + + all_latents = [] + + logger.info("Generating latents for all prompts...") + with torch.no_grad(): + for i, prompt_args_item in enumerate(all_prompt_args_list): + current_text_data = all_precomputed_text_data[i] + height, width = check_inputs(prompt_args_item) # Get height/width for each prompt + + logger.info(f"Generating latent for prompt {i+1}/{len(all_prompt_args_list)}: {prompt_args_item.prompt}") + try: + # generate is called with precomputed data, so it won't load Text Encoders. + # It will use the DiT model from shared_models_for_generate. + latent = generate(prompt_args_item, gen_settings, shared_models_for_generate, current_text_data) + + if latent is None: # and prompt_args_item.save_merged_model: # Should be caught earlier + continue + + # Save latent if needed (using data from precomputed_image_data for H/W) + if prompt_args_item.output_type in ["latent", "latent_images"]: + save_latent(latent, prompt_args_item, height, width) + + all_latents.append(latent) + except Exception as e: + logger.error(f"Error generating latent for prompt: {prompt_args_item.prompt}. Error: {e}", exc_info=True) + all_latents.append(None) # Add placeholder for failed generations + continue + + # Free DiT model + logger.info("Releasing DiT model from memory...") + + del shared_models_for_generate["model"] + del anima + clean_memory_on_device(device) + synchronize_device(device) # Ensure memory is freed before loading VAE for decoding + + # 4. Decode latents and save outputs (using vae_for_batch) + if args.output_type != "latent": + logger.info("Decoding latents to videos/images using batched VAE...") + vae_for_batch.to(device) # Move VAE to device for decoding + + for i, latent in enumerate(all_latents): + if latent is None: # Skip failed generations + logger.warning(f"Skipping decoding for prompt {i+1} due to previous error.") + continue + + current_args = all_prompt_args_list[i] + logger.info(f"Decoding output {i+1}/{len(all_latents)} for prompt: {current_args.prompt}") + + # if args.output_type is "latent_images", we already saved latent above. + # so we skip saving latent here. + if current_args.output_type == "latent_images": + current_args.output_type = "images" + + # save_output expects latent to be [BCTHW] or [CTHW]. generate returns [BCTHW] (batch size 1). + save_output(current_args, vae_for_batch, latent, device) # Pass vae_for_batch + + vae_for_batch.to("cpu") # Move VAE back to CPU + + del vae_for_batch + clean_memory_on_device(device) + + +def process_interactive(args: argparse.Namespace) -> None: + """Process prompts in interactive mode + + Args: + args: Base command line arguments + """ + gen_settings = get_generation_settings(args) + device = gen_settings.device + shared_models = load_shared_models(args) + shared_models["conds_cache"] = {} # Initialize empty cache for interactive mode + + vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True) + vae.eval() + + print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):") + + try: + import prompt_toolkit + except ImportError: + logger.warning("prompt_toolkit not found. Using basic input instead.") + prompt_toolkit = None + + if prompt_toolkit: + session = prompt_toolkit.PromptSession() + + def input_line(prompt: str) -> str: + return session.prompt(prompt) + + else: + + def input_line(prompt: str) -> str: + return input(prompt) + + try: + while True: + try: + line = input_line("> ") + if not line.strip(): + continue + if len(line.strip()) == 1 and line.strip() in ["\x04", "\x1a"]: # Ctrl+D or Ctrl+Z with prompt_toolkit + raise EOFError # Exit on Ctrl+D or Ctrl+Z + + # Parse prompt + prompt_data = parse_prompt_line(line) + prompt_args = apply_overrides(args, prompt_data) + + # Generate latent + # For interactive, precomputed data is None. shared_models contains text encoders. + latent = generate(prompt_args, gen_settings, shared_models) + + # # If not one_frame_inference, move DiT model to CPU after generation + # model = shared_models.get("model") + # model.to("cpu") # Move DiT model to CPU after generation + + # Save latent and video + # returned_vae from generate will be used for decoding here. + save_output(prompt_args, vae, latent, device) + + except KeyboardInterrupt: + print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)") + continue + + except EOFError: + print("\nExiting interactive mode") + + +def get_generation_settings(args: argparse.Namespace) -> GenerationSettings: + device = torch.device(args.device) + + dit_weight_dtype = torch.bfloat16 # default + if args.fp8_scaled: + dit_weight_dtype = None # various precision weights, so don't cast to specific dtype + elif args.fp8: + dit_weight_dtype = torch.float8_e4m3fn + + logger.info(f"Using device: {device}, DiT weight weight precision: {dit_weight_dtype}") + + gen_settings = GenerationSettings(device=device, dit_weight_dtype=dit_weight_dtype) + return gen_settings + + +def main(): + # Parse arguments + args = parse_args() + + # Check if latents are provided + latents_mode = args.latent_path is not None and len(args.latent_path) > 0 + + # Set device + device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + logger.info(f"Using device: {device}") + args.device = device + + if latents_mode: + # Original latent decode mode + original_base_names = [] + latents_list = [] + seeds = [] + + # assert len(args.latent_path) == 1, "Only one latent path is supported for now" + + for latent_path in args.latent_path: + original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0]) + seed = 0 + + if os.path.splitext(latent_path)[1] != ".safetensors": + latents = torch.load(latent_path, map_location="cpu") + else: + latents = load_file(latent_path)["latent"] + with safe_open(latent_path, framework="pt") as f: + metadata = f.metadata() + if metadata is None: + metadata = {} + logger.info(f"Loaded metadata: {metadata}") + + if "seeds" in metadata: + seed = int(metadata["seeds"]) + if "height" in metadata and "width" in metadata: + height = int(metadata["height"]) + width = int(metadata["width"]) + args.image_size = [height, width] + + seeds.append(seed) + logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}") + + if latents.ndim == 5: # [BCTHW] + latents = latents.squeeze(0) # [CTHW] + + latents_list.append(latents) + + # latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape + + for i, latent in enumerate(latents_list): + args.seed = seeds[i] + + vae = qwen_image_autoencoder_kl.load_vae(args.vae, device=device, disable_mmap=True) + vae.eval() + save_output(args, vae, latent, device, original_base_names[i]) + + elif args.from_file: + # Batch mode from file + + # Read prompts from file + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_lines = f.readlines() + + # Process prompts + prompts_data = preprocess_prompts_for_batch(prompt_lines, args) + process_batch_prompts(prompts_data, args) + + elif args.interactive: + # Interactive mode + process_interactive(args) + + else: + # Single prompt mode (original behavior) + + # Generate latent + gen_settings = get_generation_settings(args) + + # For single mode, precomputed data is None, shared_models is None. + # generate will load all necessary models (Text Encoders, DiT). + latent = generate(args, gen_settings) + # print(f"Generated latent shape: {latent.shape}") + # if args.save_merged_model: + # return + + clean_memory_on_device(device) + + # Save latent and video + vae = qwen_image_autoencoder_kl.load_vae(args.vae, device="cpu", disable_mmap=True) + + vae.eval() + save_output(args, vae, latent, device) + + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/anima_train.py b/anima_train.py index a86c30c3..081d6963 100644 --- a/anima_train.py +++ b/anima_train.py @@ -49,35 +49,32 @@ def train(args): args.skip_cache_check = args.skip_latents_validity_check if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: - logger.warning( - "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled" - ) + logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled") args.cache_text_encoder_outputs = True if args.cpu_offload_checkpointing and not args.gradient_checkpointing: logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled") args.gradient_checkpointing = True - if getattr(args, 'unsloth_offload_checkpointing', False): + if getattr(args, "unsloth_offload_checkpointing", False): if not args.gradient_checkpointing: logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled") args.gradient_checkpointing = True - assert not args.cpu_offload_checkpointing, \ - "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing" + assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing" assert ( args.blocks_to_swap is None or args.blocks_to_swap == 0 ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing" - assert ( - args.blocks_to_swap is None or args.blocks_to_swap == 0 - ) or not getattr(args, 'unsloth_offload_checkpointing', False), \ - "blocks_to_swap is not supported with unsloth_offload_checkpointing" + assert (args.blocks_to_swap is None or args.blocks_to_swap == 0) or not getattr( + args, "unsloth_offload_checkpointing", False + ), "blocks_to_swap is not supported with unsloth_offload_checkpointing" # Flash attention: validate availability - if getattr(args, 'flash_attn', False): + if getattr(args, "flash_attn", False): try: import flash_attn # noqa: F401 + logger.info("Flash Attention enabled for DiT blocks") except ImportError: logger.warning("flash_attn package not installed, falling back to PyTorch SDPA") @@ -104,9 +101,7 @@ def train(args): user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): - logger.warning( - "ignore following options because config file is found: {0}".format(", ".join(ignored)) - ) + logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored))) else: if use_dreambooth_method: logger.info("Using DreamBooth method.") @@ -150,7 +145,7 @@ def train(args): # Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of # dataset-level caption dropout, so we save the rate and zero out subset-level # caption_dropout_rate to allow text encoder output caching. - caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) + caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0) if caption_dropout_rate > 0: logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}") for dataset in train_dataset_group.datasets: @@ -175,9 +170,7 @@ def train(args): return if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used" + assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used" if args.cache_text_encoder_outputs: assert ( @@ -193,7 +186,7 @@ def train(args): # parse transformer_dtype transformer_dtype = None - if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None: + if hasattr(args, "transformer_dtype") and args.transformer_dtype is not None: transformer_dtype_map = { "float16": torch.float16, "bfloat16": torch.bfloat16, @@ -203,12 +196,8 @@ def train(args): # Load tokenizers and set strategies logger.info("Loading tokenizers...") - qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder( - args.qwen3_path, dtype=weight_dtype, device="cpu" - ) - t5_tokenizer = anima_utils.load_t5_tokenizer( - getattr(args, 't5_tokenizer_path', None) - ) + qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=weight_dtype, device="cpu") + t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None)) # Set tokenize strategy tokenize_strategy = strategy_anima.AnimaTokenizeStrategy( @@ -220,7 +209,7 @@ def train(args): strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) # Set text encoding strategy - caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) + caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0) text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy( dropout_rate=caption_dropout_rate, ) @@ -266,10 +255,8 @@ def train(args): ) # Pre-cache unconditional embeddings for caption dropout before text encoder is deleted - caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) - if caption_dropout_rate > 0.0: - with accelerator.autocast(): - text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder]) + with accelerator.autocast(): + text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder]) accelerator.wait_for_everyone() @@ -299,17 +286,17 @@ def train(args): dtype=weight_dtype, device="cpu", transformer_dtype=transformer_dtype, - llm_adapter_path=getattr(args, 'llm_adapter_path', None), - disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False), + llm_adapter_path=getattr(args, "llm_adapter_path", None), + disable_mmap=getattr(args, "disable_mmap_load_safetensors", False), ) if args.gradient_checkpointing: dit.enable_gradient_checkpointing( cpu_offload=args.cpu_offload_checkpointing, - unsloth_offload=getattr(args, 'unsloth_offload_checkpointing', False), + unsloth_offload=getattr(args, "unsloth_offload_checkpointing", False), ) - if getattr(args, 'flash_attn', False): + if getattr(args, "flash_attn", False): dit.set_flash_attn(True) train_dit = args.learning_rate != 0 @@ -335,11 +322,11 @@ def train(args): param_groups = anima_train_utils.get_anima_param_groups( dit, base_lr=args.learning_rate, - self_attn_lr=getattr(args, 'self_attn_lr', None), - cross_attn_lr=getattr(args, 'cross_attn_lr', None), - mlp_lr=getattr(args, 'mlp_lr', None), - mod_lr=getattr(args, 'mod_lr', None), - llm_adapter_lr=getattr(args, 'llm_adapter_lr', None), + self_attn_lr=getattr(args, "self_attn_lr", None), + cross_attn_lr=getattr(args, "cross_attn_lr", None), + mlp_lr=getattr(args, "mlp_lr", None), + mod_lr=getattr(args, "mod_lr", None), + llm_adapter_lr=getattr(args, "llm_adapter_lr", None), ) else: param_groups = [] @@ -366,8 +353,8 @@ def train(args): # Build param_id → lr mapping from param_groups to propagate per-component LRs param_lr_map = {} for group in param_groups: - for p in group['params']: - param_lr_map[id(p)] = group['lr'] + for p in group["params"]: + param_lr_map[id(p)] = group["lr"] grouped_params = [] param_group = {} @@ -557,9 +544,7 @@ def train(args): accelerator.print(f" num examples: {train_dataset_group.num_train_images}") accelerator.print(f" num batches per epoch: {len(train_dataloader)}") accelerator.print(f" num epochs: {num_train_epochs}") - accelerator.print( - f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" - ) + accelerator.print(f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps: {args.max_train_steps}") @@ -580,6 +565,7 @@ def train(args): if "wandb" in [tracker.name for tracker in accelerator.trackers]: import wandb + wandb.define_metric("epoch") wandb.define_metric("loss/epoch", step_metric="epoch") @@ -589,8 +575,16 @@ def train(args): # For --sample_at_first optimizer_eval_fn() anima_train_utils.sample_images( - accelerator, args, 0, global_step, dit, vae, vae_scale, - qwen3_text_encoder, tokenize_strategy, text_encoding_strategy, + accelerator, + args, + 0, + global_step, + dit, + vae, + vae_scale, + qwen3_text_encoder, + tokenize_strategy, + text_encoding_strategy, sample_prompts_te_outputs, ) optimizer_train_fn() @@ -600,7 +594,9 @@ def train(args): # Show model info unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None if unwrapped_dit is not None: - logger.info(f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}") + logger.info( + f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}" + ) if qwen3_text_encoder is not None: logger.info(f"qwen3 device: {next(qwen3_text_encoder.parameters()).device}") if vae is not None: @@ -640,9 +636,7 @@ def train(args): text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: # Cached outputs - text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs( - *text_encoder_outputs_list - ) + text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list else: # Encode on-the-fly @@ -678,10 +672,7 @@ def train(args): bs = latents.shape[0] h_latent = latents.shape[-2] w_latent = latents.shape[-1] - padding_mask = torch.zeros( - bs, 1, h_latent, w_latent, - dtype=weight_dtype, device=accelerator.device - ) + padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device) # DiT forward (LLM adapter runs inside forward for DDP gradient sync) if is_swapping_blocks: @@ -708,9 +699,7 @@ def train(args): # Loss huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None) - loss = train_util.conditional_loss( - model_pred.float(), target.float(), args.loss_type, "none", huber_c - ) + loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3, 4]) # (B, C, T, H, W) -> (B,) @@ -748,8 +737,16 @@ def train(args): optimizer_eval_fn() anima_train_utils.sample_images( - accelerator, args, None, global_step, dit, vae, vae_scale, - qwen3_text_encoder, tokenize_strategy, text_encoding_strategy, + accelerator, + args, + None, + global_step, + dit, + vae, + vae_scale, + qwen3_text_encoder, + tokenize_strategy, + text_encoding_strategy, sample_prompts_te_outputs, ) @@ -773,8 +770,10 @@ def train(args): if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs_with_names( - logs, lr_scheduler, args.optimizer_type, - ["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [] + logs, + lr_scheduler, + args.optimizer_type, + ["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [], ) accelerator.log(logs, step=global_step) @@ -807,8 +806,16 @@ def train(args): ) anima_train_utils.sample_images( - accelerator, args, epoch + 1, global_step, dit, vae, vae_scale, - qwen3_text_encoder, tokenize_strategy, text_encoding_strategy, + accelerator, + args, + epoch + 1, + global_step, + dit, + vae, + vae_scale, + qwen3_text_encoder, + tokenize_strategy, + text_encoding_strategy, sample_prompts_te_outputs, ) diff --git a/anima_train_network.py b/anima_train_network.py index 57ad1681..d003aa64 100644 --- a/anima_train_network.py +++ b/anima_train_network.py @@ -39,17 +39,15 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): val_dataset_group: Optional[train_util.DatasetGroup], ): if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: - logger.warning( - "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled" - ) + logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled") args.cache_text_encoder_outputs = True # Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of # dataset-level caption dropout, so zero out subset-level rates to allow caching. - caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) + caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0) if caption_dropout_rate > 0: logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}") - if hasattr(train_dataset_group, 'datasets'): + if hasattr(train_dataset_group, "datasets"): for dataset in train_dataset_group.datasets: for subset in dataset.subsets: subset.caption_dropout_rate = 0.0 @@ -63,26 +61,28 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): args.blocks_to_swap is None or args.blocks_to_swap == 0 ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing" - if getattr(args, 'unsloth_offload_checkpointing', False): + if getattr(args, "unsloth_offload_checkpointing", False): if not args.gradient_checkpointing: logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled") args.gradient_checkpointing = True - assert not args.cpu_offload_checkpointing, \ - "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing" + assert ( + not args.cpu_offload_checkpointing + ), "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing" assert ( args.blocks_to_swap is None or args.blocks_to_swap == 0 ), "blocks_to_swap is not supported with unsloth_offload_checkpointing" # Flash attention: validate availability - if getattr(args, 'flash_attn', False): + if getattr(args, "flash_attn", False): try: import flash_attn # noqa: F401 + logger.info("Flash Attention enabled for DiT blocks") except ImportError: logger.warning("flash_attn package not installed, falling back to PyTorch SDPA") args.flash_attn = False - if getattr(args, 'blockwise_fused_optimizers', False): + if getattr(args, "blockwise_fused_optimizers", False): raise ValueError("blockwise_fused_optimizers is not supported with LoRA/NetworkTrainer") train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8 @@ -92,14 +92,12 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): def load_target_model(self, args, weight_dtype, accelerator): # Load Qwen3 text encoder (tokenizers already loaded in get_tokenize_strategy) logger.info("Loading Qwen3 text encoder...") - self.qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder( - args.qwen3_path, dtype=weight_dtype, device="cpu" - ) + self.qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=weight_dtype, device="cpu") self.qwen3_text_encoder.eval() # Parse transformer_dtype transformer_dtype = None - if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None: + if hasattr(args, "transformer_dtype") and args.transformer_dtype is not None: transformer_dtype_map = { "float16": torch.float16, "bfloat16": torch.bfloat16, @@ -114,18 +112,18 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): dtype=weight_dtype, device="cpu", transformer_dtype=transformer_dtype, - llm_adapter_path=getattr(args, 'llm_adapter_path', None), - disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False), + llm_adapter_path=getattr(args, "llm_adapter_path", None), + disable_mmap=getattr(args, "disable_mmap_load_safetensors", False), ) # Flash attention - if getattr(args, 'flash_attn', False): + if getattr(args, "flash_attn", False): dit.set_flash_attn(True) # Store unsloth preference so that when the base NetworkTrainer calls # dit.enable_gradient_checkpointing(cpu_offload=...), we can override to use unsloth. # The base trainer only passes cpu_offload, so we store the flag on the model. - self._use_unsloth_offload_checkpointing = getattr(args, 'unsloth_offload_checkpointing', False) + self._use_unsloth_offload_checkpointing = getattr(args, "unsloth_offload_checkpointing", False) # Block swap self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 @@ -135,9 +133,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): # Load VAE logger.info("Loading Anima VAE...") - self.vae, vae_mean, vae_std, self.vae_scale = anima_utils.load_anima_vae( - args.vae_path, dtype=weight_dtype, device="cpu" - ) + self.vae, vae_mean, vae_std, self.vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu") # Return format: (model_type, text_encoders, vae, unet) return "anima", [self.qwen3_text_encoder], self.vae, dit @@ -146,7 +142,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): # Load tokenizers from paths (called before load_target_model, so self.qwen3_tokenizer isn't set yet) self.tokenize_strategy = strategy_anima.AnimaTokenizeStrategy( qwen3_path=args.qwen3_path, - t5_tokenizer_path=getattr(args, 't5_tokenizer_path', None), + t5_tokenizer_path=getattr(args, "t5_tokenizer_path", None), qwen3_max_length=args.qwen3_max_token_length, t5_max_length=args.t5_max_token_length, ) @@ -159,12 +155,10 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): return [tokenize_strategy.qwen3_tokenizer] def get_latents_caching_strategy(self, args): - return strategy_anima.AnimaLatentsCachingStrategy( - args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check - ) + return strategy_anima.AnimaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check) def get_text_encoding_strategy(self, args): - caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) + caption_dropout_rate = getattr(args, "caption_dropout_rate", 0.0) self.text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy( dropout_rate=caption_dropout_rate, ) @@ -237,12 +231,10 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): self.sample_prompts_te_outputs = sample_prompts_te_outputs # Pre-cache unconditional embeddings for caption dropout before text encoder is deleted - caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0) text_encoding_strategy_for_uncond = strategy_base.TextEncodingStrategy.get_strategy() - if caption_dropout_rate > 0.0: - tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy() - with accelerator.autocast(): - text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders) + tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy() + with accelerator.autocast(): + text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders) accelerator.wait_for_everyone() @@ -264,8 +256,16 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): qwen3_te = te[0] if te is not None else None anima_train_utils.sample_images( - accelerator, args, epoch, global_step, unet, vae, self.vae_scale, - qwen3_te, self.tokenize_strategy, self.text_encoding_strategy, + accelerator, + args, + epoch, + global_step, + unet, + vae, + self.vae_scale, + qwen3_te, + self.tokenize_strategy, + self.text_encoding_strategy, self.sample_prompts_te_outputs, ) @@ -329,10 +329,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): bs = latents.shape[0] h_latent = latents.shape[-2] w_latent = latents.shape[-1] - padding_mask = torch.zeros( - bs, 1, h_latent, w_latent, - dtype=weight_dtype, device=accelerator.device - ) + padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device) # Prepare block swap if self.is_swapping_blocks: @@ -354,9 +351,7 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): target = noise - latents # Loss weighting - weighting = anima_train_utils.compute_loss_weighting_for_anima( - weighting_scheme=args.weighting_scheme, sigmas=sigmas - ) + weighting = anima_train_utils.compute_loss_weighting_for_anima(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # Differential output preservation if "custom_attributes" in batch: @@ -386,10 +381,22 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): return model_pred, target, timesteps, weighting def process_batch( - self, batch, text_encoders, unet, network, vae, noise_scheduler, - vae_dtype, weight_dtype, accelerator, args, - text_encoding_strategy, tokenize_strategy, - is_train=True, train_text_encoder=True, train_unet=True, + self, + batch, + text_encoders, + unet, + network, + vae, + noise_scheduler, + vae_dtype, + weight_dtype, + accelerator, + args, + text_encoding_strategy, + tokenize_strategy, + is_train=True, + train_text_encoder=True, + train_unet=True, ) -> torch.Tensor: """Override base process_batch for 5D video latents (B, C, T, H, W). @@ -424,13 +431,21 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): # Text encoder conditions text_encoder_conds = [] text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) + anima_text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = text_encoding_strategy if text_encoder_outputs_list is not None: - text_encoder_conds = text_encoder_outputs_list + caption_dropout_rates = text_encoder_outputs_list[-1] + text_encoder_outputs_list = text_encoder_outputs_list[:-1] + + # Apply caption dropout to cached outputs + text_encoder_conds = anima_text_encoding_strategy.drop_cached_text_encoder_outputs( + *text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates + ) if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder: with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast(): input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] - encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( + # TODO stop gradient for uncond embeddings when using caption dropout? + encoded_text_encoder_conds = anima_text_encoding_strategy.encode_tokens( tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids, @@ -441,13 +456,23 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): if len(text_encoder_conds) == 0: text_encoder_conds = encoded_text_encoder_conds else: + # Fill in only missing parts (partial caching) for i in range(len(encoded_text_encoder_conds)): if encoded_text_encoder_conds[i] is not None: text_encoder_conds[i] = encoded_text_encoder_conds[i] noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( - args, accelerator, noise_scheduler, latents, batch, - text_encoder_conds, unet, network, weight_dtype, train_unet, is_train=is_train, + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + is_train=is_train, ) huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) @@ -479,8 +504,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): def update_metadata(self, metadata, args): metadata["ss_weighting_scheme"] = args.weighting_scheme metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift - metadata["ss_timestep_sample_method"] = getattr(args, 'timestep_sample_method', 'logit_normal') - metadata["ss_sigmoid_scale"] = getattr(args, 'sigmoid_scale', 1.0) + metadata["ss_timestep_sample_method"] = getattr(args, "timestep_sample_method", "logit_normal") + metadata["ss_sigmoid_scale"] = getattr(args, "sigmoid_scale", 1.0) def is_text_encoder_not_needed_for_training(self, args): return args.cache_text_encoder_outputs diff --git a/docs/anima_train_network.md b/docs/anima_train_network.md index fe6b2354..7cd30e26 100644 --- a/docs/anima_train_network.md +++ b/docs/anima_train_network.md @@ -118,7 +118,7 @@ accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \ --optimizer_type="AdamW8bit" \ --lr_scheduler="constant" \ --timestep_sample_method="logit_normal" \ - --discrete_flow_shift=3.0 \ + --discrete_flow_shift=1.0 \ --max_train_epochs=10 \ --save_every_n_epochs=1 \ --mixed_precision="bf16" \ @@ -162,7 +162,7 @@ Besides the arguments explained in the [train_network.py guide](train_network.md * `--timestep_sample_method=` - Timestep sampling method. Choose from `logit_normal` (default) or `uniform`. * `--discrete_flow_shift=` - - Shift for the timestep distribution in Rectified Flow training. Default `3.0`. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`. + - Shift for the timestep distribution in Rectified Flow training. Default `1.0`. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`. 1.0 means no shift. * `--sigmoid_scale=` - Scale factor for `logit_normal` timestep sampling. Default `1.0`. * `--qwen3_max_token_length=` diff --git a/library/anima_models.py b/library/anima_models.py index 6aad9d8c..d3adff5f 100644 --- a/library/anima_models.py +++ b/library/anima_models.py @@ -13,11 +13,10 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint as torch_checkpoint -from library import custom_offloading_utils +from library import custom_offloading_utils, attention from library.device_utils import clean_memory_on_device - def to_device(x, device): if isinstance(x, torch.Tensor): return x.to(device) @@ -39,11 +38,13 @@ def to_cpu(x): else: return x + # Unsloth Offloaded Gradient Checkpointing # Based on Unsloth Zoo by Daniel Han-Chen & the Unsloth team try: from deepspeed.runtime.activation_checkpointing.checkpointing import detach_variable except ImportError: + def detach_variable(inputs, device=None): """Detach tensors from computation graph, optionally moving to a device. @@ -80,11 +81,11 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function): """ @staticmethod - @torch.amp.custom_fwd(device_type='cuda') + @torch.amp.custom_fwd(device_type="cuda") def forward(ctx, forward_function, hidden_states, *args): # Remember the original device for backward pass (multi-GPU support) ctx.input_device = hidden_states.device - saved_hidden_states = hidden_states.to('cpu', non_blocking=True) + saved_hidden_states = hidden_states.to("cpu", non_blocking=True) with torch.no_grad(): output = forward_function(hidden_states, *args) ctx.save_for_backward(saved_hidden_states) @@ -96,7 +97,7 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function): return output @staticmethod - @torch.amp.custom_bwd(device_type='cuda') + @torch.amp.custom_bwd(device_type="cuda") def backward(ctx, *grads): (hidden_states,) = ctx.saved_tensors hidden_states = hidden_states.to(ctx.input_device, non_blocking=True).detach() @@ -108,8 +109,9 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function): output_tensors = [] grad_tensors = [] - for out, grad in zip(outputs if isinstance(outputs, tuple) else (outputs,), - grads if isinstance(grads, tuple) else (grads,)): + for out, grad in zip( + outputs if isinstance(outputs, tuple) else (outputs,), grads if isinstance(grads, tuple) else (grads,) + ): if isinstance(out, torch.Tensor) and out.requires_grad: output_tensors.append(out) grad_tensors.append(grad) @@ -123,24 +125,24 @@ def unsloth_checkpoint(function, *args): return UnslothOffloadedGradientCheckpointer.apply(function, *args) -# Flash Attention support -try: - from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func - FLASH_ATTN_AVAILABLE = True -except ImportError: - _flash_attn_func = None - FLASH_ATTN_AVAILABLE = False +# # Flash Attention support +# try: +# from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func +# FLASH_ATTN_AVAILABLE = True +# except ImportError: +# _flash_attn_func = None +# FLASH_ATTN_AVAILABLE = False -def flash_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor: - """Computes multi-head attention using Flash Attention. +# def flash_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor: +# """Computes multi-head attention using Flash Attention. - Input format: (batch, seq_len, n_heads, head_dim) - Output format: (batch, seq_len, n_heads * head_dim) — matches torch_attention_op output. - """ - # flash_attn_func expects (B, S, H, D) and returns (B, S, H, D) - out = _flash_attn_func(q_B_S_H_D, k_B_S_H_D, v_B_S_H_D) - return rearrange(out, "b s h d -> b s (h d)") +# Input format: (batch, seq_len, n_heads, head_dim) +# Output format: (batch, seq_len, n_heads * head_dim) — matches torch_attention_op output. +# """ +# # flash_attn_func expects (B, S, H, D) and returns (B, S, H, D) +# out = _flash_attn_func(q_B_S_H_D, k_B_S_H_D, v_B_S_H_D) +# return rearrange(out, "b s h d -> b s (h d)") from .utils import setup_logging @@ -174,14 +176,10 @@ def _apply_rotary_pos_emb_base( if start_positions is not None: max_offset = torch.max(start_positions) - assert ( - max_offset + cur_seq_len <= max_seq_len - ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" + assert max_offset + cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1) - assert ( - cur_seq_len <= max_seq_len - ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" + assert cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" freqs = freqs[:cur_seq_len] if tensor_format == "bshd": @@ -205,13 +203,9 @@ def apply_rotary_pos_emb( cu_seqlens: Union[torch.Tensor, None] = None, cp_size: int = 1, ) -> torch.Tensor: - assert not ( - cp_size > 1 and start_positions is not None - ), "start_positions != None with CP SIZE > 1 is not supported!" + assert not (cp_size > 1 and start_positions is not None), "start_positions != None with CP SIZE > 1 is not supported!" - assert ( - tensor_format != "thd" or cu_seqlens is not None - ), "cu_seqlens must not be None when tensor_format is 'thd'." + assert tensor_format != "thd" or cu_seqlens is not None, "cu_seqlens must not be None when tensor_format is 'thd'." assert fused == False @@ -223,9 +217,7 @@ def apply_rotary_pos_emb( _apply_rotary_pos_emb_base( x.unsqueeze(1), freqs, - start_positions=( - start_positions[idx : idx + 1] if start_positions is not None else None - ), + start_positions=(start_positions[idx : idx + 1] if start_positions is not None else None), interleaved=interleaved, ) for idx, x in enumerate(torch.split(t, seqlens)) @@ -262,7 +254,7 @@ class RMSNorm(torch.nn.Module): def _norm(self, x: torch.Tensor) -> torch.Tensor: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - @torch.amp.autocast(device_type='cuda', dtype=torch.float32) + @torch.amp.autocast(device_type="cuda", dtype=torch.float32) def forward(self, x: torch.Tensor) -> torch.Tensor: output = self._norm(x.float()).type_as(x) return output * self.weight @@ -308,9 +300,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1]) k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1]) - result_B_S_HD = rearrange( - F.scaled_dot_product_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D), "b h ... l -> b ... (h l)" - ) + result_B_S_HD = rearrange(F.scaled_dot_product_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D), "b h ... l -> b ... (h l)") return result_B_S_HD @@ -399,18 +389,23 @@ class Attention(nn.Module): return q, k, v - def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: - result = self.attn_op(q, k, v) # [B, S, H, D] - return self.output_dropout(self.output_proj(result)) + # def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + # result = self.attn_op(q, k, v) # [B, S, H, D] + # return self.output_dropout(self.output_proj(result)) def forward( self, x: torch.Tensor, + attn_params: attention.AttentionParams, context: Optional[torch.Tensor] = None, rope_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb) - return self.compute_attention(q, k, v) + # return self.compute_attention(q, k, v) + qkv = [q, k, v] + del q, k, v + result = attention.attention(qkv, attn_params=attn_params) + return self.output_dropout(self.output_proj(result)) # Positional Embeddings @@ -484,12 +479,8 @@ class VideoRopePosition3DEmb(VideoPositionEmb): dim_t = self._dim_t self.seq = torch.arange(max(self.max_h, self.max_w, self.max_t)).float().to(self.dim_spatial_range.device) - self.dim_spatial_range = ( - torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h - ) - self.dim_temporal_range = ( - torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t - ) + self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h + self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t def generate_embeddings( self, @@ -679,9 +670,7 @@ class FourierFeatures(nn.Module): def reset_parameters(self) -> None: generator = torch.Generator() generator.manual_seed(0) - self.freqs = ( - 2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device) - ) + self.freqs = 2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device) self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device) def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor: @@ -713,9 +702,7 @@ class PatchEmbed(nn.Module): m=spatial_patch_size, n=spatial_patch_size, ), - nn.Linear( - in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False - ), + nn.Linear(in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False), ) self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size @@ -765,9 +752,7 @@ class FinalLayer(nn.Module): nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False), ) else: - self.adaln_modulation = nn.Sequential( - nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False) - ) + self.adaln_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False)) self.init_weights() @@ -790,9 +775,9 @@ class FinalLayer(nn.Module): ): if self.use_adaln_lora: assert adaln_lora_B_T_3D is not None - shift_B_T_D, scale_B_T_D = ( - self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size] - ).chunk(2, dim=-1) + shift_B_T_D, scale_B_T_D = (self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]).chunk( + 2, dim=-1 + ) else: shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) @@ -833,7 +818,11 @@ class Block(nn.Module): self.layer_norm_cross_attn = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) self.cross_attn = Attention( - x_dim, context_dim, num_heads, x_dim // num_heads, qkv_format="bshd", + x_dim, + context_dim, + num_heads, + x_dim // num_heads, + qkv_format="bshd", ) self.layer_norm_mlp = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) @@ -904,6 +893,7 @@ class Block(nn.Module): x_B_T_H_W_D: torch.Tensor, emb_B_T_D: torch.Tensor, crossattn_emb: torch.Tensor, + attn_params: attention.AttentionParams, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None, @@ -919,13 +909,13 @@ class Block(nn.Module): shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D ).chunk(3, dim=-1) - shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = ( - self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D - ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk( + 3, dim=-1 + ) else: - shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn( - emb_B_T_D - ).chunk(3, dim=-1) + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(emb_B_T_D).chunk( + 3, dim=-1 + ) shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( emb_B_T_D ).chunk(3, dim=-1) @@ -954,11 +944,14 @@ class Block(nn.Module): result = rearrange( self.self_attn( rearrange(normalized_x, "b t h w d -> b (t h w) d"), + attn_params, None, rope_emb=rope_emb_L_1_1_D, ), "b (t h w) d -> b t h w d", - t=T, h=H, w=W, + t=T, + h=H, + w=W, ) x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result @@ -967,11 +960,14 @@ class Block(nn.Module): result = rearrange( self.cross_attn( rearrange(normalized_x, "b t h w d -> b (t h w) d"), + attn_params, crossattn_emb, rope_emb=rope_emb_L_1_1_D, ), "b (t h w) d -> b t h w d", - t=T, h=H, w=W, + t=T, + h=H, + w=W, ) x_B_T_H_W_D = result * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D @@ -987,6 +983,7 @@ class Block(nn.Module): x_B_T_H_W_D: torch.Tensor, emb_B_T_D: torch.Tensor, crossattn_emb: torch.Tensor, + attn_params: attention.AttentionParams, rope_emb_L_1_1_D: Optional[torch.Tensor] = None, adaln_lora_B_T_3D: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None, @@ -996,8 +993,13 @@ class Block(nn.Module): # Unsloth: async non-blocking CPU RAM offload (fastest offload method) return unsloth_checkpoint( self._forward, - x_B_T_H_W_D, emb_B_T_D, crossattn_emb, - rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, + x_B_T_H_W_D, + emb_B_T_D, + crossattn_emb, + attn_params, + rope_emb_L_1_1_D, + adaln_lora_B_T_3D, + extra_per_block_pos_emb, ) elif self.cpu_offload_checkpointing: # Standard cpu offload: blocking transfers @@ -1008,26 +1010,42 @@ class Block(nn.Module): device_inputs = to_device(inputs, device) outputs = func(*device_inputs) return to_cpu(outputs) + return custom_forward return torch_checkpoint( create_custom_forward(self._forward), - x_B_T_H_W_D, emb_B_T_D, crossattn_emb, - rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, + x_B_T_H_W_D, + emb_B_T_D, + crossattn_emb, + attn_params, + rope_emb_L_1_1_D, + adaln_lora_B_T_3D, + extra_per_block_pos_emb, use_reentrant=False, ) else: # Standard gradient checkpointing (no offload) return torch_checkpoint( self._forward, - x_B_T_H_W_D, emb_B_T_D, crossattn_emb, - rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, + x_B_T_H_W_D, + emb_B_T_D, + crossattn_emb, + attn_params, + rope_emb_L_1_1_D, + adaln_lora_B_T_3D, + extra_per_block_pos_emb, use_reentrant=False, ) else: return self._forward( - x_B_T_H_W_D, emb_B_T_D, crossattn_emb, - rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb, + x_B_T_H_W_D, + emb_B_T_D, + crossattn_emb, + attn_params, + rope_emb_L_1_1_D, + adaln_lora_B_T_3D, + extra_per_block_pos_emb, ) @@ -1069,6 +1087,8 @@ class MiniTrainDIT(nn.Module): extra_t_extrapolation_ratio: float = 1.0, rope_enable_fps_modulation: bool = True, use_llm_adapter: bool = False, + attn_mode: str = "torch", + split_attn: bool = False, ) -> None: super().__init__() self.max_img_h = max_img_h @@ -1097,6 +1117,9 @@ class MiniTrainDIT(nn.Module): self.rope_enable_fps_modulation = rope_enable_fps_modulation self.use_llm_adapter = use_llm_adapter + self.attn_mode = attn_mode + self.split_attn = split_attn + # Block swap support self.blocks_to_swap = None self.offloader: Optional[custom_offloading_utils.ModelOffloader] = None @@ -1156,7 +1179,6 @@ class MiniTrainDIT(nn.Module): self.final_layer.init_weights() self.t_embedding_norm.reset_parameters() - def enable_gradient_checkpointing(self, cpu_offload: bool = False, unsloth_offload: bool = False): for block in self.blocks: block.enable_gradient_checkpointing(cpu_offload=cpu_offload, unsloth_offload=unsloth_offload) @@ -1169,18 +1191,17 @@ class MiniTrainDIT(nn.Module): def device(self): return next(self.parameters()).device + # def set_flash_attn(self, use_flash_attn: bool): + # """Toggle flash attention for all DiT blocks (self-attn + cross-attn). - def set_flash_attn(self, use_flash_attn: bool): - """Toggle flash attention for all DiT blocks (self-attn + cross-attn). - - LLM Adapter attention is NOT affected (it uses attention masks incompatible with flash_attn). - """ - if use_flash_attn and not FLASH_ATTN_AVAILABLE: - raise ImportError("flash_attn package is required for --flash_attn but is not installed") - attn_op = flash_attention_op if use_flash_attn else torch_attention_op - for block in self.blocks: - block.self_attn.attn_op = attn_op - block.cross_attn.attn_op = attn_op + # LLM Adapter attention is NOT affected (it uses attention masks incompatible with flash_attn). + # """ + # if use_flash_attn and not FLASH_ATTN_AVAILABLE: + # raise ImportError("flash_attn package is required for --flash_attn but is not installed") + # attn_op = flash_attention_op if use_flash_attn else torch_attention_op + # for block in self.blocks: + # block.self_attn.attn_op = attn_op + # block.cross_attn.attn_op = attn_op def build_patch_embed(self) -> None: in_channels = self.in_channels + 1 if self.concat_padding_mask else self.in_channels @@ -1232,9 +1253,7 @@ class MiniTrainDIT(nn.Module): padding_mask = transforms.functional.resize( padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST ) - x_B_C_T_H_W = torch.cat( - [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 - ) + x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1) x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) if self.extra_per_block_abs_pos_emb: @@ -1258,7 +1277,6 @@ class MiniTrainDIT(nn.Module): ) return x_B_C_Tt_Hp_Wp - def enable_block_swap(self, num_blocks: int, device: torch.device): self.blocks_to_swap = num_blocks @@ -1266,9 +1284,7 @@ class MiniTrainDIT(nn.Module): self.blocks_to_swap <= self.num_blocks - 2 ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks." - self.offloader = custom_offloading_utils.ModelOffloader( - self.blocks, self.blocks_to_swap, device - ) + self.offloader = custom_offloading_utils.ModelOffloader(self.blocks, self.blocks_to_swap, device) logger.info(f"Anima: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.") def move_to_device_except_swap_blocks(self, device: torch.device): @@ -1310,7 +1326,7 @@ class MiniTrainDIT(nn.Module): t5_attn_mask: Optional T5 attention mask """ # Run LLM adapter inside forward for correct DDP gradient synchronization - if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, 'llm_adapter'): + if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, "llm_adapter"): crossattn_emb = self.llm_adapter( source_hidden_states=crossattn_emb, target_input_ids=t5_input_ids, @@ -1337,6 +1353,8 @@ class MiniTrainDIT(nn.Module): "extra_per_block_pos_emb": extra_pos_emb, } + attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn) + for block_idx, block in enumerate(self.blocks): if self.blocks_to_swap: self.offloader.wait_for_block(block_idx) @@ -1345,6 +1363,7 @@ class MiniTrainDIT(nn.Module): x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, + attn_params, **block_kwargs, ) @@ -1485,24 +1504,36 @@ class LLMAdapterTransformerBlock(nn.Module): self.norm_mlp = nn.LayerNorm(model_dim) if layer_norm else LLMAdapterRMSNorm(model_dim) self.mlp = nn.Sequential( - nn.Linear(model_dim, int(model_dim * mlp_ratio)), - nn.GELU(), - nn.Linear(int(model_dim * mlp_ratio), model_dim) + nn.Linear(model_dim, int(model_dim * mlp_ratio)), nn.GELU(), nn.Linear(int(model_dim * mlp_ratio), model_dim) ) - def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, - position_embeddings=None, position_embeddings_context=None): + def forward( + self, + x, + context, + target_attention_mask=None, + source_attention_mask=None, + position_embeddings=None, + position_embeddings_context=None, + ): if self.has_self_attn: normed = self.norm_self_attn(x) - attn_out = self.self_attn(normed, mask=target_attention_mask, - position_embeddings=position_embeddings, - position_embeddings_context=position_embeddings) + attn_out = self.self_attn( + normed, + mask=target_attention_mask, + position_embeddings=position_embeddings, + position_embeddings_context=position_embeddings, + ) x = x + attn_out normed = self.norm_cross_attn(x) - attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, - position_embeddings=position_embeddings, - position_embeddings_context=position_embeddings_context) + attn_out = self.cross_attn( + normed, + mask=source_attention_mask, + context=context, + position_embeddings=position_embeddings, + position_embeddings_context=position_embeddings_context, + ) x = x + attn_out x = x + self.mlp(self.norm_mlp(x)) @@ -1518,8 +1549,9 @@ class LLMAdapter(nn.Module): Uses T5 token IDs as target input, embeds them, and cross-attends to Qwen3 hidden states. """ - def __init__(self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16, - embed=None, self_attn=False, layer_norm=False): + def __init__( + self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16, embed=None, self_attn=False, layer_norm=False + ): super().__init__() if embed is not None: self.embed = nn.Embedding.from_pretrained(embed.weight) @@ -1530,11 +1562,12 @@ class LLMAdapter(nn.Module): else: self.in_proj = nn.Identity() self.rotary_emb = AdapterRotaryEmbedding(model_dim // num_heads) - self.blocks = nn.ModuleList([ - LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, - self_attn=self_attn, layer_norm=layer_norm) - for _ in range(num_layers) - ]) + self.blocks = nn.ModuleList( + [ + LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, self_attn=self_attn, layer_norm=layer_norm) + for _ in range(num_layers) + ] + ) self.out_proj = nn.Linear(model_dim, target_dim) self.norm = LLMAdapterRMSNorm(target_dim) @@ -1556,41 +1589,119 @@ class LLMAdapter(nn.Module): position_embeddings = self.rotary_emb(x, position_ids) position_embeddings_context = self.rotary_emb(x, position_ids_context) for block in self.blocks: - x = block(x, context, target_attention_mask=target_attention_mask, - source_attention_mask=source_attention_mask, - position_embeddings=position_embeddings, - position_embeddings_context=position_embeddings_context) + x = block( + x, + context, + target_attention_mask=target_attention_mask, + source_attention_mask=source_attention_mask, + position_embeddings=position_embeddings, + position_embeddings_context=position_embeddings_context, + ) return self.norm(self.out_proj(x)) +class Anima(nn.Module): + """ + Wrapper class for the MiniTrainDIT and LLM Adapter. + """ + + LATENT_CHANNELS = 16 + + def __init__(self, dit_config: dict): + super().__init__() + self.net = MiniTrainDIT(**dit_config) + + @property + def device(self): + return self.net.device + + @property + def dtype(self): + return self.net.dtype + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + return self.net(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs) + + def preprocess_text_embeds( + self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None + ): + if target_input_ids is not None: + return self.net.llm_adapter( + source_hidden_states, + target_input_ids, + target_attention_mask=target_attention_mask, + source_attention_mask=source_attention_mask, + ) + else: + return source_hidden_states + + # VAE Wrapper # VAE normalization constants ANIMA_VAE_MEAN = [ - -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, - 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, ] ANIMA_VAE_STD = [ - 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, - 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, ] # DiT config detection from state_dict -KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer'] +KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"] -def get_dit_config(state_dict, key_prefix=''): +def get_dit_config(state_dict, key_prefix=""): """Derive DiT configuration from state_dict weight shapes.""" dit_config = {} dit_config["max_img_h"] = 512 dit_config["max_img_w"] = 512 dit_config["max_frames"] = 128 concat_padding_mask = True - dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask) + dit_config["in_channels"] = (state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[1] // 4) - int( + concat_padding_mask + ) dit_config["out_channels"] = 16 dit_config["patch_spatial"] = 2 dit_config["patch_temporal"] = 1 - dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0] + dit_config["model_channels"] = state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[0] dit_config["concat_padding_mask"] = concat_padding_mask dit_config["crossattn_emb_channels"] = 1024 dit_config["pos_emb_cls"] = "rope3d" diff --git a/library/anima_train_utils.py b/library/anima_train_utils.py index ef0016b5..edac2fb7 100644 --- a/library/anima_train_utils.py +++ b/library/anima_train_utils.py @@ -32,6 +32,7 @@ from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas # Anima-specific training arguments + def add_anima_training_arguments(parser: argparse.ArgumentParser): """Add Anima-specific training arguments to the parser.""" parser.add_argument( @@ -169,20 +170,20 @@ def get_noisy_model_input_and_timesteps( """ bs = latents.shape[0] - timestep_sample_method = getattr(args, 'timestep_sample_method', 'logit_normal') - sigmoid_scale = getattr(args, 'sigmoid_scale', 1.0) - shift = getattr(args, 'discrete_flow_shift', 1.0) + timestep_sample_method = getattr(args, "timestep_sample_method", "logit_normal") + sigmoid_scale = getattr(args, "sigmoid_scale", 1.0) + shift = getattr(args, "discrete_flow_shift", 1.0) - if timestep_sample_method == 'logit_normal': + if timestep_sample_method == "logit_normal": dist = torch.distributions.normal.Normal(0, 1) - elif timestep_sample_method == 'uniform': + elif timestep_sample_method == "uniform": dist = torch.distributions.uniform.Uniform(0, 1) else: raise NotImplementedError(f"Unknown timestep_sample_method: {timestep_sample_method}") t = dist.sample((bs,)).to(device) - if timestep_sample_method == 'logit_normal': + if timestep_sample_method == "logit_normal": t = t * sigmoid_scale t = torch.sigmoid(t) @@ -196,10 +197,10 @@ def get_noisy_model_input_and_timesteps( # Create noisy input: (1 - t) * latents + t * noise t_expanded = t.view(-1, *([1] * (latents.ndim - 1))) - ip_noise_gamma = getattr(args, 'ip_noise_gamma', None) + ip_noise_gamma = getattr(args, "ip_noise_gamma", None) if ip_noise_gamma: xi = torch.randn_like(latents, device=latents.device, dtype=dtype) - if getattr(args, 'ip_noise_gamma_random_strength', False): + if getattr(args, "ip_noise_gamma_random_strength", False): ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * ip_noise_gamma noisy_model_input = (1 - t_expanded) * latents + t_expanded * (noise + ip_noise_gamma * xi) else: @@ -213,6 +214,7 @@ def get_noisy_model_input_and_timesteps( # Loss weighting + def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor: """Compute loss weighting for Anima training. @@ -276,15 +278,15 @@ def get_anima_param_groups( # Store original name for debugging p.original_name = name - if 'llm_adapter' in name: + if "llm_adapter" in name: llm_adapter_params.append(p) - elif '.self_attn' in name: + elif ".self_attn" in name: self_attn_params.append(p) - elif '.cross_attn' in name: + elif ".cross_attn" in name: cross_attn_params.append(p) - elif '.mlp' in name: + elif ".mlp" in name: mlp_params.append(p) - elif '.adaln_modulation' in name: + elif ".adaln_modulation" in name: mod_params.append(p) else: base_params.append(p) @@ -311,9 +313,9 @@ def get_anima_param_groups( p.requires_grad_(False) logger.info(f" Frozen {name} params ({len(params)} parameters)") elif len(params) > 0: - param_groups.append({'params': params, 'lr': lr}) + param_groups.append({"params": params, "lr": lr}) - total_trainable = sum(p.numel() for group in param_groups for p in group['params'] if p.requires_grad) + total_trainable = sum(p.numel() for group in param_groups for p in group["params"] if p.requires_grad) logger.info(f"Total trainable parameters: {total_trainable:,}") return param_groups @@ -328,10 +330,9 @@ def save_anima_model_on_train_end( dit: anima_models.MiniTrainDIT, ): """Save Anima model at the end of training.""" + def sd_saver(ckpt_file, epoch_no, global_step): - sai_metadata = train_util.get_sai_model_spec( - None, args, False, False, False, is_stable_diffusion_ckpt=True - ) + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True) dit_sd = dit.state_dict() # Save with 'net.' prefix for ComfyUI compatibility anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype) @@ -350,10 +351,9 @@ def save_anima_model_on_epoch_end_or_stepwise( dit: anima_models.MiniTrainDIT, ): """Save Anima model at epoch end or specific steps.""" + def sd_saver(ckpt_file, epoch_no, global_step): - sai_metadata = train_util.get_sai_model_spec( - None, args, False, False, False, is_stable_diffusion_ckpt=True - ) + sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True) dit_sd = dit.state_dict() anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype) @@ -410,9 +410,7 @@ def do_sample( generator = torch.manual_seed(seed) else: generator = None - noise = torch.randn( - latent.size(), dtype=torch.float32, generator=generator, device="cpu" - ).to(dtype).to(device) + noise = torch.randn(latent.size(), dtype=torch.float32, generator=generator, device="cpu").to(dtype).to(device) # Timestep schedule: linear from 1.0 to 0.0 sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype) @@ -512,10 +510,20 @@ def sample_images( with torch.no_grad(), accelerator.autocast(): for prompt_dict in prompts: _sample_image_inference( - accelerator, args, dit, text_encoder, vae, vae_scale, - tokenize_strategy, text_encoding_strategy, - save_dir, prompt_dict, epoch, steps, - sample_prompts_te_outputs, prompt_replacement, + accelerator, + args, + dit, + text_encoder, + vae, + vae_scale, + tokenize_strategy, + text_encoding_strategy, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, ) # Restore RNG state @@ -527,10 +535,20 @@ def sample_images( def _sample_image_inference( - accelerator, args, dit, text_encoder, vae, vae_scale, - tokenize_strategy, text_encoding_strategy, - save_dir, prompt_dict, epoch, steps, - sample_prompts_te_outputs, prompt_replacement, + accelerator, + args, + dit, + text_encoder, + vae, + vae_scale, + tokenize_strategy, + text_encoding_strategy, + save_dir, + prompt_dict, + epoch, + steps, + sample_prompts_te_outputs, + prompt_replacement, ): """Generate a single sample image.""" prompt = prompt_dict.get("prompt", "") @@ -585,7 +603,7 @@ def _sample_image_inference( t5_attn_mask = t5_attn_mask.to(accelerator.device) # Process through LLM adapter if available - if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'): + if dit.use_llm_adapter and hasattr(dit, "llm_adapter"): crossattn_emb = dit.llm_adapter( source_hidden_states=prompt_embeds, target_input_ids=t5_input_ids, @@ -613,7 +631,7 @@ def _sample_image_inference( neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long) neg_t5_am = neg_t5_am.to(accelerator.device) - if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'): + if dit.use_llm_adapter and hasattr(dit, "llm_adapter"): neg_crossattn_emb = dit.llm_adapter( source_hidden_states=neg_pe, target_input_ids=neg_t5_ids, @@ -627,9 +645,16 @@ def _sample_image_inference( # Generate sample clean_memory_on_device(accelerator.device) latents = do_sample( - height, width, seed, dit, crossattn_emb, - sample_steps, dit.t_embedding_norm.weight.dtype, - accelerator.device, scale, neg_crossattn_emb, + height, + width, + seed, + dit, + crossattn_emb, + sample_steps, + dit.t_embedding_norm.weight.dtype, + accelerator.device, + scale, + neg_crossattn_emb, ) # Decode latents @@ -662,4 +687,5 @@ def _sample_image_inference( if "wandb" in [tracker.name for tracker in accelerator.trackers]: wandb_tracker = accelerator.get_tracker("wandb") import wandb + wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) diff --git a/library/anima_utils.py b/library/anima_utils.py index 8c171e0e..430c20c2 100644 --- a/library/anima_utils.py +++ b/library/anima_utils.py @@ -6,7 +6,10 @@ import torch import torch.nn as nn from safetensors.torch import load_file, save_file from accelerate.utils import set_module_tensor_to_device # kept for potential future use +from accelerate import init_empty_weights +from library.fp8_optimization_utils import apply_fp8_monkey_patch +from library.lora_utils import load_safetensors_with_lora_and_fp8 from .utils import setup_logging setup_logging() @@ -18,7 +21,7 @@ from library import anima_models # Keys that should stay in high precision (float32/bfloat16, not quantized) -KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer'] +KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"] def load_safetensors(path: str, device: str = "cpu", dtype: Optional[torch.dtype] = None) -> Dict[str, torch.Tensor]: @@ -53,6 +56,7 @@ def load_anima_dit( logger.info(f"Loading Anima DiT from {dit_path}") if disable_mmap: from library.safetensors_utils import load_safetensors as load_safetensors_no_mmap + state_dict = load_safetensors_no_mmap(dit_path, device="cpu", disable_mmap=True) else: state_dict = load_file(dit_path, device="cpu") @@ -60,8 +64,8 @@ def load_anima_dit( # Remove 'net.' prefix if present new_state_dict = {} for k, v in state_dict.items(): - if k.startswith('net.'): - k = k[len('net.'):] + if k.startswith("net."): + k = k[len("net.") :] new_state_dict[k] = v state_dict = new_state_dict @@ -71,18 +75,20 @@ def load_anima_dit( # Detect LLM adapter if llm_adapter_path is not None: use_llm_adapter = True - dit_config['use_llm_adapter'] = True + dit_config["use_llm_adapter"] = True llm_adapter_state_dict = load_safetensors(llm_adapter_path, device="cpu") - elif 'llm_adapter.out_proj.weight' in state_dict: + elif "llm_adapter.out_proj.weight" in state_dict: use_llm_adapter = True - dit_config['use_llm_adapter'] = True + dit_config["use_llm_adapter"] = True llm_adapter_state_dict = None # Loaded as part of DiT else: use_llm_adapter = False llm_adapter_state_dict = None - logger.info(f"DiT config: model_channels={dit_config['model_channels']}, num_blocks={dit_config['num_blocks']}, " - f"num_heads={dit_config['num_heads']}, use_llm_adapter={use_llm_adapter}") + logger.info( + f"DiT config: model_channels={dit_config['model_channels']}, num_blocks={dit_config['num_blocks']}, " + f"num_heads={dit_config['num_heads']}, use_llm_adapter={use_llm_adapter}" + ) # Build model normally on CPU — buffers get proper values from __init__ dit = anima_models.MiniTrainDIT(**dit_config) @@ -96,9 +102,11 @@ def load_anima_dit( missing, unexpected = dit.load_state_dict(state_dict, strict=False) if missing: # Filter out expected missing buffers (initialized in __init__, not saved in checkpoint) - unexpected_missing = [k for k in missing if not any( - buf_name in k for buf_name in ('seq', 'dim_spatial_range', 'dim_temporal_range', 'inv_freq') - )] + unexpected_missing = [ + k + for k in missing + if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq")) + ] if unexpected_missing: logger.warning(f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}") if unexpected: @@ -106,9 +114,7 @@ def load_anima_dit( # Apply per-parameter dtype (high precision for 1D/critical, transformer_dtype for rest) for name, p in dit.named_parameters(): - dtype_to_use = dtype if ( - any(keyword in name for keyword in KEEP_IN_HIGH_PRECISION) or p.ndim == 1 - ) else transformer_dtype + dtype_to_use = dtype if (any(keyword in name for keyword in KEEP_IN_HIGH_PRECISION) or p.ndim == 1) else transformer_dtype p.data = p.data.to(dtype=dtype_to_use) dit.to(device) @@ -116,6 +122,127 @@ def load_anima_dit( return dit +FP8_OPTIMIZATION_TARGET_KEYS = ["blocks", ""] +FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer"] + + +def load_anima_model( + device: Union[str, torch.device], + dit_path: str, + attn_mode: str, + split_attn: bool, + loading_device: Union[str, torch.device], + dit_weight_dtype: Optional[torch.dtype], + fp8_scaled: bool = False, + lora_weights_list: Optional[Dict[str, torch.Tensor]] = None, + lora_multipliers: Optional[list[float]] = None, +) -> anima_models.Anima: + """ + Load a HunyuanImage model from the specified checkpoint. + + Args: + device (Union[str, torch.device]): Device for optimization or merging + dit_path (str): Path to the DiT model checkpoint. + attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc. + split_attn (bool): Whether to use split attention. + loading_device (Union[str, torch.device]): Device to load the model weights on. + dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights. + If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype. + fp8_scaled (bool): Whether to use fp8 scaling for the model weights. + lora_weights_list (Optional[Dict[str, torch.Tensor]]): LoRA weights to apply, if any. + lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any. + """ + # dit_weight_dtype is None for fp8_scaled + assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None) + + device = torch.device(device) + loading_device = torch.device(loading_device) + + # We currently support fixed DiT config for Anima models + dit_config = { + "max_img_h": 512, + "max_img_w": 512, + "max_frames": 128, + "in_channels": 16, + "out_channels": 16, + "patch_spatial": 2, + "patch_temporal": 1, + "model_channels": 2048, + "concat_padding_mask": True, + "crossattn_emb_channels": 1024, + "pos_emb_cls": "rope3d", + "pos_emb_learnable": True, + "pos_emb_interpolation": "crop", + "min_fps": 1, + "max_fps": 30, + "use_adaln_lora": True, + "adaln_lora_dim": 256, + "num_blocks": 28, + "num_heads": 16, + "extra_per_block_abs_pos_emb": False, + "rope_h_extrapolation_ratio": 4.0, + "rope_w_extrapolation_ratio": 4.0, + "rope_t_extrapolation_ratio": 1.0, + "extra_h_extrapolation_ratio": 1.0, + "extra_w_extrapolation_ratio": 1.0, + "extra_t_extrapolation_ratio": 1.0, + "rope_enable_fps_modulation": False, + "use_llm_adapter": True, + "attn_mode": attn_mode, + "split_attn": split_attn, + } + # model = create_model(attn_mode, split_attn, dit_weight_dtype) + with init_empty_weights(): + model = anima_models.Anima(dit_config) + if dit_weight_dtype is not None: + model.to(dit_weight_dtype) + + # load model weights with dynamic fp8 optimization and LoRA merging if needed + logger.info(f"Loading DiT model from {dit_path}, device={loading_device}") + + sd = load_safetensors_with_lora_and_fp8( + model_files=dit_path, + lora_weights_list=lora_weights_list, + lora_multipliers=lora_multipliers, + fp8_optimization=fp8_scaled, + calc_device=device, + move_to_device=(loading_device == device), + dit_weight_dtype=dit_weight_dtype, + target_keys=FP8_OPTIMIZATION_TARGET_KEYS, + exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS, + ) + + if fp8_scaled: + apply_fp8_monkey_patch(model, sd, use_scaled_mm=False) + + if loading_device.type != "cpu": + # make sure all the model weights are on the loading_device + logger.info(f"Moving weights to {loading_device}") + for key in sd.keys(): + sd[key] = sd[key].to(loading_device) + + missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) + if missing: + # Filter out expected missing buffers (initialized in __init__, not saved in checkpoint) + unexpected_missing = [ + k + for k in missing + if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq")) + ] + if unexpected_missing: + # Raise error to avoid silent failures + raise RuntimeError( + f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}" + ) + missing = {} # all missing keys were expected + if unexpected: + # Raise error to avoid silent failures + raise RuntimeError(f"Unexpected keys in checkpoint: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") + logger.info(f"Loaded DiT model from {dit_path}, unexpected missing keys: {len(missing)}, unexpected keys: {len(unexpected)}") + + return model + + def load_anima_vae(vae_path: str, dtype: torch.dtype = torch.float32, device: str = "cpu"): """Load WanVAE from a safetensors/pth file. @@ -139,14 +266,14 @@ def load_anima_vae(vae_path: str, dtype: torch.dtype = torch.float32, device: st from library.anima_vae import WanVAE_ # Build model - with torch.device('meta'): + with torch.device("meta"): vae = WanVAE_(**vae_config) # Load state dict - if vae_path.endswith('.safetensors'): - vae_sd = load_file(vae_path, device='cpu') + if vae_path.endswith(".safetensors"): + vae_sd = load_file(vae_path, device="cpu") else: - vae_sd = torch.load(vae_path, map_location='cpu', weights_only=True) + vae_sd = torch.load(vae_path, map_location="cpu", weights_only=True) vae.load_state_dict(vae_sd, assign=True) vae = vae.eval().requires_grad_(False).to(device, dtype=dtype) @@ -175,7 +302,7 @@ def load_qwen3_tokenizer(qwen3_path: str): if os.path.isdir(qwen3_path): tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True) else: - config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b') + config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b") if not os.path.exists(config_dir): raise FileNotFoundError( f"Qwen3 config directory not found at {config_dir}. " @@ -209,12 +336,10 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16 if os.path.isdir(qwen3_path): # Directory with full model tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True) - model = transformers.AutoModelForCausalLM.from_pretrained( - qwen3_path, torch_dtype=dtype, local_files_only=True - ).model + model = transformers.AutoModelForCausalLM.from_pretrained(qwen3_path, torch_dtype=dtype, local_files_only=True).model else: # Single safetensors file - use configs/qwen3_06b/ for config - config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b') + config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b") if not os.path.exists(config_dir): raise FileNotFoundError( f"Qwen3 config directory not found at {config_dir}. " @@ -227,16 +352,16 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16 model = transformers.Qwen3ForCausalLM(qwen3_config).model # Load weights - if qwen3_path.endswith('.safetensors'): - state_dict = load_file(qwen3_path, device='cpu') + if qwen3_path.endswith(".safetensors"): + state_dict = load_file(qwen3_path, device="cpu") else: - state_dict = torch.load(qwen3_path, map_location='cpu', weights_only=True) + state_dict = torch.load(qwen3_path, map_location="cpu", weights_only=True) # Remove 'model.' prefix if present new_sd = {} for k, v in state_dict.items(): - if k.startswith('model.'): - new_sd[k[len('model.'):]] = v + if k.startswith("model."): + new_sd[k[len("model.") :]] = v else: new_sd[k] = v @@ -265,11 +390,11 @@ def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None): return T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True) # Use bundled config - config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 't5_old') + config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "t5_old") if os.path.exists(config_dir): return T5TokenizerFast( - vocab_file=os.path.join(config_dir, 'spiece.model'), - tokenizer_file=os.path.join(config_dir, 'tokenizer.json'), + vocab_file=os.path.join(config_dir, "spiece.model"), + tokenizer_file=os.path.join(config_dir, "tokenizer.json"), ) raise FileNotFoundError( @@ -291,9 +416,9 @@ def save_anima_model(save_path: str, dit_state_dict: Dict[str, torch.Tensor], dt for k, v in dit_state_dict.items(): if dtype is not None: v = v.to(dtype) - prefixed_sd['net.' + k] = v.contiguous() + prefixed_sd["net." + k] = v.contiguous() - save_file(prefixed_sd, save_path, metadata={'format': 'pt'}) + save_file(prefixed_sd, save_path, metadata={"format": "pt"}) logger.info(f"Saved Anima model to {save_path}") diff --git a/library/anima_vae.py b/library/anima_vae.py index 872bdfa2..3f6c7d1b 100644 --- a/library/anima_vae.py +++ b/library/anima_vae.py @@ -16,8 +16,7 @@ class CausalConv3d(nn.Conv3d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._padding = (self.padding[2], self.padding[2], self.padding[1], - self.padding[1], 2 * self.padding[0], 0) + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) self.padding = (0, 0, 0) def forward(self, x, cache_x=None): @@ -41,12 +40,10 @@ class RMS_norm(nn.Module): self.channel_first = channel_first self.scale = dim**0.5 self.gamma = nn.Parameter(torch.ones(shape)) - self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 def forward(self, x): - return F.normalize( - x, dim=(1 if self.channel_first else - -1)) * self.scale * self.gamma + self.bias + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias class Upsample(nn.Upsample): @@ -61,65 +58,48 @@ class Upsample(nn.Upsample): class Resample(nn.Module): def __init__(self, dim, mode): - assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', - 'downsample3d') + assert mode in ("none", "upsample2d", "upsample3d", "downsample2d", "downsample3d") super().__init__() self.dim = dim self.mode = mode # layers - if mode == 'upsample2d': + if mode == "upsample2d": self.resample = nn.Sequential( - Upsample(scale_factor=(2., 2.), mode='nearest-exact'), - nn.Conv2d(dim, dim // 2, 3, padding=1)) - elif mode == 'upsample3d': + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + elif mode == "upsample3d": self.resample = nn.Sequential( - Upsample(scale_factor=(2., 2.), mode='nearest-exact'), - nn.Conv2d(dim, dim // 2, 3, padding=1)) - self.time_conv = CausalConv3d( - dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) - elif mode == 'downsample2d': - self.resample = nn.Sequential( - nn.ZeroPad2d((0, 1, 0, 1)), - nn.Conv2d(dim, dim, 3, stride=(2, 2))) - elif mode == 'downsample3d': - self.resample = nn.Sequential( - nn.ZeroPad2d((0, 1, 0, 1)), - nn.Conv2d(dim, dim, 3, stride=(2, 2))) - self.time_conv = CausalConv3d( - dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) else: self.resample = nn.Identity() def forward(self, x, feat_cache=None, feat_idx=[0]): b, c, t, h, w = x.size() - if self.mode == 'upsample3d': + if self.mode == "upsample3d": if feat_cache is not None: idx = feat_idx[0] if feat_cache[idx] is None: - feat_cache[idx] = 'Rep' + feat_cache[idx] = "Rep" feat_idx[0] += 1 else: cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] != 'Rep': + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] == 'Rep': - cache_x = torch.cat([ - torch.zeros_like(cache_x).to(cache_x.device), - cache_x - ], - dim=2) - if feat_cache[idx] == 'Rep': + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": x = self.time_conv(x) else: x = self.time_conv(x, feat_cache[idx]) @@ -127,15 +107,14 @@ class Resample(nn.Module): feat_idx[0] += 1 x = x.reshape(b, 2, c, t, h, w) - x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), - 3) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) x = x.reshape(b, c, t * 2, h, w) t = x.shape[2] - x = rearrange(x, 'b c t h w -> (b t) c h w') + x = rearrange(x, "b c t h w -> (b t) c h w") x = self.resample(x) - x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - if self.mode == 'downsample3d': + if self.mode == "downsample3d": if feat_cache is not None: idx = feat_idx[0] if feat_cache[idx] is None: @@ -144,8 +123,7 @@ class Resample(nn.Module): else: cache_x = x[:, :, -1:, :, :].clone() - x = self.time_conv( - torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) feat_cache[idx] = cache_x feat_idx[0] += 1 return x @@ -166,8 +144,8 @@ class Resample(nn.Module): nn.init.zeros_(conv_weight) c1, c2, t, h, w = conv_weight.size() init_matrix = torch.eye(c1 // 2, c2) - conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix - conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix conv.weight.data.copy_(conv_weight) nn.init.zeros_(conv.bias.data) @@ -181,12 +159,15 @@ class ResidualBlock(nn.Module): # layers self.residual = nn.Sequential( - RMS_norm(in_dim, images=False), nn.SiLU(), + RMS_norm(in_dim, images=False), + nn.SiLU(), CausalConv3d(in_dim, out_dim, 3, padding=1), - RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), - CausalConv3d(out_dim, out_dim, 3, padding=1)) - self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ - if in_dim != out_dim else nn.Identity() + RMS_norm(out_dim, images=False), + nn.SiLU(), + nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1), + ) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() def forward(self, x, feat_cache=None, feat_idx=[0]): h = self.shortcut(x) @@ -196,11 +177,7 @@ class ResidualBlock(nn.Module): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -229,13 +206,10 @@ class AttentionBlock(nn.Module): def forward(self, x): identity = x b, c, t, h, w = x.size() - x = rearrange(x, 'b c t h w -> (b t) c h w') + x = rearrange(x, "b c t h w -> (b t) c h w") x = self.norm(x) # compute query, key, value - q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, - -1).permute(0, 1, 3, - 2).contiguous().chunk( - 3, dim=-1) + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1) # apply attention x = F.scaled_dot_product_attention( @@ -247,20 +221,22 @@ class AttentionBlock(nn.Module): # output x = self.proj(x) - x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + x = rearrange(x, "(b t) c h w-> b c t h w", t=t) return x + identity class Encoder3d(nn.Module): - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], - dropout=0.0): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): super().__init__() self.dim = dim self.z_dim = z_dim @@ -288,21 +264,18 @@ class Encoder3d(nn.Module): # downsample block if i != len(dim_mult) - 1: - mode = 'downsample3d' if temperal_downsample[ - i] else 'downsample2d' + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" downsamples.append(Resample(out_dim, mode=mode)) scale /= 2.0 self.downsamples = nn.Sequential(*downsamples) # middle blocks self.middle = nn.Sequential( - ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), - ResidualBlock(out_dim, out_dim, dropout)) + ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), ResidualBlock(out_dim, out_dim, dropout) + ) # output blocks - self.head = nn.Sequential( - RMS_norm(out_dim, images=False), nn.SiLU(), - CausalConv3d(out_dim, z_dim, 3, padding=1)) + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1)) def forward(self, x, feat_cache=None, feat_idx=[0]): if feat_cache is not None: @@ -310,11 +283,7 @@ class Encoder3d(nn.Module): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -342,11 +311,7 @@ class Encoder3d(nn.Module): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -357,14 +322,16 @@ class Encoder3d(nn.Module): class Decoder3d(nn.Module): - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_upsample=[False, True, True], - dropout=0.0): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + ): super().__init__() self.dim = dim self.z_dim = z_dim @@ -375,15 +342,15 @@ class Decoder3d(nn.Module): # dimensions dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] - scale = 1.0 / 2**(len(dim_mult) - 2) + scale = 1.0 / 2 ** (len(dim_mult) - 2) # init block self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) # middle blocks self.middle = nn.Sequential( - ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), - ResidualBlock(dims[0], dims[0], dropout)) + ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), ResidualBlock(dims[0], dims[0], dropout) + ) # upsample blocks upsamples = [] @@ -399,15 +366,13 @@ class Decoder3d(nn.Module): # upsample block if i != len(dim_mult) - 1: - mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + mode = "upsample3d" if temperal_upsample[i] else "upsample2d" upsamples.append(Resample(out_dim, mode=mode)) scale *= 2.0 self.upsamples = nn.Sequential(*upsamples) # output blocks - self.head = nn.Sequential( - RMS_norm(out_dim, images=False), nn.SiLU(), - CausalConv3d(out_dim, 3, 3, padding=1)) + self.head = nn.Sequential(RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, 3, 3, padding=1)) def forward(self, x, feat_cache=None, feat_idx=[0]): ## conv1 @@ -416,11 +381,7 @@ class Decoder3d(nn.Module): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -448,11 +409,7 @@ class Decoder3d(nn.Module): cache_x = x[:, :, -CACHE_T:, :, :].clone() if cache_x.shape[2] < 2 and feat_cache[idx] is not None: # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -471,14 +428,16 @@ def count_conv3d(model): class WanVAE_(nn.Module): - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], - dropout=0.0): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): super().__init__() self.dim = dim self.z_dim = z_dim @@ -489,12 +448,10 @@ class WanVAE_(nn.Module): self.temperal_upsample = temperal_downsample[::-1] # modules - self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, - attn_scales, self.temperal_downsample, dropout) + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, - attn_scales, self.temperal_upsample, dropout) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) def forward(self, x): mu, log_var = self.encode(x) @@ -510,20 +467,15 @@ class WanVAE_(nn.Module): for i in range(iter_): self._enc_conv_idx = [0] if i == 0: - out = self.encoder( - x[:, :, :1, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) else: out_ = self.encoder( - x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx + ) out = torch.cat([out, out_], 2) mu, log_var = self.conv1(out).chunk(2, dim=1) if isinstance(scale[0], torch.Tensor): - mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( - 1, self.z_dim, 1, 1, 1) + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) else: mu = (mu - scale[0]) * scale[1] self.clear_cache() @@ -533,8 +485,7 @@ class WanVAE_(nn.Module): self.clear_cache() # z: [b,c,t,h,w] if isinstance(scale[0], torch.Tensor): - z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( - 1, self.z_dim, 1, 1, 1) + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) else: z = z / scale[1] + scale[0] iter_ = z.shape[2] @@ -542,15 +493,9 @@ class WanVAE_(nn.Module): for i in range(iter_): self._conv_idx = [0] if i == 0: - out = self.decoder( - x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) else: - out_ = self.decoder( - x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = torch.cat([out, out_], 2) self.clear_cache() return out @@ -571,7 +516,7 @@ class WanVAE_(nn.Module): self._conv_num = count_conv3d(self.decoder) self._conv_idx = [0] self._feat_map = [None] * self._conv_num - #cache encode + # cache encode self._enc_conv_num = count_conv3d(self.encoder) self._enc_conv_idx = [0] self._enc_feat_map = [None] * self._enc_conv_num diff --git a/library/qwen_image_autoencoder_kl.py b/library/qwen_image_autoencoder_kl.py new file mode 100644 index 00000000..61fc7550 --- /dev/null +++ b/library/qwen_image_autoencoder_kl.py @@ -0,0 +1,1452 @@ +# Copied and modified from Diffusers (via Musubi-Tuner). Original copyright notice follows. + +# Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We gratefully acknowledge the Wan Team for their outstanding contributions. +# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance. +# For more information about the Wan VAE, please refer to: +# - GitHub: https://github.com/Wan-Video/Wan2.1 +# - arXiv: https://arxiv.org/abs/2503.20314 + +import json +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +import logging + +from library.safetensors_utils import load_safetensors + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +CACHE_T = 2 + +SCALE_FACTOR = 8 # VAE downsampling factor + + +# region diffusers-vae + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: + # make sure sample is on the same device as the parameters and has same dtype + if generator is not None and generator.device.type != self.parameters.device.type: + rand_device = generator.device + else: + rand_device = self.parameters.device + sample = torch.randn(self.mean.shape, generator=generator, device=rand_device, dtype=self.parameters.dtype).to( + self.parameters.device + ) + x = self.mean + self.std * sample + return x + + def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self) -> torch.Tensor: + return self.mean + + +# endregion diffusers-vae + + +class QwenImageCausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +class QwenImageRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class QwenImageUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class QwenImageResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class QwenImageResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + assert non_linearity in ["silu"], "Only 'silu' non-linearity is supported currently." + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = nn.SiLU() # get_activation(non_linearity) + + # layers + self.norm1 = QwenImageRMS_norm(in_dim, images=False) + self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = QwenImageRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class QwenImageAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = QwenImageRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class QwenImageMidBlock(nn.Module): + """ + Middle block for QwenImageVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + # Create the components + resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(QwenImageAttentionBlock(dim)) + resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) + + return x + + +class QwenImageEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + input_channels (int): Number of input channels. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + input_channels: int = 3, + non_linearity: str = "silu", + ): + super().__init__() + assert non_linearity in ["silu"], "Only 'silu' non-linearity is supported currently." + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = nn.SiLU() # get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = QwenImageCausalConv3d(input_channels, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(QwenImageAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(QwenImageResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class QwenImageUpBlock(nn.Module): + """ + A block that handles upsampling for the QwenImageVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)]) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class QwenImageDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + output_channels (int): Number of output channels. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + output_channels: int = 3, + non_linearity: str = "silu", + ): + super().__init__() + assert non_linearity in ["silu"], "Only 'silu' non-linearity is supported currently." + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = nn.SiLU() # get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + + # Create and add the upsampling block + up_block = QwenImageUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *= 2.0 + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, output_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = False + + # @register_to_config + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: List[float] = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ], + latents_std: List[float] = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ], + input_channels: int = 3, + ) -> None: + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + self.latents_mean = latents_mean + self.latents_std = latents_std + + self.encoder = QwenImageEncoder3d( + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, input_channels + ) + self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) + + self.decoder = QwenImageDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, input_channels + ) + + self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules()) if self.decoder is not None else 0, + "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules()) if self.encoder is not None else 0, + } + + @property + def dtype(self): + return self.encoder.parameters().__next__().dtype + + @property + def device(self): + return self.encoder.parameters().__next__().device + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def clear_cache(self): + def _count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, QwenImageCausalConv3d): + count += 1 + return count + + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def _encode(self, x: torch.Tensor): + _, _, num_frame, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + self.clear_cache() + iter_ = 1 + (num_frame - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + + enc = self.quant_conv(out) + self.clear_cache() + return enc + + # @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[Dict[str, torch.Tensor], Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a dictionary is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return {"latent_dist": posterior} + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frame, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + self.clear_cache() + x = self.post_quant_conv(z) + for i in range(num_frame): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + + out = torch.clamp(out, min=-1.0, max=1.0) + self.clear_cache() + if not return_dict: + return (out,) + + return {"sample": out} + + # @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Dict[str, torch.Tensor], torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice)["sample"] for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z)["sample"] + + if not return_dict: + return (decoded,) + return {"sample": decoded} + + def decode_to_pixels(self, latents: torch.Tensor) -> torch.Tensor: + vae_scale_factor = 2 ** len(self.temperal_downsample) + # latents = qwen_image_utils.unpack_latents(latent, height, width, vae_scale_factor) + latents = latents.to(self.dtype) + latents_mean = torch.tensor(self.latents_mean).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = 1.0 / torch.tensor(self.latents_std).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + image = self.decode(latents, return_dict=False)[0][:, :, 0] # -1 to 1 + # return (image * 0.5 + 0.5).clamp(0.0, 1.0) # Convert to [0, 1] range + return image.clamp(-1.0, 1.0) + + def encode_pixels_to_latents(self, pixels: torch.Tensor) -> torch.Tensor: + """ + Convert pixel values to latents and apply normalization using mean/std. + + Args: + pixels (torch.Tensor): Input pixels in [0, 1] range with shape [B, C, H, W] or [B, C, T, H, W] + + Returns: + torch.Tensor: Normalized latents + """ + # # Convert from [0, 1] to [-1, 1] range + # pixels = (pixels * 2.0 - 1.0).clamp(-1.0, 1.0) + + # Handle 2D input by adding temporal dimension + if pixels.dim() == 4: + pixels = pixels.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W] + + pixels = pixels.to(self.dtype) + + # Encode to latent space + posterior = self.encode(pixels, return_dict=False)[0] + latents = posterior.mode() # Use mode instead of sampling for deterministic results + # latents = posterior.sample() + + # Apply normalization using mean/std + latents_mean = torch.tensor(self.latents_mean).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = 1.0 / torch.tensor(self.latents_std).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * latents_std + + return latents + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Dict[str, torch.Tensor], torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a dictionary instead of a plain tuple. + + Returns: + `dict` or `tuple`: + If return_dict is True, a dictionary is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return {"sample": dec} + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[Dict[str, torch.Tensor], torch.Tensor]: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`Dict[str, torch.Tensor]`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec + + +# region utils + +# This region is not included in the original implementation. Added for musubi-tuner/sd-scripts. + + +# Convert ComfyUI keys to standard keys if necessary +def convert_comfyui_state_dict(sd): + if "conv1.bias" not in sd: + return sd + + # Key mapping from ComfyUI VAE to official VAE, auto-generated by a script + key_map = { + "conv1": "quant_conv", + "conv2": "post_quant_conv", + "decoder.conv1": "decoder.conv_in", + "decoder.head.0": "decoder.norm_out", + "decoder.head.2": "decoder.conv_out", + "decoder.middle.0.residual.0": "decoder.mid_block.resnets.0.norm1", + "decoder.middle.0.residual.2": "decoder.mid_block.resnets.0.conv1", + "decoder.middle.0.residual.3": "decoder.mid_block.resnets.0.norm2", + "decoder.middle.0.residual.6": "decoder.mid_block.resnets.0.conv2", + "decoder.middle.1.norm": "decoder.mid_block.attentions.0.norm", + "decoder.middle.1.proj": "decoder.mid_block.attentions.0.proj", + "decoder.middle.1.to_qkv": "decoder.mid_block.attentions.0.to_qkv", + "decoder.middle.2.residual.0": "decoder.mid_block.resnets.1.norm1", + "decoder.middle.2.residual.2": "decoder.mid_block.resnets.1.conv1", + "decoder.middle.2.residual.3": "decoder.mid_block.resnets.1.norm2", + "decoder.middle.2.residual.6": "decoder.mid_block.resnets.1.conv2", + "decoder.upsamples.0.residual.0": "decoder.up_blocks.0.resnets.0.norm1", + "decoder.upsamples.0.residual.2": "decoder.up_blocks.0.resnets.0.conv1", + "decoder.upsamples.0.residual.3": "decoder.up_blocks.0.resnets.0.norm2", + "decoder.upsamples.0.residual.6": "decoder.up_blocks.0.resnets.0.conv2", + "decoder.upsamples.1.residual.0": "decoder.up_blocks.0.resnets.1.norm1", + "decoder.upsamples.1.residual.2": "decoder.up_blocks.0.resnets.1.conv1", + "decoder.upsamples.1.residual.3": "decoder.up_blocks.0.resnets.1.norm2", + "decoder.upsamples.1.residual.6": "decoder.up_blocks.0.resnets.1.conv2", + "decoder.upsamples.10.residual.0": "decoder.up_blocks.2.resnets.2.norm1", + "decoder.upsamples.10.residual.2": "decoder.up_blocks.2.resnets.2.conv1", + "decoder.upsamples.10.residual.3": "decoder.up_blocks.2.resnets.2.norm2", + "decoder.upsamples.10.residual.6": "decoder.up_blocks.2.resnets.2.conv2", + "decoder.upsamples.11.resample.1": "decoder.up_blocks.2.upsamplers.0.resample.1", + "decoder.upsamples.12.residual.0": "decoder.up_blocks.3.resnets.0.norm1", + "decoder.upsamples.12.residual.2": "decoder.up_blocks.3.resnets.0.conv1", + "decoder.upsamples.12.residual.3": "decoder.up_blocks.3.resnets.0.norm2", + "decoder.upsamples.12.residual.6": "decoder.up_blocks.3.resnets.0.conv2", + "decoder.upsamples.13.residual.0": "decoder.up_blocks.3.resnets.1.norm1", + "decoder.upsamples.13.residual.2": "decoder.up_blocks.3.resnets.1.conv1", + "decoder.upsamples.13.residual.3": "decoder.up_blocks.3.resnets.1.norm2", + "decoder.upsamples.13.residual.6": "decoder.up_blocks.3.resnets.1.conv2", + "decoder.upsamples.14.residual.0": "decoder.up_blocks.3.resnets.2.norm1", + "decoder.upsamples.14.residual.2": "decoder.up_blocks.3.resnets.2.conv1", + "decoder.upsamples.14.residual.3": "decoder.up_blocks.3.resnets.2.norm2", + "decoder.upsamples.14.residual.6": "decoder.up_blocks.3.resnets.2.conv2", + "decoder.upsamples.2.residual.0": "decoder.up_blocks.0.resnets.2.norm1", + "decoder.upsamples.2.residual.2": "decoder.up_blocks.0.resnets.2.conv1", + "decoder.upsamples.2.residual.3": "decoder.up_blocks.0.resnets.2.norm2", + "decoder.upsamples.2.residual.6": "decoder.up_blocks.0.resnets.2.conv2", + "decoder.upsamples.3.resample.1": "decoder.up_blocks.0.upsamplers.0.resample.1", + "decoder.upsamples.3.time_conv": "decoder.up_blocks.0.upsamplers.0.time_conv", + "decoder.upsamples.4.residual.0": "decoder.up_blocks.1.resnets.0.norm1", + "decoder.upsamples.4.residual.2": "decoder.up_blocks.1.resnets.0.conv1", + "decoder.upsamples.4.residual.3": "decoder.up_blocks.1.resnets.0.norm2", + "decoder.upsamples.4.residual.6": "decoder.up_blocks.1.resnets.0.conv2", + "decoder.upsamples.4.shortcut": "decoder.up_blocks.1.resnets.0.conv_shortcut", + "decoder.upsamples.5.residual.0": "decoder.up_blocks.1.resnets.1.norm1", + "decoder.upsamples.5.residual.2": "decoder.up_blocks.1.resnets.1.conv1", + "decoder.upsamples.5.residual.3": "decoder.up_blocks.1.resnets.1.norm2", + "decoder.upsamples.5.residual.6": "decoder.up_blocks.1.resnets.1.conv2", + "decoder.upsamples.6.residual.0": "decoder.up_blocks.1.resnets.2.norm1", + "decoder.upsamples.6.residual.2": "decoder.up_blocks.1.resnets.2.conv1", + "decoder.upsamples.6.residual.3": "decoder.up_blocks.1.resnets.2.norm2", + "decoder.upsamples.6.residual.6": "decoder.up_blocks.1.resnets.2.conv2", + "decoder.upsamples.7.resample.1": "decoder.up_blocks.1.upsamplers.0.resample.1", + "decoder.upsamples.7.time_conv": "decoder.up_blocks.1.upsamplers.0.time_conv", + "decoder.upsamples.8.residual.0": "decoder.up_blocks.2.resnets.0.norm1", + "decoder.upsamples.8.residual.2": "decoder.up_blocks.2.resnets.0.conv1", + "decoder.upsamples.8.residual.3": "decoder.up_blocks.2.resnets.0.norm2", + "decoder.upsamples.8.residual.6": "decoder.up_blocks.2.resnets.0.conv2", + "decoder.upsamples.9.residual.0": "decoder.up_blocks.2.resnets.1.norm1", + "decoder.upsamples.9.residual.2": "decoder.up_blocks.2.resnets.1.conv1", + "decoder.upsamples.9.residual.3": "decoder.up_blocks.2.resnets.1.norm2", + "decoder.upsamples.9.residual.6": "decoder.up_blocks.2.resnets.1.conv2", + "encoder.conv1": "encoder.conv_in", + "encoder.downsamples.0.residual.0": "encoder.down_blocks.0.norm1", + "encoder.downsamples.0.residual.2": "encoder.down_blocks.0.conv1", + "encoder.downsamples.0.residual.3": "encoder.down_blocks.0.norm2", + "encoder.downsamples.0.residual.6": "encoder.down_blocks.0.conv2", + "encoder.downsamples.1.residual.0": "encoder.down_blocks.1.norm1", + "encoder.downsamples.1.residual.2": "encoder.down_blocks.1.conv1", + "encoder.downsamples.1.residual.3": "encoder.down_blocks.1.norm2", + "encoder.downsamples.1.residual.6": "encoder.down_blocks.1.conv2", + "encoder.downsamples.10.residual.0": "encoder.down_blocks.10.norm1", + "encoder.downsamples.10.residual.2": "encoder.down_blocks.10.conv1", + "encoder.downsamples.10.residual.3": "encoder.down_blocks.10.norm2", + "encoder.downsamples.10.residual.6": "encoder.down_blocks.10.conv2", + "encoder.downsamples.2.resample.1": "encoder.down_blocks.2.resample.1", + "encoder.downsamples.3.residual.0": "encoder.down_blocks.3.norm1", + "encoder.downsamples.3.residual.2": "encoder.down_blocks.3.conv1", + "encoder.downsamples.3.residual.3": "encoder.down_blocks.3.norm2", + "encoder.downsamples.3.residual.6": "encoder.down_blocks.3.conv2", + "encoder.downsamples.3.shortcut": "encoder.down_blocks.3.conv_shortcut", + "encoder.downsamples.4.residual.0": "encoder.down_blocks.4.norm1", + "encoder.downsamples.4.residual.2": "encoder.down_blocks.4.conv1", + "encoder.downsamples.4.residual.3": "encoder.down_blocks.4.norm2", + "encoder.downsamples.4.residual.6": "encoder.down_blocks.4.conv2", + "encoder.downsamples.5.resample.1": "encoder.down_blocks.5.resample.1", + "encoder.downsamples.5.time_conv": "encoder.down_blocks.5.time_conv", + "encoder.downsamples.6.residual.0": "encoder.down_blocks.6.norm1", + "encoder.downsamples.6.residual.2": "encoder.down_blocks.6.conv1", + "encoder.downsamples.6.residual.3": "encoder.down_blocks.6.norm2", + "encoder.downsamples.6.residual.6": "encoder.down_blocks.6.conv2", + "encoder.downsamples.6.shortcut": "encoder.down_blocks.6.conv_shortcut", + "encoder.downsamples.7.residual.0": "encoder.down_blocks.7.norm1", + "encoder.downsamples.7.residual.2": "encoder.down_blocks.7.conv1", + "encoder.downsamples.7.residual.3": "encoder.down_blocks.7.norm2", + "encoder.downsamples.7.residual.6": "encoder.down_blocks.7.conv2", + "encoder.downsamples.8.resample.1": "encoder.down_blocks.8.resample.1", + "encoder.downsamples.8.time_conv": "encoder.down_blocks.8.time_conv", + "encoder.downsamples.9.residual.0": "encoder.down_blocks.9.norm1", + "encoder.downsamples.9.residual.2": "encoder.down_blocks.9.conv1", + "encoder.downsamples.9.residual.3": "encoder.down_blocks.9.norm2", + "encoder.downsamples.9.residual.6": "encoder.down_blocks.9.conv2", + "encoder.head.0": "encoder.norm_out", + "encoder.head.2": "encoder.conv_out", + "encoder.middle.0.residual.0": "encoder.mid_block.resnets.0.norm1", + "encoder.middle.0.residual.2": "encoder.mid_block.resnets.0.conv1", + "encoder.middle.0.residual.3": "encoder.mid_block.resnets.0.norm2", + "encoder.middle.0.residual.6": "encoder.mid_block.resnets.0.conv2", + "encoder.middle.1.norm": "encoder.mid_block.attentions.0.norm", + "encoder.middle.1.proj": "encoder.mid_block.attentions.0.proj", + "encoder.middle.1.to_qkv": "encoder.mid_block.attentions.0.to_qkv", + "encoder.middle.2.residual.0": "encoder.mid_block.resnets.1.norm1", + "encoder.middle.2.residual.2": "encoder.mid_block.resnets.1.conv1", + "encoder.middle.2.residual.3": "encoder.mid_block.resnets.1.norm2", + "encoder.middle.2.residual.6": "encoder.mid_block.resnets.1.conv2", + } + + new_state_dict = {} + for key in sd.keys(): + new_key = key + key_without_suffix = key.rsplit(".", 1)[0] + if key_without_suffix in key_map: + new_key = key.replace(key_without_suffix, key_map[key_without_suffix]) + new_state_dict[new_key] = sd[key] + + logger.info("Converted ComfyUI AutoencoderKL state dict keys to official format") + return new_state_dict + + +def load_vae( + vae_path: str, input_channels: int = 3, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False +) -> AutoencoderKLQwenImage: + """Load VAE from a given path.""" + VAE_CONFIG_JSON = """ +{ + "_class_name": "AutoencoderKLQwenImage", + "_diffusers_version": "0.34.0.dev0", + "attn_scales": [], + "base_dim": 96, + "dim_mult": [ + 1, + 2, + 4, + 4 + ], + "dropout": 0.0, + "latents_mean": [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921 + ], + "latents_std": [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.916 + ], + "num_res_blocks": 2, + "temperal_downsample": [ + false, + true, + true + ], + "z_dim": 16 +} +""" + logger.info("Initializing VAE") + config = json.loads(VAE_CONFIG_JSON) + vae = AutoencoderKLQwenImage( + base_dim=config["base_dim"], + z_dim=config["z_dim"], + dim_mult=config["dim_mult"], + num_res_blocks=config["num_res_blocks"], + attn_scales=config["attn_scales"], + temperal_downsample=config["temperal_downsample"], + dropout=config["dropout"], + latents_mean=config["latents_mean"], + latents_std=config["latents_std"], + input_channels=input_channels, + ) + + logger.info(f"Loading VAE from {vae_path}") + state_dict = load_safetensors(vae_path, device=device, disable_mmap=disable_mmap) + + # Convert ComfyUI VAE keys to official VAE keys + state_dict = convert_comfyui_state_dict(state_dict) + + info = vae.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"Loaded VAE: {info}") + + vae.to(device) + return vae diff --git a/library/strategy_anima.py b/library/strategy_anima.py index 9c9b0126..035d2331 100644 --- a/library/strategy_anima.py +++ b/library/strategy_anima.py @@ -45,8 +45,8 @@ class AnimaTokenizeStrategy(TokenizeStrategy): t5_tokenizer = anima_utils.load_t5_tokenizer(t5_tokenizer_path) self.qwen3_tokenizer = qwen3_tokenizer - self.t5_tokenizer = t5_tokenizer self.qwen3_max_length = qwen3_max_length + self.t5_tokenizer = t5_tokenizer self.t5_max_length = t5_max_length def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: @@ -54,22 +54,14 @@ class AnimaTokenizeStrategy(TokenizeStrategy): # Tokenize with Qwen3 qwen3_encoding = self.qwen3_tokenizer.batch_encode_plus( - text, - return_tensors="pt", - truncation=True, - padding="max_length", - max_length=self.qwen3_max_length, + text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.qwen3_max_length ) qwen3_input_ids = qwen3_encoding["input_ids"] qwen3_attn_mask = qwen3_encoding["attention_mask"] # Tokenize with T5 (for LLM Adapter target tokens) t5_encoding = self.t5_tokenizer.batch_encode_plus( - text, - return_tensors="pt", - truncation=True, - padding="max_length", - max_length=self.t5_max_length, + text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.t5_max_length ) t5_input_ids = t5_encoding["input_ids"] t5_attn_mask = t5_encoding["attention_mask"] @@ -84,23 +76,17 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): T5 tokens are passed through unchanged (only used by LLM Adapter). """ - def __init__( - self, - dropout_rate: float = 0.0, - ) -> None: - self.dropout_rate = dropout_rate + def __init__(self) -> None: + super().__init__() + # Cached unconditional embeddings (from encoding empty caption "") # Must be initialized via cache_uncond_embeddings() before text encoder is deleted self._uncond_prompt_embeds: Optional[torch.Tensor] = None # (1, seq_len, hidden) - self._uncond_attn_mask: Optional[torch.Tensor] = None # (1, seq_len) - self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len) - self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len) + self._uncond_attn_mask: Optional[torch.Tensor] = None # (1, seq_len) + self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len) + self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len) - def cache_uncond_embeddings( - self, - tokenize_strategy: TokenizeStrategy, - models: List[Any], - ) -> None: + def cache_uncond_embeddings(self, tokenize_strategy: TokenizeStrategy, models: List[Any]) -> None: """Pre-encode empty caption "" and cache the unconditional embeddings. Must be called before the text encoder is deleted from GPU. @@ -110,7 +96,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): logger.info("Caching unconditional embeddings for caption dropout (encoding empty caption)...") tokens = tokenize_strategy.tokenize("") with torch.no_grad(): - uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens, enable_dropout=False) + uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens) # Store as CPU tensors (1, seq_len, ...) to avoid GPU memory waste self._uncond_prompt_embeds = uncond_outputs[0].cpu() self._uncond_attn_mask = uncond_outputs[1].cpu() @@ -119,11 +105,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): logger.info(" Unconditional embeddings cached successfully") def encode_tokens( - self, - tokenize_strategy: TokenizeStrategy, - models: List[Any], - tokens: List[torch.Tensor], - enable_dropout: bool = True, + self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] ) -> List[torch.Tensor]: """Encode Qwen3 tokens and return embeddings + T5 token IDs. @@ -134,82 +116,19 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): Returns: [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask] """ + # Do not handle dropout here; handled dataset-side or in drop_cached_text_encoder_outputs() qwen3_text_encoder = models[0] qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens - # Handle dropout: replace dropped items with unconditional embeddings (matching diffusion-pipe-main) - batch_size = qwen3_input_ids.shape[0] - non_drop_indices = [] - for i in range(batch_size): - drop = enable_dropout and (self.dropout_rate > 0.0 and random.random() < self.dropout_rate) - if not drop: - non_drop_indices.append(i) + encoder_device = qwen3_text_encoder.device - encoder_device = qwen3_text_encoder.device if hasattr(qwen3_text_encoder, 'device') else next(qwen3_text_encoder.parameters()).device + qwen3_input_ids = qwen3_input_ids.to(encoder_device) + qwen3_attn_mask = qwen3_attn_mask.to(encoder_device) + outputs = qwen3_text_encoder(input_ids=qwen3_input_ids, attention_mask=qwen3_attn_mask) + prompt_embeds = outputs.last_hidden_state - if len(non_drop_indices) > 0 and len(non_drop_indices) < batch_size: - # Only encode non-dropped items to save compute - nd_input_ids = qwen3_input_ids[non_drop_indices].to(encoder_device) - nd_attn_mask = qwen3_attn_mask[non_drop_indices].to(encoder_device) - elif len(non_drop_indices) == batch_size: - nd_input_ids = qwen3_input_ids.to(encoder_device) - nd_attn_mask = qwen3_attn_mask.to(encoder_device) - else: - nd_input_ids = None - nd_attn_mask = None - - if nd_input_ids is not None: - outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask) - nd_encoded_text = outputs.last_hidden_state - # Zero out padding positions - nd_encoded_text[~nd_attn_mask.bool()] = 0 - - # Build full batch: fill non-dropped with encoded, dropped with unconditional - if len(non_drop_indices) == batch_size: - prompt_embeds = nd_encoded_text - attn_mask = qwen3_attn_mask.to(encoder_device) - else: - # Get unconditional embeddings - if self._uncond_prompt_embeds is not None: - uncond_pe = self._uncond_prompt_embeds[0] - uncond_am = self._uncond_attn_mask[0] - uncond_t5_ids = self._uncond_t5_input_ids[0] - uncond_t5_am = self._uncond_t5_attn_mask[0] - else: - # Encode empty caption on-the-fly (text encoder still available) - uncond_tokens = tokenize_strategy.tokenize("") - uncond_ids = uncond_tokens[0].to(encoder_device) - uncond_mask = uncond_tokens[1].to(encoder_device) - uncond_out = qwen3_text_encoder(input_ids=uncond_ids, attention_mask=uncond_mask) - uncond_pe = uncond_out.last_hidden_state[0] - uncond_pe[~uncond_mask[0].bool()] = 0 - uncond_am = uncond_mask[0] - uncond_t5_ids = uncond_tokens[2][0] - uncond_t5_am = uncond_tokens[3][0] - - seq_len = qwen3_input_ids.shape[1] - hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1] - dtype = nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype - - prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype) - attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype) - - if len(non_drop_indices) > 0: - prompt_embeds[non_drop_indices] = nd_encoded_text - attn_mask[non_drop_indices] = nd_attn_mask - - # Fill dropped items with unconditional embeddings - t5_input_ids = t5_input_ids.clone() - t5_attn_mask = t5_attn_mask.clone() - drop_indices = [i for i in range(batch_size) if i not in non_drop_indices] - for i in drop_indices: - prompt_embeds[i] = uncond_pe.to(device=encoder_device, dtype=dtype) - attn_mask[i] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype) - t5_input_ids[i] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype) - t5_attn_mask[i] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype) - - return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask] + return [prompt_embeds, qwen3_attn_mask, t5_input_ids, t5_attn_mask] def drop_cached_text_encoder_outputs( self, @@ -217,6 +136,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): attn_mask: torch.Tensor, t5_input_ids: torch.Tensor, t5_attn_mask: torch.Tensor, + caption_dropout_rates: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: """Apply dropout to cached text encoder outputs. @@ -224,37 +144,30 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy): Replaces dropped items with pre-cached unconditional embeddings (from encoding "") to match diffusion-pipe-main behavior. """ - if prompt_embeds is not None and self.dropout_rate > 0.0: - # Clone to avoid in-place modification of cached tensors - prompt_embeds = prompt_embeds.clone() - if attn_mask is not None: - attn_mask = attn_mask.clone() - if t5_input_ids is not None: - t5_input_ids = t5_input_ids.clone() - if t5_attn_mask is not None: - t5_attn_mask = t5_attn_mask.clone() + if caption_dropout_rates is None or all(caption_dropout_rates == 0.0): + return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask] - for i in range(prompt_embeds.shape[0]): - if random.random() < self.dropout_rate: - if self._uncond_prompt_embeds is not None: - # Use pre-cached unconditional embeddings - prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) - if attn_mask is not None: - attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype) - if t5_input_ids is not None: - t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype) - if t5_attn_mask is not None: - t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype) - else: - # Fallback: zero out (should not happen if cache_uncond_embeddings was called) - logger.warning("Unconditional embeddings not cached, falling back to zeros for caption dropout") - prompt_embeds[i] = torch.zeros_like(prompt_embeds[i]) - if attn_mask is not None: - attn_mask[i] = torch.zeros_like(attn_mask[i]) - if t5_input_ids is not None: - t5_input_ids[i] = torch.zeros_like(t5_input_ids[i]) - if t5_attn_mask is not None: - t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i]) + assert self._uncond_prompt_embeds is not None, "Unconditional embeddings not cached, cannot apply caption dropout" + + # Clone to avoid in-place modification of cached tensors + prompt_embeds = prompt_embeds.clone() + if attn_mask is not None: + attn_mask = attn_mask.clone() + if t5_input_ids is not None: + t5_input_ids = t5_input_ids.clone() + if t5_attn_mask is not None: + t5_attn_mask = t5_attn_mask.clone() + + for i in range(prompt_embeds.shape[0]): + if random.random() < caption_dropout_rates[i].item(): + # Use pre-cached unconditional embeddings + prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype) + if attn_mask is not None: + attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype) + if t5_input_ids is not None: + t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype) + if t5_attn_mask is not None: + t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype) return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask] @@ -297,6 +210,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): return False if "t5_attn_mask" not in npz: return False + if "caption_dropout_rate" not in npz: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -309,7 +224,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): attn_mask = data["attn_mask"] t5_input_ids = data["t5_input_ids"] t5_attn_mask = data["t5_attn_mask"] - return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask] + caption_dropout_rate = data["caption_dropout_rate"] + return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate] def cache_batch_outputs( self, @@ -344,6 +260,7 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): attn_mask_i = attn_mask[i] t5_input_ids_i = t5_input_ids[i] t5_attn_mask_i = t5_attn_mask[i] + caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32) if self.cache_to_disk: np.savez( @@ -352,9 +269,16 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): attn_mask=attn_mask_i, t5_input_ids=t5_input_ids_i, t5_attn_mask=t5_attn_mask_i, + caption_dropout_rate=caption_dropout_rate, ) else: - info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i) + info.text_encoder_outputs = ( + prompt_embeds_i, + attn_mask_i, + t5_input_ids_i, + t5_attn_mask_i, + caption_dropout_rate, + ) class AnimaLatentsCachingStrategy(LatentsCachingStrategy): @@ -374,18 +298,10 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy): return self.ANIMA_LATENTS_NPZ_SUFFIX def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: - return ( - os.path.splitext(absolute_path)[0] - + f"_{image_size[0]:04d}x{image_size[1]:04d}" - + self.ANIMA_LATENTS_NPZ_SUFFIX - ) + return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX - def is_disk_cached_latents_expected( - self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool - ): - return self._default_is_disk_cached_latents_expected( - 8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True - ) + def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) def load_latents_from_disk( self, npz_path: str, bucket_reso: Tuple[int, int] diff --git a/library/train_util.py b/library/train_util.py index 6874076d..bd1b15c1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -179,12 +179,15 @@ def split_train_val( class ImageInfo: - def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: + def __init__( + self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str, caption_dropout_rate: float = 0.0 + ) -> None: self.image_key: str = image_key self.num_repeats: int = num_repeats self.caption: str = caption self.is_reg: bool = is_reg self.absolute_path: str = absolute_path + self.caption_dropout_rate: float = caption_dropout_rate self.image_size: Tuple[int, int] = None self.resized_size: Tuple[int, int] = None self.bucket_reso: Tuple[int, int] = None @@ -197,7 +200,7 @@ class ImageInfo: ) self.cond_img_path: Optional[str] = None self.image: Optional[Image.Image] = None # optional, original PIL Image - self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs + self.text_encoder_outputs_npz: Optional[str] = None # filename. set in cache_text_encoder_outputs # new self.text_encoder_outputs: Optional[List[torch.Tensor]] = None @@ -2137,7 +2140,7 @@ class DreamBoothDataset(BaseDataset): num_train_images += num_repeats * len(img_paths) for img_path, caption, size in zip(img_paths, captions, sizes): - info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) + info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path, subset.caption_dropout_rate) info.resize_interpolation = ( subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation ) @@ -2338,7 +2341,7 @@ class FineTuningDataset(BaseDataset): if caption is None: caption = "" - image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path) + image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path, subset.caption_dropout_rate) image_info.resize_interpolation = ( subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation ) diff --git a/networks/lora_anima.py b/networks/lora_anima.py index c375ead7..7f44786e 100644 --- a/networks/lora_anima.py +++ b/networks/lora_anima.py @@ -1,18 +1,17 @@ -# LoRA network module for Anima -import math +# LoRA network module for Anima +import ast import os +import re from typing import Dict, List, Optional, Tuple, Type, Union -import numpy as np import torch from library.utils import setup_logging +from networks.lora_flux import LoRAModule, LoRAInfModule -setup_logging() import logging +setup_logging() logger = logging.getLogger(__name__) -from networks.lora_flux import LoRAModule, LoRAInfModule - def create_network( multiplier: float, @@ -29,68 +28,28 @@ def create_network( if network_alpha is None: network_alpha = 1.0 - # type_dims: [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim] - self_attn_dim = kwargs.get("self_attn_dim", None) - cross_attn_dim = kwargs.get("cross_attn_dim", None) - mlp_dim = kwargs.get("mlp_dim", None) - mod_dim = kwargs.get("mod_dim", None) - llm_adapter_dim = kwargs.get("llm_adapter_dim", None) - - if self_attn_dim is not None: - self_attn_dim = int(self_attn_dim) - if cross_attn_dim is not None: - cross_attn_dim = int(cross_attn_dim) - if mlp_dim is not None: - mlp_dim = int(mlp_dim) - if mod_dim is not None: - mod_dim = int(mod_dim) - if llm_adapter_dim is not None: - llm_adapter_dim = int(llm_adapter_dim) - - type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim] - if all([d is None for d in type_dims]): - type_dims = None - - # emb_dims: [x_embedder, t_embedder, final_layer] - emb_dims = kwargs.get("emb_dims", None) - if emb_dims is not None: - emb_dims = emb_dims.strip() - if emb_dims.startswith("[") and emb_dims.endswith("]"): - emb_dims = emb_dims[1:-1] - emb_dims = [int(d) for d in emb_dims.split(",")] - assert len(emb_dims) == 3, f"invalid emb_dims: {emb_dims}, must be 3 dimensions (x_embedder, t_embedder, final_layer)" - - # block selection - def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: - if selection == "all": - return [True] * total_blocks - if selection == "none" or selection == "": - return [False] * total_blocks - - selected = [False] * total_blocks - ranges = selection.split(",") - for r in ranges: - if "-" in r: - start, end = map(str.strip, r.split("-")) - start, end = int(start), int(end) - assert 0 <= start < total_blocks and 0 <= end < total_blocks and start <= end - for i in range(start, end + 1): - selected[i] = True - else: - index = int(r) - assert 0 <= index < total_blocks - selected[index] = True - return selected - - train_block_indices = kwargs.get("train_block_indices", None) - if train_block_indices is not None: - num_blocks = len(unet.blocks) if hasattr(unet, 'blocks') else 999 - train_block_indices = parse_block_selection(train_block_indices, num_blocks) - # train LLM adapter train_llm_adapter = kwargs.get("train_llm_adapter", False) if train_llm_adapter is not None: - train_llm_adapter = True if train_llm_adapter == "True" else False + train_llm_adapter = True if train_llm_adapter.lower() == "true" else False + + exclude_patterns = kwargs.get("exclude_patterns", None) + if exclude_patterns is None: + exclude_patterns = [] + else: + exclude_patterns = ast.literal_eval(exclude_patterns) + if not isinstance(exclude_patterns, list): + exclude_patterns = [exclude_patterns] + + # add default exclude patterns + exclude_patterns.append(r".*(_modulation|_norm|_embedder|final_layer).*") + + # regular expression for module selection: exclude and include + include_patterns = kwargs.get("include_patterns", None) + if include_patterns is not None: + include_patterns = ast.literal_eval(include_patterns) + if not isinstance(include_patterns, list): + include_patterns = [include_patterns] # rank/module dropout rank_dropout = kwargs.get("rank_dropout", None) @@ -103,7 +62,7 @@ def create_network( # verbose verbose = kwargs.get("verbose", False) if verbose is not None: - verbose = True if verbose == "True" else False + verbose = True if verbose.lower() == "true" else False network = LoRANetwork( text_encoders, @@ -115,9 +74,8 @@ def create_network( rank_dropout=rank_dropout, module_dropout=module_dropout, train_llm_adapter=train_llm_adapter, - type_dims=type_dims, - emb_dims=emb_dims, - train_block_indices=train_block_indices, + exclude_patterns=exclude_patterns, + include_patterns=include_patterns, verbose=verbose, ) @@ -137,6 +95,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh if weights_sd is None: if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file + weights_sd = load_file(file) else: weights_sd = torch.load(file, map_location="cpu") @@ -173,8 +132,8 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh class LoRANetwork(torch.nn.Module): - # Target modules: DiT blocks - ANIMA_TARGET_REPLACE_MODULE = ["Block"] + # Target modules: DiT blocks, embedders, final layer. embedders and final layer are excluded by default. + ANIMA_TARGET_REPLACE_MODULE = ["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"] # Target modules: LLM Adapter blocks ANIMA_ADAPTER_TARGET_REPLACE_MODULE = ["LLMAdapterTransformerBlock"] # Target modules for text encoder (Qwen3) @@ -197,9 +156,8 @@ class LoRANetwork(torch.nn.Module): modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, train_llm_adapter: bool = False, - type_dims: Optional[List[int]] = None, - emb_dims: Optional[List[int]] = None, - train_block_indices: Optional[List[bool]] = None, + exclude_patterns: Optional[List[str]] = None, + include_patterns: Optional[List[str]] = None, verbose: Optional[bool] = False, ) -> None: super().__init__() @@ -210,21 +168,36 @@ class LoRANetwork(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout self.train_llm_adapter = train_llm_adapter - self.type_dims = type_dims - self.emb_dims = emb_dims - self.train_block_indices = train_block_indices self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None self.loraplus_text_encoder_lr_ratio = None if modules_dim is not None: - logger.info(f"create LoRA network from weights") + logger.info("create LoRA network from weights") if self.emb_dims is None: self.emb_dims = [0] * 3 else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + + # compile regular expression if specified + def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]: + re_patterns = [] + if patterns is not None: + for pattern in patterns: + try: + re_pattern = re.compile(pattern) + except re.error as e: + logger.error(f"Invalid pattern '{pattern}': {e}") + continue + re_patterns.append(re_pattern) + return re_patterns + + exclude_re_patterns = str_to_re_patterns(exclude_patterns) + include_re_patterns = str_to_re_patterns(include_patterns) # create module instances def create_modules( @@ -232,15 +205,9 @@ class LoRANetwork(torch.nn.Module): text_encoder_idx: Optional[int], root_module: torch.nn.Module, target_replace_modules: List[str], - filter: Optional[str] = None, default_dim: Optional[int] = None, - include_conv2d_if_filter: bool = False, ) -> Tuple[List[LoRAModule], List[str]]: - prefix = ( - self.LORA_PREFIX_ANIMA - if is_unet - else self.LORA_PREFIX_TEXT_ENCODER - ) + prefix = self.LORA_PREFIX_ANIMA if is_unet else self.LORA_PREFIX_TEXT_ENCODER loras = [] skipped = [] @@ -255,14 +222,16 @@ class LoRANetwork(torch.nn.Module): is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) if is_linear or is_conv2d: - lora_name = prefix + "." + (name + "." if name else "") + child_name - lora_name = lora_name.replace(".", "_") + original_name = (name + "." if name else "") + child_name + lora_name = f"{prefix}.{original_name}".replace(".", "_") - force_incl_conv2d = False - if filter is not None: - if filter not in lora_name: - continue - force_incl_conv2d = include_conv2d_if_filter + # exclude/include filter + excluded = any(pattern.match(original_name) for pattern in exclude_re_patterns) + included = any(pattern.match(original_name) for pattern in include_re_patterns) + if excluded and not included: + if verbose: + logger.info(f"exclude: {original_name}") + continue dim = None alpha_val = None @@ -276,40 +245,6 @@ class LoRANetwork(torch.nn.Module): dim = default_dim if default_dim is not None else self.lora_dim alpha_val = self.alpha - if is_unet and type_dims is not None: - # type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim] - # Order matters: check most specific identifiers first to avoid mismatches. - identifier_order = [ - (4, ("llm_adapter",)), - (3, ("adaln_modulation",)), - (0, ("self_attn",)), - (1, ("cross_attn",)), - (2, ("mlp",)), - ] - for idx, ids in identifier_order: - d = type_dims[idx] - if d is not None and all(id_str in lora_name for id_str in ids): - dim = d # 0 means skip - break - - # block index filtering - if is_unet and dim and self.train_block_indices is not None and "blocks_" in lora_name: - # Extract block index from lora_name: "lora_unet_blocks_0_self_attn..." - parts = lora_name.split("_") - for pi, part in enumerate(parts): - if part == "blocks" and pi + 1 < len(parts): - try: - block_index = int(parts[pi + 1]) - if not self.train_block_indices[block_index]: - dim = 0 - except (ValueError, IndexError): - pass - break - - elif force_incl_conv2d: - dim = default_dim if default_dim is not None else self.lora_dim - alpha_val = self.alpha - if dim is None or dim == 0: if is_linear or is_conv2d_1x1: skipped.append(lora_name) @@ -339,9 +274,7 @@ class LoRANetwork(torch.nn.Module): if text_encoder is None: continue logger.info(f"create LoRA for Text Encoder {i+1}:") - te_loras, te_skipped = create_modules( - False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - ) + te_loras, te_skipped = create_modules(False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) logger.info(f"create LoRA for Text Encoder {i+1}: {len(te_loras)} modules.") self.text_encoder_loras.extend(te_loras) skipped_te += te_skipped @@ -354,19 +287,6 @@ class LoRANetwork(torch.nn.Module): self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) - # emb_dims: [x_embedder, t_embedder, final_layer] - if self.emb_dims: - for filter_name, in_dim in zip( - ["x_embedder", "t_embedder", "final_layer"], - self.emb_dims, - ): - loras, _ = create_modules( - True, None, unet, None, - filter=filter_name, default_dim=in_dim, - include_conv2d_if_filter=(filter_name == "x_embedder"), - ) - self.unet_loras.extend(loras) - logger.info(f"create LoRA for Anima DiT: {len(self.unet_loras)} modules.") if verbose: for lora in self.unet_loras: @@ -396,6 +316,7 @@ class LoRANetwork(torch.nn.Module): def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file + weights_sd = load_file(file) else: weights_sd = torch.load(file, map_location="cpu") @@ -443,10 +364,10 @@ class LoRANetwork(torch.nn.Module): sd_for_lora = {} for key in weights_sd.keys(): if key.startswith(lora.lora_name): - sd_for_lora[key[len(lora.lora_name) + 1:]] = weights_sd[key] + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] lora.merge_to(sd_for_lora, dtype, device) - logger.info(f"weights are merged") + logger.info("weights are merged") def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): self.loraplus_lr_ratio = loraplus_lr_ratio @@ -498,10 +419,7 @@ class LoRANetwork(torch.nn.Module): if self.text_encoder_loras: loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio - te1_loras = [ - lora for lora in self.text_encoder_loras - if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER) - ] + te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)] if len(te1_loras) > 0: logger.info(f"Text Encoder 1 (Qwen3): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_ratio) diff --git a/tests/test_anima_cache.py b/tests/manual_test_anima_cache.py similarity index 99% rename from tests/test_anima_cache.py rename to tests/manual_test_anima_cache.py index 1684eb53..880406f4 100644 --- a/tests/test_anima_cache.py +++ b/tests/manual_test_anima_cache.py @@ -2,7 +2,7 @@ Diagnostic script to test Anima latent & text encoder caching independently. Usage: - python test_anima_cache.py \ + python manual_test_anima_cache.py \ --image_dir /path/to/images \ --qwen3_path /path/to/qwen3 \ --vae_path /path/to/vae.safetensors \ diff --git a/tests/test_anima_real_training.py b/tests/manual_test_anima_real_training.py similarity index 100% rename from tests/test_anima_real_training.py rename to tests/manual_test_anima_real_training.py