diff --git a/library/train_util.py b/library/train_util.py index 46b76df4..021d2ccb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3393,10 +3393,8 @@ def get_optimizer(args, trainable_params): return optimizer_name, optimizer_args, optimizer -# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler -# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6 -# Which is a newer release of diffusers than currently packaged with sd-scripts -# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts +# Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler +# Add some checking and features to the original function. def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): @@ -3413,19 +3411,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: for arg in args.lr_scheduler_args: key, value = arg.split("=") - value = ast.literal_eval(value) - # value = value.split(",") - # for i in range(len(value)): - # if value[i].lower() == "true" or value[i].lower() == "false": - # value[i] = value[i].lower() == "true" - # else: - # value[i] = ast.literal_eval(value[i]) - # if len(value) == 1: - # value = value[0] - # else: - # value = list(value) # some may use list? - lr_scheduler_kwargs[key] = value def wrap_check_needless_num_warmup_steps(return_vals): @@ -3457,15 +3443,19 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: - return wrap_check_needless_num_warmup_steps(schedule_func(optimizer)) + return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) + + if name == SchedulerType.PIECEWISE_CONSTANT: + return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") if name == SchedulerType.CONSTANT_WITH_WARMUP: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs) # All other schedulers require `num_training_steps` if num_training_steps is None: @@ -3473,13 +3463,19 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): if name == SchedulerType.COSINE_WITH_RESTARTS: return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + **lr_scheduler_kwargs, ) if name == SchedulerType.POLYNOMIAL: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power) + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs + ) - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs) def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):