mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix num_processes, fix indent
This commit is contained in:
@@ -1721,10 +1721,14 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
|||||||
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
|
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--lr_scheduler_type", type=str, default="",
|
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module")
|
||||||
help="custom scheduler module")
|
parser.add_argument(
|
||||||
parser.add_argument("--lr_scheduler_args", type=str, default=None, nargs='*',
|
"--lr_scheduler_args",
|
||||||
help="additional arguments for scheduler (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / スケジューラの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
|
type=str,
|
||||||
|
default=None,
|
||||||
|
nargs="*",
|
||||||
|
help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")',
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr_scheduler",
|
"--lr_scheduler",
|
||||||
@@ -2290,26 +2294,9 @@ def get_optimizer(args, trainable_params):
|
|||||||
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler_fix(args,optimizer: Optimizer,num_processes:int):
|
def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||||
"""
|
"""
|
||||||
Unified API to get any scheduler from its name.
|
Unified API to get any scheduler from its name.
|
||||||
Args:
|
|
||||||
name (`str` or `SchedulerType`):
|
|
||||||
The name of the scheduler to use.
|
|
||||||
optimizer (`torch.optim.Optimizer`):
|
|
||||||
The optimizer that will be used during training.
|
|
||||||
num_warmup_steps (`int`, *optional*):
|
|
||||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
|
||||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
|
||||||
num_training_steps (`int``, *optional*):
|
|
||||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
|
||||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
|
||||||
num_cycles (`int`, *optional*):
|
|
||||||
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
|
||||||
power (`float`, *optional*, defaults to 1.0):
|
|
||||||
Power factor. See `POLYNOMIAL` scheduler
|
|
||||||
last_epoch (`int`, *optional*, defaults to -1):
|
|
||||||
The index of the last epoch when resuming training.
|
|
||||||
"""
|
"""
|
||||||
name = args.lr_scheduler
|
name = args.lr_scheduler
|
||||||
num_warmup_steps = args.lr_warmup_steps
|
num_warmup_steps = args.lr_warmup_steps
|
||||||
@@ -2320,12 +2307,12 @@ def get_scheduler_fix(args,optimizer: Optimizer,num_processes:int):
|
|||||||
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
|
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
|
||||||
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
|
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
|
||||||
for arg in args.lr_scheduler_args:
|
for arg in args.lr_scheduler_args:
|
||||||
key, value = arg.split('=')
|
key, value = arg.split("=")
|
||||||
|
|
||||||
value = value.split(",")
|
value = value.split(",")
|
||||||
for i in range(len(value)):
|
for i in range(len(value)):
|
||||||
if value[i].lower() == "true" or value[i].lower() == "false":
|
if value[i].lower() == "true" or value[i].lower() == "false":
|
||||||
value[i] = (value[i].lower() == "true")
|
value[i] = value[i].lower() == "true"
|
||||||
else:
|
else:
|
||||||
value[i] = ast.literal_eval(value[i])
|
value[i] = ast.literal_eval(value[i])
|
||||||
if len(value) == 1:
|
if len(value) == 1:
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ def train(args):
|
|||||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||||
|
|
||||||
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
|
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
|
|||||||
Reference in New Issue
Block a user