mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
arbitrary args for diffusers lr scheduler
This commit is contained in:
@@ -3393,10 +3393,8 @@ def get_optimizer(args, trainable_params):
|
||||
return optimizer_name, optimizer_args, optimizer
|
||||
|
||||
|
||||
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
||||
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
||||
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
||||
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
||||
# Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler
|
||||
# Add some checking and features to the original function.
|
||||
|
||||
|
||||
def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
@@ -3413,19 +3411,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
|
||||
for arg in args.lr_scheduler_args:
|
||||
key, value = arg.split("=")
|
||||
|
||||
value = ast.literal_eval(value)
|
||||
# value = value.split(",")
|
||||
# for i in range(len(value)):
|
||||
# if value[i].lower() == "true" or value[i].lower() == "false":
|
||||
# value[i] = value[i].lower() == "true"
|
||||
# else:
|
||||
# value[i] = ast.literal_eval(value[i])
|
||||
# if len(value) == 1:
|
||||
# value = value[0]
|
||||
# else:
|
||||
# value = list(value) # some may use list?
|
||||
|
||||
lr_scheduler_kwargs[key] = value
|
||||
|
||||
def wrap_check_needless_num_warmup_steps(return_vals):
|
||||
@@ -3457,15 +3443,19 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer))
|
||||
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
|
||||
|
||||
if name == SchedulerType.PIECEWISE_CONSTANT:
|
||||
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
@@ -3473,13 +3463,19 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||
|
||||
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
||||
optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
num_cycles=num_cycles,
|
||||
**lr_scheduler_kwargs,
|
||||
)
|
||||
|
||||
if name == SchedulerType.POLYNOMIAL:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power)
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs
|
||||
)
|
||||
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs)
|
||||
|
||||
|
||||
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
|
||||
Reference in New Issue
Block a user