add lr_scheduler_type etc

This commit is contained in:
Isotr0py
2023-03-09 16:51:22 +08:00
parent 8d5ba29363
commit eb68892ab1
5 changed files with 49 additions and 20 deletions

View File

@@ -150,9 +150,7 @@ def train(args):
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16: