diff --git a/sdxl_train.py b/sdxl_train.py index f454263a..7af9213b 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -197,10 +197,6 @@ def train(args): ) return - if cache_latents: - assert ( - train_dataset_group.is_latent_cacheable() - ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" if args.cache_text_encoder_outputs: assert ( @@ -225,7 +221,32 @@ def train(args): logit_scale, ckpt_info, ) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) - # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + # Force FP16 for caching even if training uses FP32 + temp_vae_dtype = torch.float16 if not args.no_half_vae else torch.float32 + vae = vae.to(accelerator.device, dtype=temp_vae_dtype) + + # Optimize VAE performance + vae = vae.to(memory_format=torch.channels_last) + # if not isinstance(vae, torch._dynamo.eval_frame.OptimizedModule): + # vae = torch.compile(vae, mode="reduce-overhead") + + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents( + vae, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process + ) + vae.to("cpu") + clean_memory_on_device(accelerator.device) # verify load/save model formats if load_stable_diffusion_format: