mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
make deepspeed_utils
This commit is contained in:
@@ -11,11 +11,12 @@ from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from library import sdxl_model_util
|
||||
from library import deepspeed_utils, sdxl_model_util
|
||||
|
||||
import library.train_util as train_util
|
||||
|
||||
@@ -97,6 +98,7 @@ def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
deepspeed_utils.prepare_deepspeed_args(args)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
assert (
|
||||
@@ -361,7 +363,7 @@ def train(args):
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
@@ -398,41 +400,31 @@ def train(args):
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder2.to(weight_dtype)
|
||||
|
||||
if args.deepspeed:
|
||||
training_models_dict = {}
|
||||
if train_unet:
|
||||
training_models_dict["unet"] = unet
|
||||
if train_text_encoder1:
|
||||
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
||||
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
||||
training_models_dict["text_encoder1"] = text_encoder1
|
||||
if train_text_encoder2:
|
||||
training_models_dict["text_encoder2"] = text_encoder2
|
||||
ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
training_models = [] # override training_models
|
||||
if train_unet:
|
||||
unet = ds_model.models["unet"]
|
||||
training_models.append(unet)
|
||||
if train_text_encoder1:
|
||||
text_encoder1 = ds_model.models["text_encoder1"]
|
||||
training_models.append(text_encoder1)
|
||||
if train_text_encoder2:
|
||||
text_encoder2 = ds_model.models["text_encoder2"]
|
||||
training_models.append(text_encoder2)
|
||||
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
|
||||
if train_text_encoder1:
|
||||
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
||||
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
||||
|
||||
else: # acceleratorがなんかよろしくやってくれるらしい
|
||||
if args.deepspeed:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(
|
||||
args,
|
||||
unet=unet if train_unet else None,
|
||||
text_encoder1=text_encoder1 if train_text_encoder1 else None,
|
||||
text_encoder2=text_encoder2 if train_text_encoder2 else None,
|
||||
)
|
||||
ds_model = accelerator.prepare(ds_model)
|
||||
training_models = [ds_model]
|
||||
|
||||
else:
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if train_unet:
|
||||
unet = accelerator.prepare(unet)
|
||||
if train_text_encoder1:
|
||||
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
|
||||
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
||||
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||
if train_text_encoder2:
|
||||
text_encoder2 = accelerator.prepare(text_encoder2)
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||
if args.cache_text_encoder_outputs:
|
||||
@@ -446,8 +438,9 @@ def train(args):
|
||||
text_encoder2.to(accelerator.device)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16 and not args.deepspeed:
|
||||
if args.full_fp16:
|
||||
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
|
||||
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
@@ -508,10 +501,10 @@ def train(args):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(*training_models):
|
||||
with torch.no_grad(): # why this block differ within train_network.py?
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
|
||||
|
||||
@@ -519,7 +512,7 @@ def train(args):
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
input_ids1 = batch["input_ids"]
|
||||
@@ -768,6 +761,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
|
||||
Reference in New Issue
Block a user