Merge branch 'kohya-ss:main' into weighted_captions

This commit is contained in:
AI-Casanova
2023-04-05 17:07:15 -05:00
committed by GitHub
9 changed files with 547 additions and 159 deletions

View File

@@ -2460,7 +2460,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
Unified API to get any scheduler from its name.
"""
name = args.lr_scheduler
num_warmup_steps = args.lr_warmup_steps
num_warmup_steps: Optional[int] = args.lr_warmup_steps
num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps
num_cycles = args.lr_scheduler_num_cycles
power = args.lr_scheduler_power
@@ -2484,6 +2484,11 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
lr_scheduler_kwargs[key] = value
def wrap_check_needless_num_warmup_steps(return_vals):
if num_warmup_steps is not None and num_warmup_steps != 0:
raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.")
return return_vals
# using any lr_scheduler from other library
if args.lr_scheduler_type:
lr_scheduler_type = args.lr_scheduler_type
@@ -2496,7 +2501,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
lr_scheduler_type = values[-1]
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
return lr_scheduler
return wrap_check_needless_num_warmup_steps(lr_scheduler)
if name.startswith("adafactor"):
assert (
@@ -2504,12 +2509,12 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
initial_lr = float(name.split(":")[1])
# print("adafactor scheduler init lr", initial_lr)
return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer))
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None: