mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix DDP issues and Support DDP for all training scripts (#448)
* Fix DDP bugs * Fix DDP bugs for finetune and db * refactor model loader * fix DDP network * try to fix DDP network in train unet only * remove unuse DDP import * refactor DDP transform * refactor DDP transform * fix sample images bugs * change DDP tranform location * add autocast to train_db * support DDP in XTI * Clear DDP import
This commit is contained in:
@@ -90,7 +90,7 @@ def train(args):
|
|||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
|
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||||
|
|
||||||
# verify load/save model formats
|
# verify load/save model formats
|
||||||
if load_stable_diffusion_format:
|
if load_stable_diffusion_format:
|
||||||
@@ -228,6 +228,9 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
|
# transform DDP after prepare
|
||||||
|
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
import gc
|
||||||
import glob
|
import glob
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -30,6 +31,7 @@ import toml
|
|||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
@@ -2866,7 +2868,7 @@ def prepare_dtype(args: argparse.Namespace):
|
|||||||
return weight_dtype, save_dtype
|
return weight_dtype, save_dtype
|
||||||
|
|
||||||
|
|
||||||
def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
|
def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
|
||||||
name_or_path = args.pretrained_model_name_or_path
|
name_or_path = args.pretrained_model_name_or_path
|
||||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||||
@@ -2895,6 +2897,36 @@ def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
|
|||||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||||
|
|
||||||
|
|
||||||
|
def transform_DDP(text_encoder, unet, network=None):
|
||||||
|
# Transform text_encoder, unet and network from DistributedDataParallel
|
||||||
|
return (encoder.module if type(encoder) == DDP else encoder for encoder in [text_encoder, unet, network])
|
||||||
|
|
||||||
|
|
||||||
|
def load_target_model(args, weight_dtype, accelerator):
|
||||||
|
# load models for each process
|
||||||
|
for pi in range(accelerator.state.num_processes):
|
||||||
|
if pi == accelerator.state.local_process_index:
|
||||||
|
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||||
|
|
||||||
|
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
|
||||||
|
args, weight_dtype, accelerator.device if args.lowram else "cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# work on low-ram device
|
||||||
|
if args.lowram:
|
||||||
|
text_encoder.to(accelerator.device)
|
||||||
|
unet.to(accelerator.device)
|
||||||
|
vae.to(accelerator.device)
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
text_encoder, unet, _ = transform_DDP(text_encoder, unet, network=None)
|
||||||
|
|
||||||
|
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||||
|
|
||||||
|
|
||||||
def patch_accelerator_for_fp16_training(accelerator):
|
def patch_accelerator_for_fp16_training(accelerator):
|
||||||
org_unscale_grads = accelerator.scaler._unscale_grads_
|
org_unscale_grads = accelerator.scaler._unscale_grads_
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ def interrogate(args):
|
|||||||
print(f"loading SD model: {args.sd_model}")
|
print(f"loading SD model: {args.sd_model}")
|
||||||
args.pretrained_model_name_or_path = args.sd_model
|
args.pretrained_model_name_or_path = args.sd_model
|
||||||
args.vae = None
|
args.vae = None
|
||||||
text_encoder, vae, unet, _ = train_util.load_target_model(args,weights_dtype, DEVICE)
|
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
|
||||||
|
|
||||||
print(f"loading LoRA: {args.model}")
|
print(f"loading LoRA: {args.model}")
|
||||||
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ def train(args):
|
|||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
|
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||||
|
|
||||||
# verify load/save model formats
|
# verify load/save model formats
|
||||||
if load_stable_diffusion_format:
|
if load_stable_diffusion_format:
|
||||||
@@ -196,6 +196,9 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
|
# transform DDP after prepare
|
||||||
|
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)
|
||||||
|
|
||||||
if not train_text_encoder:
|
if not train_text_encoder:
|
||||||
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
||||||
|
|
||||||
@@ -297,6 +300,7 @@ def train(args):
|
|||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
|
with accelerator.autocast():
|
||||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
if args.v_parameterization:
|
if args.v_parameterization:
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
import importlib
|
import importlib
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
@@ -144,24 +143,7 @@ def train(args):
|
|||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
for pi in range(accelerator.state.num_processes):
|
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||||
# TODO: modify other training scripts as well
|
|
||||||
if pi == accelerator.state.local_process_index:
|
|
||||||
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
|
||||||
|
|
||||||
text_encoder, vae, unet, _ = train_util.load_target_model(
|
|
||||||
args, weight_dtype, accelerator.device if args.lowram else "cpu"
|
|
||||||
)
|
|
||||||
|
|
||||||
# work on low-ram device
|
|
||||||
if args.lowram:
|
|
||||||
text_encoder.to(accelerator.device)
|
|
||||||
unet.to(accelerator.device)
|
|
||||||
vae.to(accelerator.device)
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||||
@@ -279,6 +261,9 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
|
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
|
# transform DDP after prepare (train_network here only)
|
||||||
|
text_encoder, unet, network = train_util.transform_DDP(text_encoder, unet, network)
|
||||||
|
|
||||||
unet.requires_grad_(False)
|
unet.requires_grad_(False)
|
||||||
unet.to(accelerator.device, dtype=weight_dtype)
|
unet.to(accelerator.device, dtype=weight_dtype)
|
||||||
text_encoder.requires_grad_(False)
|
text_encoder.requires_grad_(False)
|
||||||
@@ -288,20 +273,11 @@ def train(args):
|
|||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
|
|
||||||
# set top parameter requires_grad = True for gradient checkpointing works
|
# set top parameter requires_grad = True for gradient checkpointing works
|
||||||
if type(text_encoder) == DDP:
|
|
||||||
text_encoder.module.text_model.embeddings.requires_grad_(True)
|
|
||||||
else:
|
|
||||||
text_encoder.text_model.embeddings.requires_grad_(True)
|
text_encoder.text_model.embeddings.requires_grad_(True)
|
||||||
else:
|
else:
|
||||||
unet.eval()
|
unet.eval()
|
||||||
text_encoder.eval()
|
text_encoder.eval()
|
||||||
|
|
||||||
# support DistributedDataParallel
|
|
||||||
if type(text_encoder) == DDP:
|
|
||||||
text_encoder = text_encoder.module
|
|
||||||
unet = unet.module
|
|
||||||
network = network.module
|
|
||||||
|
|
||||||
network.prepare_grad_etc(text_encoder, unet)
|
network.prepare_grad_etc(text_encoder, unet)
|
||||||
|
|
||||||
if not cache_latents:
|
if not cache_latents:
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ def train(args):
|
|||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||||
|
|
||||||
# Convert the init_word to token_id
|
# Convert the init_word to token_id
|
||||||
if args.init_word is not None:
|
if args.init_word is not None:
|
||||||
@@ -280,6 +280,9 @@ def train(args):
|
|||||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# transform DDP after prepare
|
||||||
|
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)
|
||||||
|
|
||||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
||||||
# print(len(index_no_updates), torch.sum(index_no_updates))
|
# print(len(index_no_updates), torch.sum(index_no_updates))
|
||||||
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ def train(args):
|
|||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||||
|
|
||||||
# Convert the init_word to token_id
|
# Convert the init_word to token_id
|
||||||
if args.init_word is not None:
|
if args.init_word is not None:
|
||||||
@@ -314,6 +314,9 @@ def train(args):
|
|||||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# transform DDP after prepare
|
||||||
|
text_encoder, unet, _ = train_util.transform_DDP(text_encoder, unet)
|
||||||
|
|
||||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
||||||
# print(len(index_no_updates), torch.sum(index_no_updates))
|
# print(len(index_no_updates), torch.sum(index_no_updates))
|
||||||
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||||
|
|||||||
Reference in New Issue
Block a user