mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix DDP issues and Support DDP for all training scripts (#448)
* Fix DDP bugs * Fix DDP bugs for finetune and db * refactor model loader * fix DDP network * try to fix DDP network in train unet only * remove unuse DDP import * refactor DDP transform * refactor DDP transform * fix sample images bugs * change DDP tranform location * add autocast to train_db * support DDP in XTI * Clear DDP import
This commit is contained in:
@@ -90,7 +90,7 @@ def train(args):
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# verify load/save model formats
|
||||
if load_stable_diffusion_format:
|
||||
@@ -228,6 +228,9 @@ def train(args):
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# transform DDP after prepare
|
||||
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
Reference in New Issue
Block a user