fix multi gpu

This commit is contained in:
Isotr0py
2023-03-10 18:45:53 +08:00
parent c4a596df9e
commit 7544b38635
4 changed files with 5 additions and 5 deletions

View File

@@ -1849,7 +1849,7 @@ def get_optimizer(args, trainable_params):
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
def get_scheduler_fix(args,optimizer: Optimizer):
def get_scheduler_fix(args,optimizer: Optimizer,num_processes:int):
"""
Unified API to get any scheduler from its name.
Args:
@@ -1873,7 +1873,7 @@ def get_scheduler_fix(args,optimizer: Optimizer):
name = args.lr_scheduler
num_warmup_steps = args.lr_warmup_steps
num_training_steps = args.max_train_steps * args.gradient_accumulation_steps
num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps
num_cycles = args.lr_scheduler_num_cycles
power = args.lr_scheduler_power