mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support multi gpu in caching text encoder outputs
This commit is contained in:
@@ -255,6 +255,11 @@ class NetworkTrainer:
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
||||
self.cache_text_encoder_outputs_if_needed(
|
||||
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
|
||||
)
|
||||
|
||||
# prepare network
|
||||
net_kwargs = {}
|
||||
if args.network_args is not None:
|
||||
@@ -419,11 +424,6 @@ class NetworkTrainer:
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
|
||||
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
||||
self.cache_text_encoder_outputs_if_needed(
|
||||
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataloader, weight_dtype
|
||||
)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
Reference in New Issue
Block a user