mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix DDP issues and Support DDP for all training scripts (#448)
* Fix DDP bugs * Fix DDP bugs for finetune and db * refactor model loader * fix DDP network * try to fix DDP network in train unet only * remove unuse DDP import * refactor DDP transform * refactor DDP transform * fix sample images bugs * change DDP tranform location * add autocast to train_db * support DDP in XTI * Clear DDP import
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import importlib
|
||||
import argparse
|
||||
import gc
|
||||
@@ -144,24 +143,7 @@ def train(args):
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
# TODO: modify other training scripts as well
|
||||
if pi == accelerator.state.local_process_index:
|
||||
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(
|
||||
args, weight_dtype, accelerator.device if args.lowram else "cpu"
|
||||
)
|
||||
|
||||
# work on low-ram device
|
||||
if args.lowram:
|
||||
text_encoder.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.wait_for_everyone()
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
@@ -279,6 +261,9 @@ def train(args):
|
||||
else:
|
||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# transform DDP after prepare (train_network here only)
|
||||
text_encoder, unet, network = train_util.transform_DDP(text_encoder, unet, network)
|
||||
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
@@ -288,20 +273,11 @@ def train(args):
|
||||
text_encoder.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if type(text_encoder) == DDP:
|
||||
text_encoder.module.text_model.embeddings.requires_grad_(True)
|
||||
else:
|
||||
text_encoder.text_model.embeddings.requires_grad_(True)
|
||||
text_encoder.text_model.embeddings.requires_grad_(True)
|
||||
else:
|
||||
unet.eval()
|
||||
text_encoder.eval()
|
||||
|
||||
# support DistributedDataParallel
|
||||
if type(text_encoder) == DDP:
|
||||
text_encoder = text_encoder.module
|
||||
unet = unet.module
|
||||
network = network.module
|
||||
|
||||
|
||||
network.prepare_grad_etc(text_encoder, unet)
|
||||
|
||||
if not cache_latents:
|
||||
|
||||
Reference in New Issue
Block a user