mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Merge 68877e789b into 1dae34b0af
This commit is contained in:
@@ -197,10 +197,6 @@ def train(args):
|
||||
)
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
@@ -225,7 +221,32 @@ def train(args):
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# Force FP16 for caching even if training uses FP32
|
||||
temp_vae_dtype = torch.float16 if not args.no_half_vae else torch.float32
|
||||
vae = vae.to(accelerator.device, dtype=temp_vae_dtype)
|
||||
|
||||
# Optimize VAE performance
|
||||
vae = vae.to(memory_format=torch.channels_last)
|
||||
# if not isinstance(vae, torch._dynamo.eval_frame.OptimizedModule):
|
||||
# vae = torch.compile(vae, mode="reduce-overhead")
|
||||
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(
|
||||
vae,
|
||||
args.vae_batch_size,
|
||||
args.cache_latents_to_disk,
|
||||
accelerator.is_main_process
|
||||
)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# verify load/save model formats
|
||||
if load_stable_diffusion_format:
|
||||
|
||||
Reference in New Issue
Block a user