From 7544b3863536cf4a3ce8ad886f67dc5a69b796c4 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 10 Mar 2023 18:45:53 +0800 Subject: [PATCH] fix multi gpu --- fine_tune.py | 2 +- library/train_util.py | 4 ++-- train_network.py | 2 +- train_textual_inversion.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 89bc1aa6..94f9747d 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -178,7 +178,7 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: diff --git a/library/train_util.py b/library/train_util.py index dc0269cc..80739ad5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1849,7 +1849,7 @@ def get_optimizer(args, trainable_params): # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts -def get_scheduler_fix(args,optimizer: Optimizer): +def get_scheduler_fix(args,optimizer: Optimizer,num_processes:int): """ Unified API to get any scheduler from its name. Args: @@ -1873,7 +1873,7 @@ def get_scheduler_fix(args,optimizer: Optimizer): name = args.lr_scheduler num_warmup_steps = args.lr_warmup_steps - num_training_steps = args.max_train_steps * args.gradient_accumulation_steps + num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps num_cycles = args.lr_scheduler_num_cycles power = args.lr_scheduler_power diff --git a/train_network.py b/train_network.py index cfc0b15a..a4846161 100644 --- a/train_network.py +++ b/train_network.py @@ -179,7 +179,7 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 9f789517..d158b242 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -235,7 +235,7 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # acceleratorがなんかよろしくやってくれるらしい text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(