Fix DDP issues and Support DDP for all training scripts (#448)

* Fix DDP bugs

* Fix DDP bugs for finetune and db

* refactor model loader

* fix DDP network

* try to fix DDP network in train unet only

* remove unuse DDP import

* refactor DDP transform

* refactor DDP transform

* fix sample images bugs

* change DDP tranform location

* add autocast to train_db

* support DDP in XTI

* Clear DDP import
This commit is contained in:
Isotr0py
2023-05-03 09:37:47 +08:00
committed by GitHub
parent a7485e4d9e
commit e1143caf38
7 changed files with 58 additions and 37 deletions

View File

@@ -90,7 +90,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む # モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
# verify load/save model formats # verify load/save model formats
if load_stable_diffusion_format: if load_stable_diffusion_format:
@@ -228,6 +228,9 @@ def train(args):
else: else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする # 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16: if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator) train_util.patch_accelerator_for_fp16_training(accelerator)

View File

@@ -19,6 +19,7 @@ from typing import (
Union, Union,
) )
from accelerate import Accelerator from accelerate import Accelerator
import gc
import glob import glob
import math import math
import os import os
@@ -30,6 +31,7 @@ import toml
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torchvision import transforms from torchvision import transforms
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
@@ -2866,7 +2868,7 @@ def prepare_dtype(args: argparse.Namespace):
return weight_dtype, save_dtype return weight_dtype, save_dtype
def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"): def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
name_or_path = args.pretrained_model_name_or_path name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
@@ -2895,6 +2897,36 @@ def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
return text_encoder, vae, unet, load_stable_diffusion_format return text_encoder, vae, unet, load_stable_diffusion_format
def transform_DDP(text_encoder, unet, network=None):
# Transform text_encoder, unet and network from DistributedDataParallel
return (encoder.module if type(encoder) == DDP else encoder for encoder in [text_encoder, unet, network])
def load_target_model(args, weight_dtype, accelerator):
# load models for each process
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu"
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
text_encoder, unet, _ = transform_DDP(text_encoder, unet, network=None)
return text_encoder, vae, unet, load_stable_diffusion_format
def patch_accelerator_for_fp16_training(accelerator): def patch_accelerator_for_fp16_training(accelerator):
org_unscale_grads = accelerator.scaler._unscale_grads_ org_unscale_grads = accelerator.scaler._unscale_grads_

View File

@@ -23,7 +23,7 @@ def interrogate(args):
print(f"loading SD model: {args.sd_model}") print(f"loading SD model: {args.sd_model}")
args.pretrained_model_name_or_path = args.sd_model args.pretrained_model_name_or_path = args.sd_model
args.vae = None args.vae = None
text_encoder, vae, unet, _ = train_util.load_target_model(args,weights_dtype, DEVICE) text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
print(f"loading LoRA: {args.model}") print(f"loading LoRA: {args.model}")
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)

View File

@@ -92,7 +92,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む # モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
# verify load/save model formats # verify load/save model formats
if load_stable_diffusion_format: if load_stable_diffusion_format:
@@ -196,6 +196,9 @@ def train(args):
else: else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)
if not train_text_encoder: if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
@@ -297,6 +300,7 @@ def train(args):
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual # Predict the noise residual
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization: if args.v_parameterization:

View File

@@ -1,4 +1,3 @@
from torch.nn.parallel import DistributedDataParallel as DDP
import importlib import importlib
import argparse import argparse
import gc import gc
@@ -144,24 +143,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む # モデルを読み込む
for pi in range(accelerator.state.num_processes): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
# TODO: modify other training scripts as well
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
text_encoder, vae, unet, _ = train_util.load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu"
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
# モデルに xformers とか memory efficient attention を組み込む # モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@@ -279,6 +261,9 @@ def train(args):
else: else:
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare (train_network here only)
text_encoder, unet, network = train_util.transform_DDP(text_encoder, unet, network)
unet.requires_grad_(False) unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
@@ -288,20 +273,11 @@ def train(args):
text_encoder.train() text_encoder.train()
# set top parameter requires_grad = True for gradient checkpointing works # set top parameter requires_grad = True for gradient checkpointing works
if type(text_encoder) == DDP:
text_encoder.module.text_model.embeddings.requires_grad_(True)
else:
text_encoder.text_model.embeddings.requires_grad_(True) text_encoder.text_model.embeddings.requires_grad_(True)
else: else:
unet.eval() unet.eval()
text_encoder.eval() text_encoder.eval()
# support DistributedDataParallel
if type(text_encoder) == DDP:
text_encoder = text_encoder.module
unet = unet.module
network = network.module
network.prepare_grad_etc(text_encoder, unet) network.prepare_grad_etc(text_encoder, unet)
if not cache_latents: if not cache_latents:

View File

@@ -98,7 +98,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む # モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
# Convert the init_word to token_id # Convert the init_word to token_id
if args.init_word is not None: if args.init_word is not None:
@@ -280,6 +280,9 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler text_encoder, optimizer, train_dataloader, lr_scheduler
) )
# transform DDP after prepare
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
# print(len(index_no_updates), torch.sum(index_no_updates)) # print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()

View File

@@ -104,7 +104,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む # モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
# Convert the init_word to token_id # Convert the init_word to token_id
if args.init_word is not None: if args.init_word is not None:
@@ -314,6 +314,9 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler text_encoder, optimizer, train_dataloader, lr_scheduler
) )
# transform DDP after prepare
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates)) # print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()