diff --git a/library/train_util.py b/library/train_util.py index 59dbc44c..a195faac 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2460,7 +2460,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): Unified API to get any scheduler from its name. """ name = args.lr_scheduler - num_warmup_steps = args.lr_warmup_steps + num_warmup_steps: Optional[int] = args.lr_warmup_steps num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps num_cycles = args.lr_scheduler_num_cycles power = args.lr_scheduler_power @@ -2484,6 +2484,11 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): lr_scheduler_kwargs[key] = value + def wrap_check_needless_num_warmup_steps(return_vals): + if num_warmup_steps is not None and num_warmup_steps != 0: + raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.") + return return_vals + # using any lr_scheduler from other library if args.lr_scheduler_type: lr_scheduler_type = args.lr_scheduler_type @@ -2496,7 +2501,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): lr_scheduler_type = values[-1] lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) - return lr_scheduler + return wrap_check_needless_num_warmup_steps(lr_scheduler) if name.startswith("adafactor"): assert ( @@ -2504,12 +2509,12 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" initial_lr = float(name.split(":")[1]) # print("adafactor scheduler init lr", initial_lr) - return transformers.optimization.AdafactorSchedule(optimizer, initial_lr) + return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: - return schedule_func(optimizer) + return wrap_check_needless_num_warmup_steps(schedule_func(optimizer)) # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: