fix scheduler steps with gradient accumulation

This commit is contained in:
Kohya S
2023-07-16 15:56:29 +09:00
parent 3db97f8897
commit 41d195715d

View File

@@ -3348,7 +3348,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
""" """
name = args.lr_scheduler name = args.lr_scheduler
num_warmup_steps: Optional[int] = args.lr_warmup_steps num_warmup_steps: Optional[int] = args.lr_warmup_steps
num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
num_cycles = args.lr_scheduler_num_cycles num_cycles = args.lr_scheduler_num_cycles
power = args.lr_scheduler_power power = args.lr_scheduler_power