cache latents to disk in dreambooth method

This commit is contained in:
Kohya S
2023-04-12 23:10:39 +09:00
parent 5050971ac6
commit 2e9f7b5f91
6 changed files with 67 additions and 15 deletions

View File

@@ -142,12 +142,14 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする
training_models = []
if args.gradient_checkpointing: