From e1143caf381a60bcee50d4bf76de26093327059a Mon Sep 17 00:00:00 2001 From: Isotr0py <41363108+Isotr0py@users.noreply.github.com> Date: Wed, 3 May 2023 09:37:47 +0800 Subject: [PATCH] Fix DDP issues and Support DDP for all training scripts (#448) * Fix DDP bugs * Fix DDP bugs for finetune and db * refactor model loader * fix DDP network * try to fix DDP network in train unet only * remove unuse DDP import * refactor DDP transform * refactor DDP transform * fix sample images bugs * change DDP tranform location * add autocast to train_db * support DDP in XTI * Clear DDP import --- fine_tune.py | 5 ++++- library/train_util.py | 34 +++++++++++++++++++++++++++++++- networks/lora_interrogator.py | 2 +- train_db.py | 8 ++++++-- train_network.py | 36 ++++++---------------------------- train_textual_inversion.py | 5 ++++- train_textual_inversion_XTI.py | 5 ++++- 7 files changed, 58 insertions(+), 37 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index b6a8d1d7..db1c8a23 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -90,7 +90,7 @@ def train(args): weight_dtype, save_dtype = train_util.prepare_dtype(args) # モデルを読み込む - text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator) # verify load/save model formats if load_stable_diffusion_format: @@ -228,6 +228,9 @@ def train(args): else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + # transform DDP after prepare + text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet) + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: train_util.patch_accelerator_for_fp16_training(accelerator) diff --git a/library/train_util.py b/library/train_util.py index 6c064738..1a3b2ed0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -19,6 +19,7 @@ from typing import ( Union, ) from accelerate import Accelerator +import gc import glob import math import os @@ -30,6 +31,7 @@ import toml from tqdm import tqdm import torch +from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torchvision import transforms from transformers import CLIPTokenizer @@ -2866,7 +2868,7 @@ def prepare_dtype(args: argparse.Namespace): return weight_dtype, save_dtype -def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"): +def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"): name_or_path = args.pretrained_model_name_or_path name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers @@ -2895,6 +2897,36 @@ def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"): return text_encoder, vae, unet, load_stable_diffusion_format +def transform_DDP(text_encoder, unet, network=None): + # Transform text_encoder, unet and network from DistributedDataParallel + return (encoder.module if type(encoder) == DDP else encoder for encoder in [text_encoder, unet, network]) + + +def load_target_model(args, weight_dtype, accelerator): + # load models for each process + for pi in range(accelerator.state.num_processes): + if pi == accelerator.state.local_process_index: + print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") + + text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model( + args, weight_dtype, accelerator.device if args.lowram else "cpu" + ) + + # work on low-ram device + if args.lowram: + text_encoder.to(accelerator.device) + unet.to(accelerator.device) + vae.to(accelerator.device) + + gc.collect() + torch.cuda.empty_cache() + accelerator.wait_for_everyone() + + text_encoder, unet, _ = transform_DDP(text_encoder, unet, network=None) + + return text_encoder, vae, unet, load_stable_diffusion_format + + def patch_accelerator_for_fp16_training(accelerator): org_unscale_grads = accelerator.scaler._unscale_grads_ diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py index beb25181..0dc066fd 100644 --- a/networks/lora_interrogator.py +++ b/networks/lora_interrogator.py @@ -23,7 +23,7 @@ def interrogate(args): print(f"loading SD model: {args.sd_model}") args.pretrained_model_name_or_path = args.sd_model args.vae = None - text_encoder, vae, unet, _ = train_util.load_target_model(args,weights_dtype, DEVICE) + text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE) print(f"loading LoRA: {args.model}") network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet) diff --git a/train_db.py b/train_db.py index 178d5cb4..abe2ecdf 100644 --- a/train_db.py +++ b/train_db.py @@ -92,7 +92,7 @@ def train(args): weight_dtype, save_dtype = train_util.prepare_dtype(args) # モデルを読み込む - text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator) # verify load/save model formats if load_stable_diffusion_format: @@ -196,6 +196,9 @@ def train(args): else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + # transform DDP after prepare + text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet) + if not train_text_encoder: text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error @@ -297,7 +300,8 @@ def train(args): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with accelerator.autocast(): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.v_parameterization: # v-parameterization training diff --git a/train_network.py b/train_network.py index 5c4d5ad1..c5ec0ebd 100644 --- a/train_network.py +++ b/train_network.py @@ -1,4 +1,3 @@ -from torch.nn.parallel import DistributedDataParallel as DDP import importlib import argparse import gc @@ -144,24 +143,7 @@ def train(args): weight_dtype, save_dtype = train_util.prepare_dtype(args) # モデルを読み込む - for pi in range(accelerator.state.num_processes): - # TODO: modify other training scripts as well - if pi == accelerator.state.local_process_index: - print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") - - text_encoder, vae, unet, _ = train_util.load_target_model( - args, weight_dtype, accelerator.device if args.lowram else "cpu" - ) - - # work on low-ram device - if args.lowram: - text_encoder.to(accelerator.device) - unet.to(accelerator.device) - vae.to(accelerator.device) - - gc.collect() - torch.cuda.empty_cache() - accelerator.wait_for_everyone() + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) @@ -279,6 +261,9 @@ def train(args): else: network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler) + # transform DDP after prepare (train_network here only) + text_encoder, unet, network = train_util.transform_DDP(text_encoder, unet, network) + unet.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) @@ -288,20 +273,11 @@ def train(args): text_encoder.train() # set top parameter requires_grad = True for gradient checkpointing works - if type(text_encoder) == DDP: - text_encoder.module.text_model.embeddings.requires_grad_(True) - else: - text_encoder.text_model.embeddings.requires_grad_(True) + text_encoder.text_model.embeddings.requires_grad_(True) else: unet.eval() text_encoder.eval() - - # support DistributedDataParallel - if type(text_encoder) == DDP: - text_encoder = text_encoder.module - unet = unet.module - network = network.module - + network.prepare_grad_etc(text_encoder, unet) if not cache_latents: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index fb6b6053..c13fcf9f 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -98,7 +98,7 @@ def train(args): weight_dtype, save_dtype = train_util.prepare_dtype(args) # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) # Convert the init_word to token_id if args.init_word is not None: @@ -280,6 +280,9 @@ def train(args): text_encoder, optimizer, train_dataloader, lr_scheduler ) + # transform DDP after prepare + text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet) + index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] # print(len(index_no_updates), torch.sum(index_no_updates)) orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 69ec3eb1..67d48023 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -104,7 +104,7 @@ def train(args): weight_dtype, save_dtype = train_util.prepare_dtype(args) # モデルを読み込む - text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) # Convert the init_word to token_id if args.init_word is not None: @@ -314,6 +314,9 @@ def train(args): text_encoder, optimizer, train_dataloader, lr_scheduler ) + # transform DDP after prepare + text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet) + index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] # print(len(index_no_updates), torch.sum(index_no_updates)) orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()