diff --git a/train_db.py b/train_db.py index bbef3da7..8ac503ea 100644 --- a/train_db.py +++ b/train_db.py @@ -92,10 +92,7 @@ def train(args): gc.collect() # 学習を準備する:モデルを適切な状態にする - if args.stop_text_encoder_training is None: - args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end - - train_text_encoder = args.stop_text_encoder_training >= 0 + train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 text_encoder.requires_grad_(train_text_encoder) if not train_text_encoder: @@ -143,6 +140,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * len(train_dataloader) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + if args.stop_text_encoder_training is None: + args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end + # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)