Refactor caching in train scripts

This commit is contained in:
kohya-ss
2024-10-12 20:18:41 +09:00
parent ff4083b910
commit c80c304779
14 changed files with 95 additions and 47 deletions

View File

@@ -84,7 +84,7 @@ def train(args):
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, False
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
@@ -230,7 +230,7 @@ def train(args):
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator.is_main_process)
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
accelerator.wait_for_everyone()