Merge pull request #271 from Isotr0py/dev

Add '--lr_scheduler_type' and '--lr_scheduler_args' argument
This commit is contained in:
Kohya S
2023-03-19 10:26:34 +09:00
committed by GitHub
5 changed files with 49 additions and 40 deletions

View File

@@ -198,14 +198,7 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(
args.lr_scheduler,
optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles,
power=args.lr_scheduler_power,
)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:

View File

@@ -1,6 +1,7 @@
# common functions for training
import argparse
import ast
import importlib
import json
import pathlib
@@ -1720,6 +1721,11 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ..."',
)
parser.add_argument("--lr_scheduler_type", type=str, default="",
help="custom scheduler module")
parser.add_argument("--lr_scheduler_args", type=str, default=None, nargs='*',
help="additional arguments for scheduler (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / スケジューラの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\"")
parser.add_argument(
"--lr_scheduler",
type=str,
@@ -2284,14 +2290,7 @@ def get_optimizer(args, trainable_params):
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
def get_scheduler_fix(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
num_cycles: int = 1,
power: float = 1.0,
):
def get_scheduler_fix(args,optimizer: Optimizer,num_processes:int):
"""
Unified API to get any scheduler from its name.
Args:
@@ -2312,6 +2311,44 @@ def get_scheduler_fix(
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
"""
name = args.lr_scheduler
num_warmup_steps = args.lr_warmup_steps
num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps
num_cycles = args.lr_scheduler_num_cycles
power = args.lr_scheduler_power
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
for arg in args.lr_scheduler_args:
key, value = arg.split('=')
value = value.split(",")
for i in range(len(value)):
if value[i].lower() == "true" or value[i].lower() == "false":
value[i] = (value[i].lower() == "true")
else:
value[i] = ast.literal_eval(value[i])
if len(value) == 1:
value = value[0]
else:
value = list(value) # some may use list?
lr_scheduler_kwargs[key] = value
# using any lr_scheduler from other library
if args.lr_scheduler_type:
lr_scheduler_type = args.lr_scheduler_type
print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
if "." not in lr_scheduler_type: # default to use torch.optim
lr_scheduler_module = torch.optim.lr_scheduler
else:
values = lr_scheduler_type.split(".")
lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
lr_scheduler_type = values[-1]
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
return lr_scheduler
if name.startswith("adafactor"):
assert (
type(optimizer) == transformers.optimization.Adafactor

View File

@@ -166,14 +166,7 @@ def train(args):
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
lr_scheduler = train_util.get_scheduler_fix(
args.lr_scheduler,
optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps,
num_cycles=args.lr_scheduler_num_cycles,
power=args.lr_scheduler_power,
)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:

View File

@@ -201,14 +201,7 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(
args.lr_scheduler,
optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles,
power=args.lr_scheduler_power,
)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:

View File

@@ -261,14 +261,7 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(
args.lr_scheduler,
optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles,
power=args.lr_scheduler_power,
)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(