mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
arbitrary args for diffusers lr scheduler
This commit is contained in:
@@ -3393,10 +3393,8 @@ def get_optimizer(args, trainable_params):
|
|||||||
return optimizer_name, optimizer_args, optimizer
|
return optimizer_name, optimizer_args, optimizer
|
||||||
|
|
||||||
|
|
||||||
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
# Modified version of get_scheduler() function from diffusers.optimizer.get_scheduler
|
||||||
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
# Add some checking and features to the original function.
|
||||||
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
|
||||||
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
||||||
@@ -3413,19 +3411,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
|||||||
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
|
if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0:
|
||||||
for arg in args.lr_scheduler_args:
|
for arg in args.lr_scheduler_args:
|
||||||
key, value = arg.split("=")
|
key, value = arg.split("=")
|
||||||
|
|
||||||
value = ast.literal_eval(value)
|
value = ast.literal_eval(value)
|
||||||
# value = value.split(",")
|
|
||||||
# for i in range(len(value)):
|
|
||||||
# if value[i].lower() == "true" or value[i].lower() == "false":
|
|
||||||
# value[i] = value[i].lower() == "true"
|
|
||||||
# else:
|
|
||||||
# value[i] = ast.literal_eval(value[i])
|
|
||||||
# if len(value) == 1:
|
|
||||||
# value = value[0]
|
|
||||||
# else:
|
|
||||||
# value = list(value) # some may use list?
|
|
||||||
|
|
||||||
lr_scheduler_kwargs[key] = value
|
lr_scheduler_kwargs[key] = value
|
||||||
|
|
||||||
def wrap_check_needless_num_warmup_steps(return_vals):
|
def wrap_check_needless_num_warmup_steps(return_vals):
|
||||||
@@ -3457,15 +3443,19 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
|||||||
|
|
||||||
name = SchedulerType(name)
|
name = SchedulerType(name)
|
||||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||||
|
|
||||||
if name == SchedulerType.CONSTANT:
|
if name == SchedulerType.CONSTANT:
|
||||||
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer))
|
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs))
|
||||||
|
|
||||||
|
if name == SchedulerType.PIECEWISE_CONSTANT:
|
||||||
|
return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs
|
||||||
|
|
||||||
# All other schedulers require `num_warmup_steps`
|
# All other schedulers require `num_warmup_steps`
|
||||||
if num_warmup_steps is None:
|
if num_warmup_steps is None:
|
||||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||||
|
|
||||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs)
|
||||||
|
|
||||||
# All other schedulers require `num_training_steps`
|
# All other schedulers require `num_training_steps`
|
||||||
if num_training_steps is None:
|
if num_training_steps is None:
|
||||||
@@ -3473,13 +3463,19 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
|||||||
|
|
||||||
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
||||||
return schedule_func(
|
return schedule_func(
|
||||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
optimizer,
|
||||||
|
num_warmup_steps=num_warmup_steps,
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
num_cycles=num_cycles,
|
||||||
|
**lr_scheduler_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if name == SchedulerType.POLYNOMIAL:
|
if name == SchedulerType.POLYNOMIAL:
|
||||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power)
|
return schedule_func(
|
||||||
|
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||||
|
|||||||
Reference in New Issue
Block a user