mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Check needless num_warmup_steps
This commit is contained in:
@@ -2460,7 +2460,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
|||||||
Unified API to get any scheduler from its name.
|
Unified API to get any scheduler from its name.
|
||||||
"""
|
"""
|
||||||
name = args.lr_scheduler
|
name = args.lr_scheduler
|
||||||
num_warmup_steps = args.lr_warmup_steps
|
num_warmup_steps: Optional[int] = args.lr_warmup_steps
|
||||||
num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps
|
num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps
|
||||||
num_cycles = args.lr_scheduler_num_cycles
|
num_cycles = args.lr_scheduler_num_cycles
|
||||||
power = args.lr_scheduler_power
|
power = args.lr_scheduler_power
|
||||||
@@ -2484,6 +2484,11 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
|||||||
|
|
||||||
lr_scheduler_kwargs[key] = value
|
lr_scheduler_kwargs[key] = value
|
||||||
|
|
||||||
|
def wrap_check_needless_num_warmup_steps(return_vals):
|
||||||
|
if num_warmup_steps is not None and num_warmup_steps != 0:
|
||||||
|
raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.")
|
||||||
|
return return_vals
|
||||||
|
|
||||||
# using any lr_scheduler from other library
|
# using any lr_scheduler from other library
|
||||||
if args.lr_scheduler_type:
|
if args.lr_scheduler_type:
|
||||||
lr_scheduler_type = args.lr_scheduler_type
|
lr_scheduler_type = args.lr_scheduler_type
|
||||||
@@ -2496,7 +2501,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
|||||||
lr_scheduler_type = values[-1]
|
lr_scheduler_type = values[-1]
|
||||||
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
|
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
|
||||||
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
|
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
|
||||||
return lr_scheduler
|
return wrap_check_needless_num_warmup_steps(lr_scheduler)
|
||||||
|
|
||||||
if name.startswith("adafactor"):
|
if name.startswith("adafactor"):
|
||||||
assert (
|
assert (
|
||||||
@@ -2504,12 +2509,12 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
|
|||||||
), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
||||||
initial_lr = float(name.split(":")[1])
|
initial_lr = float(name.split(":")[1])
|
||||||
# print("adafactor scheduler init lr", initial_lr)
|
# print("adafactor scheduler init lr", initial_lr)
|
||||||
return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
|
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
|
||||||
|
|
||||||
name = SchedulerType(name)
|
name = SchedulerType(name)
|
||||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||||
if name == SchedulerType.CONSTANT:
|
if name == SchedulerType.CONSTANT:
|
||||||
return schedule_func(optimizer)
|
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer))
|
||||||
|
|
||||||
# All other schedulers require `num_warmup_steps`
|
# All other schedulers require `num_warmup_steps`
|
||||||
if num_warmup_steps is None:
|
if num_warmup_steps is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user