Check needless num_warmup_steps

This commit is contained in:
Yuta Hayashibe
2023-04-01 20:33:20 +09:00
parent c93cbbc373
commit 9577a9f38d

View File

@@ -2460,7 +2460,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
Unified API to get any scheduler from its name. Unified API to get any scheduler from its name.
""" """
name = args.lr_scheduler name = args.lr_scheduler
num_warmup_steps = args.lr_warmup_steps num_warmup_steps: Optional[int] = args.lr_warmup_steps
num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps
num_cycles = args.lr_scheduler_num_cycles num_cycles = args.lr_scheduler_num_cycles
power = args.lr_scheduler_power power = args.lr_scheduler_power
@@ -2484,6 +2484,11 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
lr_scheduler_kwargs[key] = value lr_scheduler_kwargs[key] = value
def wrap_check_needless_num_warmup_steps(return_vals):
if num_warmup_steps is not None and num_warmup_steps != 0:
raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.")
return return_vals
# using any lr_scheduler from other library # using any lr_scheduler from other library
if args.lr_scheduler_type: if args.lr_scheduler_type:
lr_scheduler_type = args.lr_scheduler_type lr_scheduler_type = args.lr_scheduler_type
@@ -2496,7 +2501,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
lr_scheduler_type = values[-1] lr_scheduler_type = values[-1]
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type) lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs) lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
return lr_scheduler return wrap_check_needless_num_warmup_steps(lr_scheduler)
if name.startswith("adafactor"): if name.startswith("adafactor"):
assert ( assert (
@@ -2504,12 +2509,12 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください" ), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
initial_lr = float(name.split(":")[1]) initial_lr = float(name.split(":")[1])
# print("adafactor scheduler init lr", initial_lr) # print("adafactor scheduler init lr", initial_lr)
return transformers.optimization.AdafactorSchedule(optimizer, initial_lr) return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
name = SchedulerType(name) name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT: if name == SchedulerType.CONSTANT:
return schedule_func(optimizer) return wrap_check_needless_num_warmup_steps(schedule_func(optimizer))
# All other schedulers require `num_warmup_steps` # All other schedulers require `num_warmup_steps`
if num_warmup_steps is None: if num_warmup_steps is None: