diff --git a/train_db.py b/train_db.py index a47da472..d1ef350c 100644 --- a/train_db.py +++ b/train_db.py @@ -98,6 +98,8 @@ def train(args): train_text_encoder = args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 text_encoder.requires_grad_(train_text_encoder) + if not train_text_encoder: + print("Text Encoder is not trained.") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -153,6 +155,9 @@ def train(args): unet, text_encoder, optimizer, train_dataloader, lr_scheduler) else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + + if not train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: