support deepspeed

This commit is contained in:
BootsofLagrangian
2024-02-04 03:12:42 +09:00
parent cd19df49cd
commit dfe08f395f
5 changed files with 195 additions and 50 deletions

View File

@@ -102,6 +102,7 @@ def train(args):
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む # モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator) text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
@@ -152,7 +153,7 @@ def train(args):
# 学習を準備する # 学習を準備する
if cache_latents: if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
@@ -187,7 +188,7 @@ def train(args):
if not cache_latents: if not cache_latents:
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
vae.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=vae_dtype)
for m in training_models: for m in training_models:
m.requires_grad_(True) m.requires_grad_(True)
@@ -214,7 +215,7 @@ def train(args):
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collator, collate_fn=collator,
num_workers=n_workers, num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
@@ -240,7 +241,27 @@ def train(args):
unet.to(weight_dtype) unet.to(weight_dtype)
text_encoder.to(weight_dtype) text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい if args.deepspeed:
# wrapping model
class DeepSpeedModel(torch.nn.Module):
def __init__(self, unet, text_encoder, vae) -> None:
super().__init__()
self.unet = unet
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
self.vae = vae
def get_models(self):
return self.unet, self.text_encoders, self.vae
unet.to(accelerator.device, dtype=weight_dtype)
[t_enc.to(accelerator.device, dtype=weight_dtype) for t_enc in text_encoders]
ds_model = DeepSpeedModel(unet, text_encoders, vae)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
# Now, ds_model is an instance of DeepSpeedEngine.
unet, text_encoders, vae = ds_model.get_models() # for compatiblility
vae.to(vae_dtype)
text_encoder = text_encoders
else: # acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder: if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler unet, text_encoder, optimizer, train_dataloader, lr_scheduler

View File

@@ -20,6 +20,7 @@ from typing import (
Union, Union,
) )
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
from accelerate import DeepSpeedPlugin
import gc import gc
import glob import glob
import math import math
@@ -3124,6 +3125,47 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み" "--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
) )
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
parser.add_argument(
"--zero_stage",
type=int, default=2,
choices=[0, 1, 2, 3],
help="Possible options are 0,1,2,3."
)
parser.add_argument(
"--offload_optimizer",
type=str, default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3."
)
parser.add_argument(
"--offload_optimizer_nvme_path",
type=str, default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."
)
parser.add_argument(
"--offload_param_device",
type=str, default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3."
)
parser.add_argument(
"--offload_param_nvme_path",
type=str, default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."
)
parser.add_argument(
"--zero3_init_flag",
action="store_true",
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
"Only applicable with ZeRO Stage-3."
)
parser.add_argument(
"--zero3_save_16bit_model",
action="store_true",
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3."
)
def verify_training_args(args: argparse.Namespace): def verify_training_args(args: argparse.Namespace):
if args.v_parameterization and not args.v2: if args.v_parameterization and not args.v2:
@@ -3912,6 +3954,17 @@ def prepare_accelerator(args: argparse.Namespace):
else None, else None,
) )
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
deepspeed_plugin = None
if args.deepspeed:
deepspeed_plugin = DeepSpeedPlugin(
zero_stage=args.zero_stage,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_clipping=args.max_grad_norm,
offload_optimizer=args.offload_optimizer, offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
offload_param_device=args.offload_param_device, offload_param_nvme_path=args.offload_param_nvme_path,
zero3_init_flag=args.zero3_init_flag, zero3_save_16bit_model=args.zero3_save_16bit_model,
)
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size
accelerator = Accelerator( accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision, mixed_precision=args.mixed_precision,
@@ -3919,6 +3972,7 @@ def prepare_accelerator(args: argparse.Namespace):
project_dir=logging_dir, project_dir=logging_dir,
kwargs_handlers=kwargs_handlers, kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend, dynamo_backend=dynamo_backend,
deepspeed_plugin=deepspeed_plugin,
) )
return accelerator return accelerator

View File

@@ -354,7 +354,7 @@ def train(args):
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collator, collate_fn=collator,
num_workers=n_workers, num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
@@ -389,7 +389,27 @@ def train(args):
text_encoder1.to(weight_dtype) text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype) text_encoder2.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい if args.deepspeed:
# Wrapping model for DeepSpeed
class DeepSpeedModel(torch.nn.Module):
def __init__(self, unet, text_encoder, vae) -> None:
super().__init__()
self.unet = unet
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
self.vae = vae
def get_models(self):
return self.unet, self.text_encoders, self.vae
text_encoders = [text_encoder1, text_encoder2]
unet.to(accelerator.device, dtype=weight_dtype)
[t_enc.to(accelerator.device, dtype=weight_dtype) for t_enc in text_encoders]
ds_model = DeepSpeedModel(unet, text_encoders, vae)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
# Now, ds_model is an instance of DeepSpeedEngine.
unet, text_encoders, vae = ds_model.get_models() # for compatiblility
vae.to(vae_dtype) # to avoid explicitly half-vae
text_encoder1, text_encoder2 = text_encoders[0], text_encoders[1]
else: # acceleratorがなんかよろしくやってくれるらしい
if train_unet: if train_unet:
unet = accelerator.prepare(unet) unet = accelerator.prepare(unet)
if train_text_encoder1: if train_text_encoder1:
@@ -399,7 +419,6 @@ def train(args):
text_encoder1 = accelerator.prepare(text_encoder1) text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2: if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2) text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# TextEncoderの出力をキャッシュするときにはCPUへ移動する # TextEncoderの出力をキャッシュするときにはCPUへ移動する

View File

@@ -184,7 +184,7 @@ def train(args):
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collator, collate_fn=collator,
num_workers=n_workers, num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
@@ -214,6 +214,27 @@ def train(args):
text_encoder.to(weight_dtype) text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい # acceleratorがなんかよろしくやってくれるらしい
if args.deepspeed:
# wrapping model
class DeepSpeedModel(torch.nn.Module):
def __init__(self, unet, text_encoder, vae) -> None:
super().__init__()
self.unet = unet
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
self.vae = vae
def get_models(self):
return self.unet, self.text_encoders, self.vae
unet.to(accelerator.device, dtype=weight_dtype)
[t_enc.to(accelerator.device, dtype=weight_dtype) for t_enc in text_encoders]
ds_model = DeepSpeedModel(unet, text_encoders, vae)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
# Now, ds_model is an instance of DeepSpeedEngine.
unet, text_encoders, vae = ds_model.get_models() # for compatiblility
vae.to(vae_dtype) # to avoid explicitly half-vae
text_encoder = text_encoders
else:
if train_text_encoder: if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler unet, text_encoder, optimizer, train_dataloader, lr_scheduler

View File

@@ -353,12 +353,20 @@ class NetworkTrainer:
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
collate_fn=collator, collate_fn=collator,
num_workers=n_workers, num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
persistent_workers=args.persistent_data_loader_workers, persistent_workers=args.persistent_data_loader_workers,
) )
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
if args.deepspeed:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
else:
args.max_train_steps = args.max_train_epochs * math.ceil( args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
) )
@@ -409,6 +417,28 @@ class NetworkTrainer:
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if args.deepspeed:
# wrapping model
class DeepSpeedModel(torch.nn.Module):
def __init__(self, unet, text_encoder, vae, network) -> None:
super().__init__()
self.unet = unet
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
self.vae = vae
self.network = network
def get_models(self):
return self.unet, self.text_encoders, self.vae, self.network
unet.to(accelerator.device, dtype=unet_weight_dtype)
[t_enc.to(accelerator.device, dtype=te_weight_dtype) for t_enc in text_encoders]
ds_model = DeepSpeedModel(unet, text_encoders, vae, network)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
# Now, ds_model is an instance of DeepSpeedEngine.
unet, text_encoders, vae, network = ds_model.get_models() # for compatiblility
vae.to(vae_dtype) # to avoid explicitly half-vae
text_encoder = text_encoders
else:
if train_unet: if train_unet:
unet = accelerator.prepare(unet) unet = accelerator.prepare(unet)
else: else: