support multi gpu in caching text encoder outputs

This commit is contained in:
Kohya S
2023-07-09 16:02:56 +09:00
parent 3579b4570f
commit 0416f26a76
5 changed files with 32 additions and 22 deletions

View File

@@ -204,12 +204,25 @@ def train(args):
text_encoder2.gradient_checkpointing_enable()
training_models.append(text_encoder1)
training_models.append(text_encoder2)
text_encoder1_cache = None
text_encoder2_cache = None
# set require_grad=True later
else:
text_encoder1.requires_grad_(False)
text_encoder2.requires_grad_(False)
text_encoder1.eval()
text_encoder2.eval()
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataset_group, None
)
accelerator.wait_for_everyone()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
@@ -289,23 +302,16 @@ def train(args):
(unet,) = train_util.transform_models_if_DDP([unet])
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
text_encoder1.eval()
text_encoder2.eval()
# TextEncoderの出力をキャッシュする
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataloader, None
)
accelerator.wait_for_everyone()
# Text Encoder doesn't work on CPU with fp16
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
text_encoder1_cache = None
text_encoder2_cache = None
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)