This commit is contained in:
alefh123
2026-04-01 13:03:39 +00:00
committed by GitHub

View File

@@ -197,10 +197,6 @@ def train(args):
) )
return return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
if args.cache_text_encoder_outputs: if args.cache_text_encoder_outputs:
assert ( assert (
@@ -225,7 +221,32 @@ def train(args):
logit_scale, logit_scale,
ckpt_info, ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype) ) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# Force FP16 for caching even if training uses FP32
temp_vae_dtype = torch.float16 if not args.no_half_vae else torch.float32
vae = vae.to(accelerator.device, dtype=temp_vae_dtype)
# Optimize VAE performance
vae = vae.to(memory_format=torch.channels_last)
# if not isinstance(vae, torch._dynamo.eval_frame.OptimizedModule):
# vae = torch.compile(vae, mode="reduce-overhead")
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
vae,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process
)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
# verify load/save model formats # verify load/save model formats
if load_stable_diffusion_format: if load_stable_diffusion_format: