mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support multi gpu in caching text encoder outputs
This commit is contained in:
@@ -204,12 +204,25 @@ def train(args):
|
||||
text_encoder2.gradient_checkpointing_enable()
|
||||
training_models.append(text_encoder1)
|
||||
training_models.append(text_encoder2)
|
||||
|
||||
text_encoder1_cache = None
|
||||
text_encoder2_cache = None
|
||||
|
||||
# set require_grad=True later
|
||||
else:
|
||||
text_encoder1.requires_grad_(False)
|
||||
text_encoder2.requires_grad_(False)
|
||||
text_encoder1.eval()
|
||||
text_encoder2.eval()
|
||||
|
||||
# TextEncoderの出力をキャッシュする
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
|
||||
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataset_group, None
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
@@ -289,23 +302,16 @@ def train(args):
|
||||
(unet,) = train_util.transform_models_if_DDP([unet])
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder2.to(weight_dtype)
|
||||
text_encoder1.eval()
|
||||
text_encoder2.eval()
|
||||
|
||||
# TextEncoderの出力をキャッシュする
|
||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||
if args.cache_text_encoder_outputs:
|
||||
text_encoder1_cache, text_encoder2_cache = sdxl_train_util.cache_text_encoder_outputs(
|
||||
args, accelerator, (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), train_dataloader, None
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
# Text Encoder doesn't work on CPU with fp16
|
||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||
text_encoder1.to("cpu", dtype=torch.float32)
|
||||
text_encoder2.to("cpu", dtype=torch.float32)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
text_encoder1_cache = None
|
||||
text_encoder2_cache = None
|
||||
# make sure Text Encoders are on GPU
|
||||
text_encoder1.to(accelerator.device)
|
||||
text_encoder2.to(accelerator.device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user