From 687044519b6c4f6166145b20cba2d7f2e1df9b8a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 19 Jan 2023 21:43:34 +0900 Subject: [PATCH] Fix TE training stops at max steps if ecpochs set --- train_db.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train_db.py b/train_db.py index bbef3da7..8ac503ea 100644 --- a/train_db.py +++ b/train_db.py @@ -92,10 +92,7 @@ def train(args): gc.collect() # 学習を準備する:モデルを適切な状態にする - if args.stop_text_encoder_training is None: - args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end - - train_text_encoder = args.stop_text_encoder_training >= 0 + train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 text_encoder.requires_grad_(train_text_encoder) if not train_text_encoder: @@ -143,6 +140,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * len(train_dataloader) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + if args.stop_text_encoder_training is None: + args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end + # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)