make deepspeed_utils

This commit is contained in:
Kohya S
2024-02-27 21:30:46 +09:00
parent 0e4a5738df
commit e3ccf8fbf7
6 changed files with 238 additions and 200 deletions

View File

@@ -13,13 +13,14 @@ from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import model_util
from library import deepspeed_utils, model_util
import library.train_util as train_util
from library.train_util import (
@@ -141,6 +142,7 @@ class NetworkTrainer:
training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -357,7 +359,7 @@ class NetworkTrainer:
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -414,22 +416,17 @@ class NetworkTrainer:
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if args.deepspeed:
training_models_dict = {}
if train_unet: training_models_dict["unet"] = unet
if train_text_encoder: training_models_dict["text_encoder"] = text_encoders
training_models_dict["network"] = network
ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
if train_unet: unet = ds_model.models["unet"]
if train_text_encoder:
text_encoder = ds_model.models["text_encoder"]
if len(ds_model.models["text_encoder"]) > 1:
text_encoders = text_encoder
else:
text_encoders = [text_encoder]
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
unet=unet if train_unet else None,
text_encoder1=text_encoders[0] if train_text_encoder else None,
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
network=network,
)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_model = ds_model
else:
if train_unet:
unet = accelerator.prepare(unet)
@@ -444,7 +441,10 @@ class NetworkTrainer:
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler
)
training_model = network
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
@@ -777,13 +777,13 @@ class NetworkTrainer:
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(network):
with accelerator.accumulate(training_model):
on_step_start(text_encoder, unet)
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
with torch.no_grad():
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
@@ -791,7 +791,7 @@ class NetworkTrainer:
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * self.vae_scale_factor
latents = latents * self.vae_scale_factor
# get multiplier for each sample
if network_has_multiplier:
@@ -976,6 +976,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)