mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Refactor caching in train scripts
This commit is contained in:
@@ -67,7 +67,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
return latents_caching_strategy
|
||||
|
||||
@@ -80,7 +80,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions
|
||||
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions
|
||||
)
|
||||
else:
|
||||
return None
|
||||
@@ -102,9 +102,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
||||
with accelerator.autocast():
|
||||
dataset.new_cache_text_encoder_outputs(
|
||||
text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator.is_main_process
|
||||
)
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||
|
||||
Reference in New Issue
Block a user