mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Merge 68877e789b into 1dae34b0af
This commit is contained in:
@@ -197,10 +197,6 @@ def train(args):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if cache_latents:
|
|
||||||
assert (
|
|
||||||
train_dataset_group.is_latent_cacheable()
|
|
||||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
|
||||||
|
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
assert (
|
assert (
|
||||||
@@ -225,7 +221,32 @@ def train(args):
|
|||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||||
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
|
|
||||||
|
if cache_latents:
|
||||||
|
assert (
|
||||||
|
train_dataset_group.is_latent_cacheable()
|
||||||
|
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||||
|
|
||||||
|
# Force FP16 for caching even if training uses FP32
|
||||||
|
temp_vae_dtype = torch.float16 if not args.no_half_vae else torch.float32
|
||||||
|
vae = vae.to(accelerator.device, dtype=temp_vae_dtype)
|
||||||
|
|
||||||
|
# Optimize VAE performance
|
||||||
|
vae = vae.to(memory_format=torch.channels_last)
|
||||||
|
# if not isinstance(vae, torch._dynamo.eval_frame.OptimizedModule):
|
||||||
|
# vae = torch.compile(vae, mode="reduce-overhead")
|
||||||
|
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
vae.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
train_dataset_group.cache_latents(
|
||||||
|
vae,
|
||||||
|
args.vae_batch_size,
|
||||||
|
args.cache_latents_to_disk,
|
||||||
|
accelerator.is_main_process
|
||||||
|
)
|
||||||
|
vae.to("cpu")
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
# verify load/save model formats
|
# verify load/save model formats
|
||||||
if load_stable_diffusion_format:
|
if load_stable_diffusion_format:
|
||||||
|
|||||||
Reference in New Issue
Block a user