From 209c02dbb6952e1006a625c2cdd653a91db25bd0 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Fri, 12 Sep 2025 21:40:42 +0900 Subject: [PATCH] feat: HunyuanImage LoRA training --- _typos.toml | 2 +- hunyuan_image_minimal_inference.py | 30 +++-- hunyuan_image_train_network.py | 164 ++++++++++++++++---------- library/attention.py | 46 +++++++- library/hunyuan_image_models.py | 27 ++++- library/hunyuan_image_modules.py | 56 ++++++--- library/hunyuan_image_text_encoder.py | 2 +- library/hunyuan_image_vae.py | 4 +- library/strategy_hunyuan_image.py | 49 ++++++-- library/train_util.py | 34 +++++- networks/lora_hunyuan_image.py | 13 +- train_network.py | 74 +++++++----- 12 files changed, 352 insertions(+), 149 deletions(-) diff --git a/_typos.toml b/_typos.toml index cc167eaa..362ba8a6 100644 --- a/_typos.toml +++ b/_typos.toml @@ -30,7 +30,7 @@ yos="yos" wn="wn" hime="hime" OT="OT" -byt5="byt5" +byt="byt" # [files] # # Extend the default list of files to check diff --git a/hunyuan_image_minimal_inference.py b/hunyuan_image_minimal_inference.py index ba8ca78e..3de0b1cd 100644 --- a/hunyuan_image_minimal_inference.py +++ b/hunyuan_image_minimal_inference.py @@ -66,7 +66,7 @@ def parse_args() -> argparse.Namespace: # inference parser.add_argument( - "--guidance_scale", type=float, default=4.0, help="Guidance scale for classifier free guidance. Default is 4.0." + "--guidance_scale", type=float, default=5.0, help="Guidance scale for classifier free guidance. Default is 5.0." ) parser.add_argument("--prompt", type=str, default=None, help="prompt for generation") parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt for generation, default is empty string") @@ -508,7 +508,7 @@ def prepare_text_inputs( prompt = args.prompt cache_key = prompt if cache_key in conds_cache: - embed, mask = conds_cache[cache_key] + embed, mask, embed_byt5, mask_byt5, ocr_mask = conds_cache[cache_key] else: move_models_to_device_if_needed() @@ -527,7 +527,7 @@ def prepare_text_inputs( negative_prompt = args.negative_prompt cache_key = negative_prompt if cache_key in conds_cache: - negative_embed, negative_mask = conds_cache[cache_key] + negative_embed, negative_mask, negative_embed_byt5, negative_mask_byt5, negative_ocr_mask = conds_cache[cache_key] else: move_models_to_device_if_needed() @@ -614,9 +614,10 @@ def generate( shared_models["model"] = model else: # use shared model + logger.info("Using shared DiT model.") model: hunyuan_image_models.HYImageDiffusionTransformer = shared_models["model"] - # model.move_to_device_except_swap_blocks(device) # Handles block swap correctly - # model.prepare_block_swap_before_forward() + model.move_to_device_except_swap_blocks(device) # Handles block swap correctly + model.prepare_block_swap_before_forward() return generate_body(args, model, context, context_null, device, seed) @@ -678,9 +679,18 @@ def generate_body( # Denoising loop do_cfg = args.guidance_scale != 1.0 + # print(f"embed shape: {embed.shape}, mean: {embed.mean()}, std: {embed.std()}") + # print(f"embed_byt5 shape: {embed_byt5.shape}, mean: {embed_byt5.mean()}, std: {embed_byt5.std()}") + # print(f"negative_embed shape: {negative_embed.shape}, mean: {negative_embed.mean()}, std: {negative_embed.std()}") + # print(f"negative_embed_byt5 shape: {negative_embed_byt5.shape}, mean: {negative_embed_byt5.mean()}, std: {negative_embed_byt5.std()}") + # print(f"latents shape: {latents.shape}, mean: {latents.mean()}, std: {latents.std()}") + # print(f"mask shape: {mask.shape}, sum: {mask.sum()}") + # print(f"mask_byt5 shape: {mask_byt5.shape}, sum: {mask_byt5.sum()}") + # print(f"negative_mask shape: {negative_mask.shape}, sum: {negative_mask.sum()}") + # print(f"negative_mask_byt5 shape: {negative_mask_byt5.shape}, sum: {negative_mask_byt5.sum()}") with tqdm(total=len(timesteps), desc="Denoising steps") as pbar: for i, t in enumerate(timesteps): - t_expand = t.expand(latents.shape[0]).to(latents.dtype) + t_expand = t.expand(latents.shape[0]).to(torch.int64) with torch.no_grad(): noise_pred = model(latents, t_expand, embed, mask, embed_byt5, mask_byt5) @@ -1040,6 +1050,9 @@ def process_interactive(args: argparse.Namespace) -> None: shared_models = load_shared_models(args) shared_models["conds_cache"] = {} # Initialize empty cache for interactive mode + vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True) + vae.eval() + print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):") try: @@ -1059,9 +1072,6 @@ def process_interactive(args: argparse.Namespace) -> None: def input_line(prompt: str) -> str: return input(prompt) - vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True) - vae.eval() - try: while True: try: @@ -1088,7 +1098,7 @@ def process_interactive(args: argparse.Namespace) -> None: # Save latent and video # returned_vae from generate will be used for decoding here. - save_output(prompt_args, vae, latent[0], device) + save_output(prompt_args, vae, latent, device) except KeyboardInterrupt: print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)") diff --git a/hunyuan_image_train_network.py b/hunyuan_image_train_network.py index b1281fa0..291d5132 100644 --- a/hunyuan_image_train_network.py +++ b/hunyuan_image_train_network.py @@ -1,5 +1,6 @@ import argparse import copy +import gc from typing import Any, Optional, Union import argparse import os @@ -12,7 +13,7 @@ import torch.nn as nn from PIL import Image from accelerate import Accelerator, PartialState -from library import hunyuan_image_models, hunyuan_image_vae, strategy_base, train_util +from library import flux_utils, hunyuan_image_models, hunyuan_image_vae, strategy_base, train_util from library.device_utils import clean_memory_on_device, init_ipex init_ipex() @@ -24,7 +25,6 @@ from library import ( hunyuan_image_text_encoder, hunyuan_image_utils, hunyuan_image_vae, - sai_model_spec, sd3_train_utils, strategy_base, strategy_hunyuan_image, @@ -79,8 +79,6 @@ def sample_images( dit = accelerator.unwrap_model(dit) if text_encoders is not None: text_encoders = [(accelerator.unwrap_model(te) if te is not None else None) for te in text_encoders] - if controlnet is not None: - controlnet = accelerator.unwrap_model(controlnet) # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders]) prompts = train_util.load_prompts(args.sample_prompts) @@ -162,10 +160,10 @@ def sample_image_inference( sample_steps = prompt_dict.get("sample_steps", 20) width = prompt_dict.get("width", 512) height = prompt_dict.get("height", 512) - cfg_scale = prompt_dict.get("scale", 1.0) + cfg_scale = prompt_dict.get("scale", 3.5) seed = prompt_dict.get("seed") prompt: str = prompt_dict.get("prompt", "") - flow_shift: float = prompt_dict.get("flow_shift", 4.0) + flow_shift: float = prompt_dict.get("flow_shift", 5.0) # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) if prompt_replacement is not None: @@ -208,11 +206,10 @@ def sample_image_inference( text_encoder_conds = [] if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs: text_encoder_conds = sample_prompts_te_outputs[prpt] - print(f"Using cached text encoder outputs for prompt: {prpt}") + # print(f"Using cached text encoder outputs for prompt: {prpt}") if text_encoders is not None: - print(f"Encoding prompt: {prpt}") + # print(f"Encoding prompt: {prpt}") tokens_and_masks = tokenize_strategy.tokenize(prpt) - # strategy has apply_t5_attn_mask option encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) # if text_encoder_conds is not cached, use encoded_text_encoder_conds @@ -255,16 +252,21 @@ def sample_image_inference( from hunyuan_image_minimal_inference import generate_body # import here to avoid circular import - latents = generate_body(gen_args, dit, arg_c, arg_c_null, accelerator.device, seed) + dit_is_training = dit.training + dit.eval() + x = generate_body(gen_args, dit, arg_c, arg_c_null, accelerator.device, seed) + if dit_is_training: + dit.train() + clean_memory_on_device(accelerator.device) # latent to image - clean_memory_on_device(accelerator.device) org_vae_device = vae.device # will be on cpu vae.to(accelerator.device) # distributed_state.device is same as accelerator.device - with torch.autocast(accelerator.device.type, vae.dtype, enabled=True), torch.no_grad(): - x = x / hunyuan_image_vae.VAE_SCALE_FACTOR - x = vae.decode(x) + with torch.no_grad(): + x = x / vae.scaling_factor + x = vae.decode(x.to(vae.device, dtype=vae.dtype)) vae.to(org_vae_device) + clean_memory_on_device(accelerator.device) x = x.clamp(-1, 1) @@ -299,6 +301,7 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer): super().__init__() self.sample_prompts_te_outputs = None self.is_swapping_blocks: bool = False + self.rotary_pos_emb_cache = {} def assert_extra_args( self, @@ -341,12 +344,42 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer): def load_target_model(self, args, weight_dtype, accelerator): self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 - # currently offload to cpu for some models + vl_dtype = torch.float8_e4m3fn if args.fp8_vl else torch.bfloat16 + vl_device = "cpu" + _, text_encoder_vlm = hunyuan_image_text_encoder.load_qwen2_5_vl( + args.text_encoder, dtype=vl_dtype, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors + ) + _, text_encoder_byt5 = hunyuan_image_text_encoder.load_byt5( + args.byt5, dtype=torch.float16, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors + ) + + vae = hunyuan_image_vae.load_vae(args.vae, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + vae.to(dtype=torch.float16) # VAE is always fp16 + vae.eval() + if args.vae_enable_tiling: + vae.enable_tiling() + logger.info("VAE tiling is enabled") + + model_version = hunyuan_image_utils.MODEL_VERSION_2_1 + return model_version, [text_encoder_vlm, text_encoder_byt5], vae, None # unet will be loaded later + + def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, list[nn.Module]]: + if args.cache_text_encoder_outputs: + logger.info("Replace text encoders with dummy models to save memory") + + # This doesn't free memory, so we move text encoders to meta device in cache_text_encoder_outputs_if_needed + text_encoders = [flux_utils.dummy_clip_l() for _ in text_encoders] + clean_memory_on_device(accelerator.device) + gc.collect() + loading_dtype = None if args.fp8_scaled else weight_dtype loading_device = "cpu" if self.is_swapping_blocks else accelerator.device split_attn = True attn_mode = "torch" + if args.xformers: + attn_mode = "xformers" + logger.info("xformers is enabled for attention") model = hunyuan_image_models.load_hunyuan_image_model( accelerator.device, @@ -363,19 +396,7 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer): logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}") model.enable_block_swap(args.blocks_to_swap, accelerator.device) - vl_dtype = torch.bfloat16 - vl_device = "cpu" - _, text_encoder_vlm = hunyuan_image_text_encoder.load_qwen2_5_vl( - args.text_encoder, dtype=vl_dtype, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors - ) - _, text_encoder_byt5 = hunyuan_image_text_encoder.load_byt5( - args.byt5, dtype=torch.float16, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors - ) - - vae = hunyuan_image_vae.load_vae(args.vae, "cpu", disable_mmap=args.disable_mmap_load_safetensors) - - model_version = hunyuan_image_utils.MODEL_VERSION_2_1 - return model_version, [text_encoder_vlm, text_encoder_byt5], vae, model + return model, text_encoders def get_tokenize_strategy(self, args): return strategy_hunyuan_image.HunyuanImageTokenizeStrategy(args.tokenizer_cache_dir) @@ -404,7 +425,6 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer): def get_text_encoder_outputs_caching_strategy(self, args): if args.cache_text_encoder_outputs: - # if the text encoders is trained, we need tokenization, so is_partial is True return strategy_hunyuan_image.HunyuanImageTextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False ) @@ -417,11 +437,9 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer): if args.cache_text_encoder_outputs: if not args.lowram: # メモリ消費を減らす - logger.info("move vae and unet to cpu to save memory") + logger.info("move vae to cpu to save memory") org_vae_device = vae.device - org_unet_device = unet.device vae.to("cpu") - unet.to("cpu") clean_memory_on_device(accelerator.device) logger.info("move text encoders to gpu") @@ -457,17 +475,14 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer): accelerator.wait_for_everyone() - # move back to cpu - logger.info("move VLM back to cpu") - text_encoders[0].to("cpu") - logger.info("move byT5 back to cpu") - text_encoders[1].to("cpu") + # text encoders are not needed for training, so we move to meta device + logger.info("move text encoders to meta device to save memory") + text_encoders = [te.to("meta") for te in text_encoders] clean_memory_on_device(accelerator.device) if not args.lowram: - logger.info("move vae and unet back to original device") + logger.info("move vae back to original device") vae.to(org_vae_device) - unet.to(org_unet_device) else: # Text Encoderから毎回出力を取得するので、GPUに乗せておく text_encoders[0].to(accelerator.device) @@ -477,21 +492,19 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer): text_encoders = text_encoder # for compatibility text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders) - flux_train_utils.sample_images( - accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs - ) + sample_images(accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs) def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) return noise_scheduler - def encode_images_to_latents(self, args, vae, images): - return vae.encode(images) + def encode_images_to_latents(self, args, vae: hunyuan_image_vae.HunyuanVAE2D, images): + return vae.encode(images).sample() def shift_scale_latents(self, args, latents): # for encoding, we need to scale the latents - return latents * hunyuan_image_vae.VAE_SCALE_FACTOR + return latents * hunyuan_image_vae.LATENT_SCALING_FACTOR def get_noise_pred_and_target( self, @@ -509,12 +522,16 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer): ): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) - bsz = latents.shape[0] # get noisy model input and timesteps - noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + noisy_model_input, _, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( args, noise_scheduler, latents, noise, accelerator.device, weight_dtype ) + # bfloat16 is too low precision for 0-1000 TODO fix get_noisy_model_input_and_timesteps + timesteps = (sigmas[:, 0, 0, 0] * 1000).to(torch.int64) + # print( + # f"timestep: {timesteps}, noisy_model_input shape: {noisy_model_input.shape}, mean: {noisy_model_input.mean()}, std: {noisy_model_input.std()}" + # ) if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) @@ -526,31 +543,33 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer): # ocr_mask is for inference only, so it is not used here vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask = text_encoder_conds + # print(f"embed shape: {vlm_embed.shape}, mean: {vlm_embed.mean()}, std: {vlm_embed.std()}") + # print(f"embed_byt5 shape: {byt5_embed.shape}, mean: {byt5_embed.mean()}, std: {byt5_embed.std()}") + # print(f"latents shape: {latents.shape}, mean: {latents.mean()}, std: {latents.std()}") + # print(f"mask shape: {vlm_mask.shape}, sum: {vlm_mask.sum()}") + # print(f"mask_byt5 shape: {byt5_mask.shape}, sum: {byt5_mask.sum()}") with torch.set_grad_enabled(is_train), accelerator.autocast(): - model_pred = unet(noisy_model_input, timesteps / 1000, vlm_embed, vlm_mask, byt5_embed, byt5_mask) + model_pred = unet( + noisy_model_input, timesteps, vlm_embed, vlm_mask, byt5_embed, byt5_mask # , self.rotary_pos_emb_cache + ) - # model prediction and weighting is omitted for HunyuanImage-2.1 currently + # apply model prediction type + model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas) # flow matching loss target = noise - latents # differential output preservation is not used for HunyuanImage-2.1 currently - return model_pred, target, timesteps, None + return model_pred, target, timesteps, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss def get_sai_model_spec(self, args): - # if self.model_type != "chroma": - # model_description = "schnell" if self.is_schnell else "dev" - # else: - # model_description = "chroma" - # return train_util.get_sai_model_spec(None, args, False, True, False, flux=model_description) - train_util.get_sai_model_spec_dataclass(None, args, False, True, False, hunyuan_image="2.1") + return train_util.get_sai_model_spec_dataclass(None, args, False, True, False, hunyuan_image="2.1").to_metadata_dict() def update_metadata(self, metadata, args): - metadata["ss_model_type"] = args.model_type metadata["ss_logit_mean"] = args.logit_mean metadata["ss_logit_std"] = args.logit_std metadata["ss_mode_scale"] = args.mode_scale @@ -569,6 +588,9 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer): def cast_text_encoder(self): return False # VLM is bf16, byT5 is fp16, so do not cast to other dtype + def cast_vae(self): + return False # VAE is fp16, so do not cast to other dtype + def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): # fp8 text encoder for HunyuanImage-2.1 is not supported currently pass @@ -597,6 +619,17 @@ def setup_parser() -> argparse.ArgumentParser: parser = train_network.setup_parser() train_util.add_dit_training_arguments(parser) + parser.add_argument( + "--text_encoder", + type=str, + help="path to Qwen2.5-VL (*.sft or *.safetensors), should be bfloat16 / Qwen2.5-VLのパス(*.sftまたは*.safetensors)、bfloat16が前提", + ) + parser.add_argument( + "--byt5", + type=str, + help="path to byt5 (*.sft or *.safetensors), should be float16 / byt5のパス(*.sftまたは*.safetensors)、float16が前提", + ) + parser.add_argument( "--timestep_sampling", choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"], @@ -613,17 +646,24 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--model_prediction_type", choices=["raw", "additive", "sigma_scaled"], - default="sigma_scaled", + default="raw", help="How to interpret and process the model prediction: " - "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)." + "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling). Default is raw unlike FLUX.1." " / モデル予測の解釈と処理方法:" - "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。", + "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。デフォルトはFLUX.1とは異なりrawです。", ) parser.add_argument( "--discrete_flow_shift", type=float, - default=3.0, - help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", + default=5.0, + help="Discrete flow shift for the Euler Discrete Scheduler, default is 5.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは5.0。", + ) + parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う") + parser.add_argument("--fp8_vl", action="store_true", help="use fp8 for VLM text encoder / VLMテキストエンコーダにfp8を使用する") + parser.add_argument( + "--vae_enable_tiling", + action="store_true", + help="Enable tiling for VAE decoding and encoding / VAEデコーディングとエンコーディングのタイルを有効にする", ) return parser diff --git a/library/attention.py b/library/attention.py index 10a09614..f1e7c0b0 100644 --- a/library/attention.py +++ b/library/attention.py @@ -1,9 +1,19 @@ import torch -from typing import Optional +from typing import Optional, Union + +try: + import xformers.ops as xops +except ImportError: + xops = None def attention( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_lens: list[int], attn_mode: str = "torch", drop_rate: float = 0.0 + qkv_or_q: Union[torch.Tensor, list], + k: Optional[torch.Tensor] = None, + v: Optional[torch.Tensor] = None, + seq_lens: Optional[list[int]] = None, + attn_mode: str = "torch", + drop_rate: float = 0.0, ) -> torch.Tensor: """ Compute scaled dot-product attention with variable sequence lengths. @@ -12,7 +22,7 @@ def attention( processing each sequence individually. Args: - q: Query tensor [B, L, H, D]. + qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors. k: Key tensor [B, L, H, D]. v: Value tensor [B, L, H, D]. seq_lens: Valid sequence length for each batch element. @@ -22,6 +32,17 @@ def attention( Returns: Attention output tensor [B, L, H*D]. """ + if isinstance(qkv_or_q, list): + q, k, v = qkv_or_q + qkv_or_q.clear() + del qkv_or_q + else: + q = qkv_or_q + del qkv_or_q + assert k is not None and v is not None, "k and v must be provided if qkv_or_q is a tensor" + if seq_lens is None: + seq_lens = [q.shape[1]] * q.shape[0] + # Determine tensor layout based on attention implementation if attn_mode == "torch" or attn_mode == "sageattn": transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA @@ -29,6 +50,7 @@ def attention( transpose_fn = lambda x: x # [B, L, H, D] for other implementations # Process each batch element with its valid sequence length + q_seq_len = q.shape[1] q = [transpose_fn(q[i : i + 1, : seq_lens[i]]) for i in range(len(q))] k = [transpose_fn(k[i : i + 1, : seq_lens[i]]) for i in range(len(k))] v = [transpose_fn(v[i : i + 1, : seq_lens[i]]) for i in range(len(v))] @@ -40,10 +62,24 @@ def attention( q[i] = None k[i] = None v[i] = None - x.append(x_i) + x.append(torch.nn.functional.pad(x_i, (0, 0, 0, q_seq_len - x_i.shape[2]), value=0)) # Pad to max seq len, B, H, L, D x = torch.cat(x, dim=0) del q, k, v - # Currently only PyTorch SDPA is implemented + + elif attn_mode == "xformers": + x = [] + for i in range(len(q)): + x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate) + q[i] = None + k[i] = None + v[i] = None + x.append(torch.nn.functional.pad(x_i, (0, 0, 0, 0, 0, q_seq_len - x_i.shape[1]), value=0)) # B, L, H, D + x = torch.cat(x, dim=0) + del q, k, v + + else: + # Currently only PyTorch SDPA and xformers are implemented + raise ValueError(f"Unsupported attention mode: {attn_mode}") x = transpose_fn(x) # [B, L, H, D] x = x.reshape(x.shape[0], x.shape[1], -1) # [B, L, H*D] diff --git a/library/hunyuan_image_models.py b/library/hunyuan_image_models.py index 9e3a00e8..ce2d23dd 100644 --- a/library/hunyuan_image_models.py +++ b/library/hunyuan_image_models.py @@ -30,11 +30,7 @@ from library.hunyuan_image_modules import ( from library.hunyuan_image_utils import get_nd_rotary_pos_embed FP8_OPTIMIZATION_TARGET_KEYS = ["double_blocks", "single_blocks"] -FP8_OPTIMIZATION_EXCLUDE_KEYS = [ - "norm", - "_mod", - "modulation", -] +FP8_OPTIMIZATION_EXCLUDE_KEYS = ["norm", "_mod", "modulation", "_emb"] # region DiT Model @@ -142,6 +138,14 @@ class HYImageDiffusionTransformer(nn.Module): self.num_double_blocks = len(self.double_blocks) self.num_single_blocks = len(self.single_blocks) + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + def enable_gradient_checkpointing(self, cpu_offload: bool = False): self.gradient_checkpointing = True self.cpu_offload_checkpointing = cpu_offload @@ -273,6 +277,7 @@ class HYImageDiffusionTransformer(nn.Module): encoder_attention_mask: torch.Tensor, byt5_text_states: Optional[torch.Tensor] = None, byt5_text_mask: Optional[torch.Tensor] = None, + rotary_pos_emb_cache: Optional[Dict[Tuple[int, int], Tuple[torch.Tensor, torch.Tensor]]] = None, ) -> torch.Tensor: """ Forward pass through the HunyuanImage diffusion transformer. @@ -296,7 +301,15 @@ class HYImageDiffusionTransformer(nn.Module): # Calculate spatial dimensions for rotary position embeddings _, _, oh, ow = x.shape th, tw = oh, ow # Height and width (patch_size=[1,1] means no spatial downsampling) - freqs_cis = self.get_rotary_pos_embed((th, tw)) + if rotary_pos_emb_cache is not None: + if (th, tw) in rotary_pos_emb_cache: + freqs_cis = rotary_pos_emb_cache[(th, tw)] + freqs_cis = (freqs_cis[0].to(img.device), freqs_cis[1].to(img.device)) + else: + freqs_cis = self.get_rotary_pos_embed((th, tw)) + rotary_pos_emb_cache[(th, tw)] = (freqs_cis[0].cpu(), freqs_cis[1].cpu()) + else: + freqs_cis = self.get_rotary_pos_embed((th, tw)) # Reshape image latents to sequence format: [B, C, H, W] -> [B, H*W, C] img = self.img_in(img) @@ -349,9 +362,11 @@ class HYImageDiffusionTransformer(nn.Module): vec = vec.to(input_device) img = x[:, :img_seq_len, ...] + del x # Apply final projection to output space img = self.final_layer(img, vec) + del vec # Reshape from sequence to spatial format: [B, L, C] -> [B, C, H, W] img = self.unpatchify_2d(img, th, tw) diff --git a/library/hunyuan_image_modules.py b/library/hunyuan_image_modules.py index 633cd310..ef4d5e5d 100644 --- a/library/hunyuan_image_modules.py +++ b/library/hunyuan_image_modules.py @@ -50,7 +50,7 @@ class ByT5Mapper(nn.Module): Returns: Transformed embeddings [..., out_dim1]. """ - residual = x + residual = x if self.use_residual else None x = self.layernorm(x) x = self.fc1(x) x = self.act_fn(x) @@ -411,6 +411,7 @@ class SingleTokenRefiner(nn.Module): context_aware_representations = self.c_embedder(context_aware_representations) c = timestep_aware_representations + context_aware_representations + del timestep_aware_representations, context_aware_representations x = self.input_embedder(x) x = self.individual_token_refiner(x, c, txt_lens) return x @@ -447,6 +448,7 @@ class FinalLayer(nn.Module): def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift=shift, scale=scale) + del shift, scale, c x = self.linear(x) return x @@ -494,6 +496,7 @@ class RMSNorm(nn.Module): Normalized and scaled tensor. """ output = self._norm(x.float()).type_as(x) + del x output = output * self.weight return output @@ -634,8 +637,10 @@ class MMDoubleStreamBlock(nn.Module): # Process image stream for attention img_modulated = self.img_norm1(img) img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale) + del img_mod1_shift, img_mod1_scale img_qkv = self.img_attn_qkv(img_modulated) + del img_modulated img_q, img_k, img_v = img_qkv.chunk(3, dim=-1) del img_qkv @@ -649,17 +654,15 @@ class MMDoubleStreamBlock(nn.Module): # Apply rotary position embeddings to image tokens if freqs_cis is not None: - img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) - assert ( - img_qq.shape == img_q.shape and img_kk.shape == img_k.shape - ), f"RoPE output shape mismatch: got {img_qq.shape}, {img_kk.shape}, expected {img_q.shape}, {img_k.shape}" - img_q, img_k = img_qq, img_kk + img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) + del freqs_cis # Process text stream for attention txt_modulated = self.txt_norm1(txt) txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale) txt_qkv = self.txt_attn_qkv(txt_modulated) + del txt_modulated txt_q, txt_k, txt_v = txt_qkv.chunk(3, dim=-1) del txt_qkv @@ -672,31 +675,44 @@ class MMDoubleStreamBlock(nn.Module): txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) # Concatenate image and text tokens for joint attention + img_seq_len = img.shape[1] q = torch.cat([img_q, txt_q], dim=1) + del img_q, txt_q k = torch.cat([img_k, txt_k], dim=1) + del img_k, txt_k v = torch.cat([img_v, txt_v], dim=1) - attn = attention(q, k, v, seq_lens=seq_lens, attn_mode=self.attn_mode) + del img_v, txt_v + + qkv = [q, k, v] + del q, k, v + attn = attention(qkv, seq_lens=seq_lens, attn_mode=self.attn_mode) + del qkv # Split attention outputs back to separate streams - img_attn, txt_attn = (attn[:, : img_q.shape[1]].contiguous(), attn[:, img_q.shape[1] :].contiguous()) + img_attn, txt_attn = (attn[:, : img_seq_len].contiguous(), attn[:, img_seq_len :].contiguous()) + del attn # Apply attention projection and residual connection for image stream img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) + del img_attn, img_mod1_gate # Apply MLP and residual connection for image stream img = img + apply_gate( self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)), gate=img_mod2_gate, ) + del img_mod2_shift, img_mod2_scale, img_mod2_gate # Apply attention projection and residual connection for text stream txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate) + del txt_attn, txt_mod1_gate # Apply MLP and residual connection for text stream txt = txt + apply_gate( self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)), gate=txt_mod2_gate, ) + del txt_mod2_shift, txt_mod2_scale, txt_mod2_gate return img, txt @@ -797,6 +813,7 @@ class MMSingleStreamBlock(nn.Module): # Compute Q, K, V, and MLP input qkv_mlp = self.linear1(x_mod) + del x_mod q, k, v, mlp = qkv_mlp.split([self.hidden_size, self.hidden_size, self.hidden_size, self.mlp_hidden_dim], dim=-1) del qkv_mlp @@ -810,27 +827,34 @@ class MMSingleStreamBlock(nn.Module): # Separate image and text tokens img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] + del q img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] - img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :] + del k + # img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :] + # del v # Apply rotary position embeddings only to image tokens - img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) - assert ( - img_qq.shape == img_q.shape and img_kk.shape == img_k.shape - ), f"RoPE output shape mismatch: got {img_qq.shape}, {img_kk.shape}, expected {img_q.shape}, {img_k.shape}" - img_q, img_k = img_qq, img_kk + img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) + del freqs_cis # Recombine and compute joint attention q = torch.cat([img_q, txt_q], dim=1) + del img_q, txt_q k = torch.cat([img_k, txt_k], dim=1) - v = torch.cat([img_v, txt_v], dim=1) - attn = attention(q, k, v, seq_lens=seq_lens, attn_mode=self.attn_mode) + del img_k, txt_k + # v = torch.cat([img_v, txt_v], dim=1) + # del img_v, txt_v + qkv = [q, k, v] + del q, k, v + attn = attention(qkv, seq_lens=seq_lens, attn_mode=self.attn_mode) + del qkv # Combine attention and MLP outputs, apply gating # output = self.linear2(attn, self.mlp_act(mlp)) mlp = self.mlp_act(mlp) output = torch.cat([attn, mlp], dim=2).contiguous() + del attn, mlp output = self.linear2(output) return x + apply_gate(output, gate=mod_gate) diff --git a/library/hunyuan_image_text_encoder.py b/library/hunyuan_image_text_encoder.py index 1300b39b..960f14b3 100644 --- a/library/hunyuan_image_text_encoder.py +++ b/library/hunyuan_image_text_encoder.py @@ -598,7 +598,7 @@ def get_byt5_prompt_embeds_from_tokens( ) -> Tuple[list[bool], torch.Tensor, torch.Tensor]: byt5_max_length = BYT5_MAX_LENGTH - if byt5_text_ids is None or byt5_text_mask is None: + if byt5_text_ids is None or byt5_text_mask is None or byt5_text_mask.sum() == 0: return ( [False], torch.zeros((1, byt5_max_length, 1472), device=text_encoder.device), diff --git a/library/hunyuan_image_vae.py b/library/hunyuan_image_vae.py index 6eb035c3..570d4caa 100644 --- a/library/hunyuan_image_vae.py +++ b/library/hunyuan_image_vae.py @@ -17,6 +17,8 @@ logger = logging.getLogger(__name__) VAE_SCALE_FACTOR = 32 # 32x spatial compression +LATENT_SCALING_FACTOR = 0.75289 # Latent scaling factor for Hunyuan Image-2.1 + def swish(x: Tensor) -> Tensor: """Swish activation function: x * sigmoid(x).""" @@ -378,7 +380,7 @@ class HunyuanVAE2D(nn.Module): layers_per_block = 2 ffactor_spatial = 32 # 32x spatial compression sample_size = 384 # Minimum sample size for tiling - scaling_factor = 0.75289 # Latent scaling factor + scaling_factor = LATENT_SCALING_FACTOR # 0.75289 # Latent scaling factor self.ffactor_spatial = ffactor_spatial self.scaling_factor = scaling_factor diff --git a/library/strategy_hunyuan_image.py b/library/strategy_hunyuan_image.py index 2188ed37..5c704728 100644 --- a/library/strategy_hunyuan_image.py +++ b/library/strategy_hunyuan_image.py @@ -21,14 +21,27 @@ class HunyuanImageTokenizeStrategy(TokenizeStrategy): Qwen2Tokenizer, hunyuan_image_text_encoder.QWEN_2_5_VL_IMAGE_ID, tokenizer_cache_dir=tokenizer_cache_dir ) self.byt5_tokenizer = self._load_tokenizer( - AutoTokenizer, hunyuan_image_text_encoder.BYT5_TOKENIZER_PATH, tokenizer_cache_dir=tokenizer_cache_dir + AutoTokenizer, hunyuan_image_text_encoder.BYT5_TOKENIZER_PATH, subfolder="", tokenizer_cache_dir=tokenizer_cache_dir ) def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: text = [text] if isinstance(text, str) else text vlm_tokens, vlm_mask = hunyuan_image_text_encoder.get_qwen_tokens(self.vlm_tokenizer, text) - byt5_tokens, byt5_mask = hunyuan_image_text_encoder.get_byt5_text_tokens(self.byt5_tokenizer, text) + + # byt5_tokens, byt5_mask = hunyuan_image_text_encoder.get_byt5_text_tokens(self.byt5_tokenizer, text) + byt5_tokens = [] + byt5_mask = [] + for t in text: + tokens, mask = hunyuan_image_text_encoder.get_byt5_text_tokens(self.byt5_tokenizer, t) + if tokens is None: + tokens = torch.zeros((1, 1), dtype=torch.long) + mask = torch.zeros((1, 1), dtype=torch.long) + byt5_tokens.append(tokens) + byt5_mask.append(mask) + max_len = max([m.shape[1] for m in byt5_mask]) + byt5_tokens = torch.cat([torch.nn.functional.pad(t, (0, max_len - t.shape[1]), value=0) for t in byt5_tokens], dim=0) + byt5_mask = torch.cat([torch.nn.functional.pad(m, (0, max_len - m.shape[1]), value=0) for m in byt5_mask], dim=0) return [vlm_tokens, vlm_mask, byt5_tokens, byt5_mask] @@ -46,11 +59,24 @@ class HunyuanImageTextEncodingStrategy(TextEncodingStrategy): # autocast and no_grad are handled in hunyuan_image_text_encoder vlm_embed, vlm_mask = hunyuan_image_text_encoder.get_qwen_prompt_embeds_from_tokens(qwen2vlm, vlm_tokens, vlm_mask) - ocr_mask, byt5_embed, byt5_mask = hunyuan_image_text_encoder.get_byt5_prompt_embeds_from_tokens( - byt5, byt5_tokens, byt5_mask - ) - return [vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask] + # ocr_mask, byt5_embed, byt5_mask = hunyuan_image_text_encoder.get_byt5_prompt_embeds_from_tokens( + # byt5, byt5_tokens, byt5_mask + # ) + ocr_mask, byt5_embed, byt5_updated_mask = [], [], [] + for i in range(byt5_tokens.shape[0]): + ocr_m, byt5_e, byt5_m = hunyuan_image_text_encoder.get_byt5_prompt_embeds_from_tokens( + byt5, byt5_tokens[i : i + 1], byt5_mask[i : i + 1] + ) + ocr_mask.append(torch.zeros((1,), dtype=torch.long) + (1 if ocr_m[0] else 0)) # 1 or 0 + byt5_embed.append(byt5_e) + byt5_updated_mask.append(byt5_m) + + ocr_mask = torch.cat(ocr_mask, dim=0).to(torch.bool) # [B] + byt5_embed = torch.cat(byt5_embed, dim=0) + byt5_updated_mask = torch.cat(byt5_updated_mask, dim=0) + + return [vlm_embed, vlm_mask, byt5_embed, byt5_updated_mask, ocr_mask] class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy): @@ -110,7 +136,6 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr tokens_and_masks = tokenize_strategy.tokenize(captions) with torch.no_grad(): - # attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True vlm_embed, vlm_mask, byt5_embed, byt5_mask, ocr_mask = huyuan_image_text_encoding_strategy.encode_tokens( tokenize_strategy, models, tokens_and_masks ) @@ -124,7 +149,7 @@ class HunyuanImageTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStr vlm_mask = vlm_mask.cpu().numpy() byt5_embed = byt5_embed.cpu().numpy() byt5_mask = byt5_mask.cpu().numpy() - ocr_mask = np.array(ocr_mask, dtype=bool) + ocr_mask = ocr_mask.cpu().numpy() for i, info in enumerate(infos): vlm_embed_i = vlm_embed[i] @@ -175,7 +200,13 @@ class HunyuanImageLatentsCachingStrategy(LatentsCachingStrategy): def cache_batch_latents( self, vae: hunyuan_image_vae.HunyuanVAE2D, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool ): - encode_by_vae = lambda img_tensor: vae.encode(img_tensor).sample() + # encode_by_vae = lambda img_tensor: vae.encode(img_tensor).sample() + def encode_by_vae(img_tensor): + # no_grad is handled in _default_cache_batch_latents + nonlocal vae + with torch.autocast(device_type=vae.device.type, dtype=vae.dtype): + return vae.encode(img_tensor).sample() + vae_device = vae.device vae_dtype = vae.dtype diff --git a/library/train_util.py b/library/train_util.py index 8cd43463..756d88b1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1744,7 +1744,39 @@ class BaseDataset(torch.utils.data.Dataset): # [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)] if len(tensors_list) == 0 or tensors_list[0] == None or len(tensors_list[0]) == 0 or tensors_list[0][0] is None: return None - return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))] + + # old implementation without padding: all elements must have same length + # return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))] + + # new implementation with padding support + result = [] + for i in range(len(tensors_list[0])): + tensors = [x[i] for x in tensors_list] + if tensors[0].ndim == 0: + # scalar value: e.g. ocr mask + result.append(torch.stack([converter(x[i]) for x in tensors_list])) + continue + + min_len = min([len(x) for x in tensors]) + max_len = max([len(x) for x in tensors]) + + if min_len == max_len: + # no padding + result.append(torch.stack([converter(x) for x in tensors])) + else: + # padding + tensors = [converter(x) for x in tensors] + if tensors[0].ndim == 1: + # input_ids or mask + result.append( + torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors]) + ) + else: + # text encoder outputs + result.append( + torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors]) + ) + return result # set example example = {} diff --git a/networks/lora_hunyuan_image.py b/networks/lora_hunyuan_image.py index b0edde57..3e801f95 100644 --- a/networks/lora_hunyuan_image.py +++ b/networks/lora_hunyuan_image.py @@ -191,9 +191,8 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh class HunyuanImageLoRANetwork(lora_flux.LoRANetwork): - # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] - FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] - FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] + TARGET_REPLACE_MODULE_DOUBLE = ["MMDoubleStreamBlock"] + TARGET_REPLACE_MODULE_SINGLE = ["MMSingleStreamBlock"] LORA_PREFIX_HUNYUAN_IMAGE_DIT = "lora_unet" # make ComfyUI compatible @classmethod @@ -222,7 +221,7 @@ class HunyuanImageLoRANetwork(lora_flux.LoRANetwork): reg_lrs: Optional[Dict[str, float]] = None, verbose: Optional[bool] = False, ) -> None: - super().__init__() + nn.Module.__init__(self) self.multiplier = multiplier self.lora_dim = lora_dim @@ -259,8 +258,6 @@ class HunyuanImageLoRANetwork(lora_flux.LoRANetwork): if self.split_qkv: logger.info(f"split qkv for LoRA") - if self.train_blocks is not None: - logger.info(f"train {self.train_blocks} blocks only") # create module instances def create_modules( @@ -354,14 +351,14 @@ class HunyuanImageLoRANetwork(lora_flux.LoRANetwork): # create LoRA for U-Net target_replace_modules = ( - HunyuanImageLoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + HunyuanImageLoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + HunyuanImageLoRANetwork.TARGET_REPLACE_MODULE_DOUBLE + HunyuanImageLoRANetwork.TARGET_REPLACE_MODULE_SINGLE ) self.unet_loras: List[Union[lora_flux.LoRAModule, lora_flux.LoRAInfModule]] self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules) self.text_encoder_loras = [] - logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.") + logger.info(f"create LoRA for HunyuanImage-2.1: {len(self.unet_loras)} modules.") if verbose: for lora in self.unet_loras: logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") diff --git a/train_network.py b/train_network.py index 00118877..c03c5fa0 100644 --- a/train_network.py +++ b/train_network.py @@ -1,3 +1,4 @@ +import gc import importlib import argparse import math @@ -10,11 +11,11 @@ import time import json from multiprocessing import Value import numpy as np -import toml from tqdm import tqdm import torch +import torch.nn as nn from torch.types import Number from library.device_utils import init_ipex, clean_memory_on_device @@ -175,7 +176,7 @@ class NetworkTrainer: if val_dataset_group is not None: val_dataset_group.verify_bucket_reso_steps(64) - def load_target_model(self, args, weight_dtype, accelerator) -> tuple: + def load_target_model(self, args, weight_dtype, accelerator) -> tuple[str, nn.Module, nn.Module, Optional[nn.Module]]: text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) # モデルに xformers とか memory efficient attention を組み込む @@ -185,6 +186,9 @@ class NetworkTrainer: return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet + def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, List[nn.Module]]: + raise NotImplementedError() + def get_tokenize_strategy(self, args): return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) @@ -476,8 +480,11 @@ class NetworkTrainer: return loss.mean() def cast_text_encoder(self): - return True # default for other than HunyuanImage + return True # default for other than HunyuanImage + def cast_vae(self): + return True # default for other than HunyuanImage + def train(self, args): session_id = random.randint(0, 2**32) training_started_at = time.time() @@ -586,37 +593,18 @@ class NetworkTrainer: # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) - vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + vae_dtype = (torch.float32 if args.no_half_vae else weight_dtype) if self.cast_vae() else None - # モデルを読み込む + # load target models: unet may be None for lazy loading model_version, text_encoder, vae, unet = self.load_target_model(args, weight_dtype, accelerator) + if vae_dtype is None: + vae_dtype = vae.dtype + logger.info(f"vae_dtype is set to {vae_dtype} by the model since cast_vae() is false") # text_encoder is List[CLIPTextModel] or CLIPTextModel text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] - # 差分追加学習のためにモデルを読み込む - sys.path.append(os.path.dirname(__file__)) - accelerator.print("import network module:", args.network_module) - network_module = importlib.import_module(args.network_module) - - if args.base_weights is not None: - # base_weights が指定されている場合は、指定された重みを読み込みマージする - for i, weight_path in enumerate(args.base_weights): - if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i: - multiplier = 1.0 - else: - multiplier = args.base_weights_multiplier[i] - - accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}") - - module, weights_sd = network_module.create_network_from_weights( - multiplier, weight_path, vae, text_encoder, unet, for_inference=True - ) - module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") - - accelerator.print(f"all weights merged: {', '.join(args.base_weights)}") - - # 学習を準備する + # prepare dataset for latents caching if needed if cache_latents: vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) @@ -643,6 +631,32 @@ class NetworkTrainer: if val_dataset_group is not None: self.cache_text_encoder_outputs_if_needed(args, accelerator, unet, vae, text_encoders, val_dataset_group, weight_dtype) + if unet is None: + # lazy load unet if needed. text encoders may be freed or replaced with dummy models for saving memory + unet, text_encoders = self.load_unet_lazily(args, weight_dtype, accelerator, text_encoders) + + # 差分追加学習のためにモデルを読み込む + sys.path.append(os.path.dirname(__file__)) + accelerator.print("import network module:", args.network_module) + network_module = importlib.import_module(args.network_module) + + if args.base_weights is not None: + # base_weights が指定されている場合は、指定された重みを読み込みマージする + for i, weight_path in enumerate(args.base_weights): + if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i: + multiplier = 1.0 + else: + multiplier = args.base_weights_multiplier[i] + + accelerator.print(f"merging module: {weight_path} with multiplier {multiplier}") + + module, weights_sd = network_module.create_network_from_weights( + multiplier, weight_path, vae, text_encoder, unet, for_inference=True + ) + module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu") + + accelerator.print(f"all weights merged: {', '.join(args.base_weights)}") + # prepare network net_kwargs = {} if args.network_args is not None: @@ -672,7 +686,7 @@ class NetworkTrainer: return network_has_multiplier = hasattr(network, "set_multiplier") - # TODO remove `hasattr`s by setting up methods if not defined in the network like (hacky but works): + # TODO remove `hasattr` by setting up methods if not defined in the network like below (hacky but will work): # if not hasattr(network, "prepare_network"): # network.prepare_network = lambda args: None @@ -1305,6 +1319,8 @@ class NetworkTrainer: del t_enc text_encoders = [] text_encoder = None + gc.collect() + clean_memory_on_device(accelerator.device) # For --sample_at_first optimizer_eval_fn()