Refactor caching in train scripts

This commit is contained in:
kohya-ss
2024-10-12 20:18:41 +09:00
parent ff4083b910
commit c80c304779
14 changed files with 95 additions and 47 deletions

View File

@@ -188,8 +188,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
None,
False,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
@@ -222,7 +222,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
text_encoders[1].to(weight_dtype)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process)
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# cache sample prompts
if args.sample_prompts is not None: