add lr_scheduler_type etc

This commit is contained in:
Isotr0py
2023-03-09 16:51:22 +08:00
parent 8d5ba29363
commit eb68892ab1
5 changed files with 49 additions and 20 deletions

View File

@@ -235,9 +235,7 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
# acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(