support multi gpu in caching text encoder outputs

This commit is contained in:
Kohya S
2023-07-09 16:02:56 +09:00
parent 3579b4570f
commit 0416f26a76
5 changed files with 32 additions and 22 deletions

View File

@@ -47,7 +47,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
return args.cache_text_encoder_outputs
def cache_text_encoder_outputs_if_needed(
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset, weight_dtype
):
if args.cache_text_encoder_outputs:
if not args.lowram:
@@ -61,7 +61,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
torch.cuda.empty_cache()
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype
args, accelerator, tokenizers, text_encoders, dataset, weight_dtype
)
accelerator.wait_for_everyone()
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU