Fix DDP issues and Support DDP for all training scripts (#448)

* Fix DDP bugs

* Fix DDP bugs for finetune and db

* refactor model loader

* fix DDP network

* try to fix DDP network in train unet only

* remove unuse DDP import

* refactor DDP transform

* refactor DDP transform

* fix sample images bugs

* change DDP tranform location

* add autocast to train_db

* support DDP in XTI

* Clear DDP import
This commit is contained in:
Isotr0py
2023-05-03 09:37:47 +08:00
committed by GitHub
parent a7485e4d9e
commit e1143caf38
7 changed files with 58 additions and 37 deletions

View File

@@ -90,7 +90,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
# verify load/save model formats
if load_stable_diffusion_format:
@@ -228,6 +228,9 @@ def train(args):
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)

View File

@@ -19,6 +19,7 @@ from typing import (
Union,
)
from accelerate import Accelerator
import gc
import glob
import math
import os
@@ -30,6 +31,7 @@ import toml
from tqdm import tqdm
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer
@@ -2866,7 +2868,7 @@ def prepare_dtype(args: argparse.Namespace):
return weight_dtype, save_dtype
def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
@@ -2895,6 +2897,36 @@ def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
return text_encoder, vae, unet, load_stable_diffusion_format
def transform_DDP(text_encoder, unet, network=None):
# Transform text_encoder, unet and network from DistributedDataParallel
return (encoder.module if type(encoder) == DDP else encoder for encoder in [text_encoder, unet, network])
def load_target_model(args, weight_dtype, accelerator):
# load models for each process
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu"
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
text_encoder, unet, _ = transform_DDP(text_encoder, unet, network=None)
return text_encoder, vae, unet, load_stable_diffusion_format
def patch_accelerator_for_fp16_training(accelerator):
org_unscale_grads = accelerator.scaler._unscale_grads_

View File

@@ -23,7 +23,7 @@ def interrogate(args):
print(f"loading SD model: {args.sd_model}")
args.pretrained_model_name_or_path = args.sd_model
args.vae = None
text_encoder, vae, unet, _ = train_util.load_target_model(args,weights_dtype, DEVICE)
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
print(f"loading LoRA: {args.model}")
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)

View File

@@ -92,7 +92,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
# verify load/save model formats
if load_stable_diffusion_format:
@@ -196,6 +196,9 @@ def train(args):
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
@@ -297,6 +300,7 @@ def train(args):
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:

View File

@@ -1,4 +1,3 @@
from torch.nn.parallel import DistributedDataParallel as DDP
import importlib
import argparse
import gc
@@ -144,24 +143,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
for pi in range(accelerator.state.num_processes):
# TODO: modify other training scripts as well
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
text_encoder, vae, unet, _ = train_util.load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu"
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@@ -279,6 +261,9 @@ def train(args):
else:
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare (train_network here only)
text_encoder, unet, network = train_util.transform_DDP(text_encoder, unet, network)
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
@@ -288,20 +273,11 @@ def train(args):
text_encoder.train()
# set top parameter requires_grad = True for gradient checkpointing works
if type(text_encoder) == DDP:
text_encoder.module.text_model.embeddings.requires_grad_(True)
else:
text_encoder.text_model.embeddings.requires_grad_(True)
else:
unet.eval()
text_encoder.eval()
# support DistributedDataParallel
if type(text_encoder) == DDP:
text_encoder = text_encoder.module
unet = unet.module
network = network.module
network.prepare_grad_etc(text_encoder, unet)
if not cache_latents:

View File

@@ -98,7 +98,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
# Convert the init_word to token_id
if args.init_word is not None:
@@ -280,6 +280,9 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()

View File

@@ -104,7 +104,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
# Convert the init_word to token_id
if args.init_word is not None:
@@ -314,6 +314,9 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()