mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix multi gpu
This commit is contained in:
@@ -178,7 +178,7 @@ def train(args):
|
|||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
|
|||||||
@@ -1849,7 +1849,7 @@ def get_optimizer(args, trainable_params):
|
|||||||
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler_fix(args,optimizer: Optimizer):
|
def get_scheduler_fix(args,optimizer: Optimizer,num_processes:int):
|
||||||
"""
|
"""
|
||||||
Unified API to get any scheduler from its name.
|
Unified API to get any scheduler from its name.
|
||||||
Args:
|
Args:
|
||||||
@@ -1873,7 +1873,7 @@ def get_scheduler_fix(args,optimizer: Optimizer):
|
|||||||
|
|
||||||
name = args.lr_scheduler
|
name = args.lr_scheduler
|
||||||
num_warmup_steps = args.lr_warmup_steps
|
num_warmup_steps = args.lr_warmup_steps
|
||||||
num_training_steps = args.max_train_steps * args.gradient_accumulation_steps
|
num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps
|
||||||
num_cycles = args.lr_scheduler_num_cycles
|
num_cycles = args.lr_scheduler_num_cycles
|
||||||
power = args.lr_scheduler_power
|
power = args.lr_scheduler_power
|
||||||
|
|
||||||
|
|||||||
@@ -179,7 +179,7 @@ def train(args):
|
|||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
|
|||||||
@@ -235,7 +235,7 @@ def train(args):
|
|||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
|||||||
Reference in New Issue
Block a user