support multi gpu in caching text encoder outputs

This commit is contained in:
Kohya S
2023-07-09 16:02:56 +09:00
parent 3579b4570f
commit 0416f26a76
5 changed files with 32 additions and 22 deletions

View File

@@ -255,6 +255,11 @@ class NetworkTrainer:
accelerator.wait_for_everyone()
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
self.cache_text_encoder_outputs_if_needed(
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
)
# prepare network
net_kwargs = {}
if args.network_args is not None:
@@ -419,11 +424,6 @@ class NetworkTrainer:
vae.eval()
vae.to(accelerator.device, dtype=vae_dtype)
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
self.cache_text_encoder_outputs_if_needed(
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataloader, weight_dtype
)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)