mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Fix TE training stops at max steps if ecpochs set
This commit is contained in:
@@ -92,10 +92,7 @@ def train(args):
|
||||
gc.collect()
|
||||
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
if args.stop_text_encoder_training is None:
|
||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||
|
||||
train_text_encoder = args.stop_text_encoder_training >= 0
|
||||
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
|
||||
unet.requires_grad_(True) # 念のため追加
|
||||
text_encoder.requires_grad_(train_text_encoder)
|
||||
if not train_text_encoder:
|
||||
@@ -143,6 +140,9 @@ def train(args):
|
||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
if args.stop_text_encoder_training is None:
|
||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
|
||||
|
||||
Reference in New Issue
Block a user