mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix num_processes, fix indent
This commit is contained in:
@@ -1721,10 +1721,14 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
|
||||
)
|
||||
|
||||
parser.add_argument("--lr_scheduler_type", type=str, default="",
|
||||
help="custom scheduler module")
|
||||
parser.add_argument("--lr_scheduler_args", type=str, default=None, nargs='*',
|
||||
help="additional arguments for scheduler (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / スケジューラの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
|
||||
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module")
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_args",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
@@ -2293,23 +2297,6 @@ def get_optimizer(args, trainable_params):
|
||||
def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
Args:
|
||||
name (`str` or `SchedulerType`):
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_cycles (`int`, *optional*):
|
||||
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
||||
power (`float`, *optional*, defaults to 1.0):
|
||||
Power factor. See `POLYNOMIAL` scheduler
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
"""
|
||||
name = args.lr_scheduler
|
||||
num_warmup_steps = args.lr_warmup_steps
|
||||
@@ -2320,12 +2307,12 @@ def get_scheduler_fix(args,optimizer: Optimizer,num_processes:int):
|
||||
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
|
||||
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
|
||||
for arg in args.lr_scheduler_args:
|
||||
key, value = arg.split('=')
|
||||
key, value = arg.split("=")
|
||||
|
||||
value = value.split(",")
|
||||
for i in range(len(value)):
|
||||
if value[i].lower() == "true" or value[i].lower() == "false":
|
||||
value[i] = (value[i].lower() == "true")
|
||||
value[i] = value[i].lower() == "true"
|
||||
else:
|
||||
value[i] = ast.literal_eval(value[i])
|
||||
if len(value) == 1:
|
||||
|
||||
@@ -166,7 +166,7 @@ def train(args):
|
||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||
|
||||
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
|
||||
Reference in New Issue
Block a user