arbitrary args for diffusers lr scheduler

This commit is contained in:
Kohya S
2023-07-30 14:36:03 +09:00
parent 8856c19c76
commit 496c3f2732

View File

@@ -3393,10 +3393,8 @@ def get_optimizer(args, trainable_params):
return optimizer_name, optimizer_args, optimizer return optimizer_name, optimizer_args, optimizer
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler # Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6 # Add some checking and features to the original function.
# Which is a newer release of diffusers than currently packaged with sd-scripts
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
@@ -3413,19 +3411,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
for arg in args.lr_scheduler_args: for arg in args.lr_scheduler_args:
key, value = arg.split("=") key, value = arg.split("=")
value = ast.literal_eval(value) value = ast.literal_eval(value)
# value = value.split(",")
# for i in range(len(value)):
# if value[i].lower() == "true" or value[i].lower() == "false":
# value[i] = value[i].lower() == "true"
# else:
# value[i] = ast.literal_eval(value[i])
# if len(value) == 1:
# value = value[0]
# else:
# value = list(value) # some may use list?
lr_scheduler_kwargs[key] = value lr_scheduler_kwargs[key] = value
def wrap_check_needless_num_warmup_steps(return_vals): def wrap_check_needless_num_warmup_steps(return_vals):
@@ -3457,15 +3443,19 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
name = SchedulerType(name) name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT: if name == SchedulerType.CONSTANT:
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer)) return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
if name == SchedulerType.PIECEWISE_CONSTANT:
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
# All other schedulers require `num_warmup_steps` # All other schedulers require `num_warmup_steps`
if num_warmup_steps is None: if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP: if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
# All other schedulers require `num_training_steps` # All other schedulers require `num_training_steps`
if num_training_steps is None: if num_training_steps is None:
@@ -3473,13 +3463,19 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
if name == SchedulerType.COSINE_WITH_RESTARTS: if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func( return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
**lr_scheduler_kwargs,
) )
if name == SchedulerType.POLYNOMIAL: if name == SchedulerType.POLYNOMIAL:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power) return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs
)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs)
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):