Fix DDP issues and Support DDP for all training scripts (#448)

* Fix DDP bugs

* Fix DDP bugs for finetune and db

* refactor model loader

* fix DDP network

* try to fix DDP network in train unet only

* remove unuse DDP import

* refactor DDP transform

* refactor DDP transform

* fix sample images bugs

* change DDP tranform location

* add autocast to train_db

* support DDP in XTI

* Clear DDP import
This commit is contained in:
Isotr0py
2023-05-03 09:37:47 +08:00
committed by GitHub
parent a7485e4d9e
commit e1143caf38
7 changed files with 58 additions and 37 deletions

View File

@@ -1,4 +1,3 @@
from torch.nn.parallel import DistributedDataParallel as DDP
import importlib
import argparse
import gc
@@ -144,24 +143,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
for pi in range(accelerator.state.num_processes):
# TODO: modify other training scripts as well
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
text_encoder, vae, unet, _ = train_util.load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu"
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@@ -279,6 +261,9 @@ def train(args):
else:
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare (train_network here only)
text_encoder, unet, network = train_util.transform_DDP(text_encoder, unet, network)
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
@@ -288,20 +273,11 @@ def train(args):
text_encoder.train()
# set top parameter requires_grad = True for gradient checkpointing works
if type(text_encoder) == DDP:
text_encoder.module.text_model.embeddings.requires_grad_(True)
else:
text_encoder.text_model.embeddings.requires_grad_(True)
text_encoder.text_model.embeddings.requires_grad_(True)
else:
unet.eval()
text_encoder.eval()
# support DistributedDataParallel
if type(text_encoder) == DDP:
text_encoder = text_encoder.module
unet = unet.module
network = network.module
network.prepare_grad_etc(text_encoder, unet)
if not cache_latents: