mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix num_processes, fix indent
This commit is contained in:
@@ -1721,11 +1721,15 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
help='additional arguments for optimizer (like "weight_decay=0.01 betas=0.9,0.999 ...") / オプティマイザの追加引数(例: "weight_decay=0.01 betas=0.9,0.999 ...")',
|
||||
)
|
||||
|
||||
parser.add_argument("--lr_scheduler_type", type=str, default="",
|
||||
help="custom scheduler module")
|
||||
parser.add_argument("--lr_scheduler_args", type=str, default=None, nargs='*',
|
||||
help="additional arguments for scheduler (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / スケジューラの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
|
||||
|
||||
parser.add_argument("--lr_scheduler_type", type=str, default="", help="custom scheduler module")
|
||||
parser.add_argument(
|
||||
"--lr_scheduler_args",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help='additional arguments for scheduler (like "T_max=100") / スケジューラの追加引数(例: "T_max100")',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
@@ -2083,7 +2087,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
||||
# if value is dict, save all key and value into one dict
|
||||
for key, value in section_dict.items():
|
||||
ignore_nesting_dict[key] = value
|
||||
|
||||
|
||||
config_args = argparse.Namespace(**ignore_nesting_dict)
|
||||
args = parser.parse_args(namespace=config_args)
|
||||
args.config_file = os.path.splitext(args.config_file)[0]
|
||||
@@ -2290,26 +2294,9 @@ def get_optimizer(args, trainable_params):
|
||||
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
||||
|
||||
|
||||
def get_scheduler_fix(args,optimizer: Optimizer,num_processes:int):
|
||||
def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
Args:
|
||||
name (`str` or `SchedulerType`):
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_cycles (`int`, *optional*):
|
||||
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
||||
power (`float`, *optional*, defaults to 1.0):
|
||||
Power factor. See `POLYNOMIAL` scheduler
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
"""
|
||||
name = args.lr_scheduler
|
||||
num_warmup_steps = args.lr_warmup_steps
|
||||
@@ -2319,35 +2306,35 @@ def get_scheduler_fix(args,optimizer: Optimizer,num_processes:int):
|
||||
|
||||
lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs
|
||||
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
|
||||
for arg in args.lr_scheduler_args:
|
||||
key, value = arg.split('=')
|
||||
for arg in args.lr_scheduler_args:
|
||||
key, value = arg.split("=")
|
||||
|
||||
value = value.split(",")
|
||||
for i in range(len(value)):
|
||||
if value[i].lower() == "true" or value[i].lower() == "false":
|
||||
value[i] = (value[i].lower() == "true")
|
||||
else:
|
||||
value[i] = ast.literal_eval(value[i])
|
||||
if len(value) == 1:
|
||||
value = value[0]
|
||||
else:
|
||||
value = list(value) # some may use list?
|
||||
value = value.split(",")
|
||||
for i in range(len(value)):
|
||||
if value[i].lower() == "true" or value[i].lower() == "false":
|
||||
value[i] = value[i].lower() == "true"
|
||||
else:
|
||||
value[i] = ast.literal_eval(value[i])
|
||||
if len(value) == 1:
|
||||
value = value[0]
|
||||
else:
|
||||
value = list(value) # some may use list?
|
||||
|
||||
lr_scheduler_kwargs[key] = value
|
||||
lr_scheduler_kwargs[key] = value
|
||||
|
||||
# using any lr_scheduler from other library
|
||||
if args.lr_scheduler_type:
|
||||
lr_scheduler_type = args.lr_scheduler_type
|
||||
print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
|
||||
if "." not in lr_scheduler_type: # default to use torch.optim
|
||||
lr_scheduler_module = torch.optim.lr_scheduler
|
||||
else:
|
||||
values = lr_scheduler_type.split(".")
|
||||
lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
|
||||
lr_scheduler_type = values[-1]
|
||||
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
|
||||
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
|
||||
return lr_scheduler
|
||||
lr_scheduler_type = args.lr_scheduler_type
|
||||
print(f"use {lr_scheduler_type} | {lr_scheduler_kwargs} as lr_scheduler")
|
||||
if "." not in lr_scheduler_type: # default to use torch.optim
|
||||
lr_scheduler_module = torch.optim.lr_scheduler
|
||||
else:
|
||||
values = lr_scheduler_type.split(".")
|
||||
lr_scheduler_module = importlib.import_module(".".join(values[:-1]))
|
||||
lr_scheduler_type = values[-1]
|
||||
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
|
||||
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
|
||||
return lr_scheduler
|
||||
|
||||
if name.startswith("adafactor"):
|
||||
assert (
|
||||
|
||||
Reference in New Issue
Block a user