support deepspeed

This commit is contained in:
BootsofLagrangian
2024-02-04 03:12:42 +09:00
parent cd19df49cd
commit dfe08f395f
5 changed files with 195 additions and 50 deletions

View File

@@ -353,18 +353,26 @@ class NetworkTrainer:
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
if args.deepspeed:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
else:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -409,20 +417,42 @@ class NetworkTrainer:
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if train_unet:
unet = accelerator.prepare(unet)
if args.deepspeed:
# wrapping model
class DeepSpeedModel(torch.nn.Module):
def __init__(self, unet, text_encoder, vae, network) -> None:
super().__init__()
self.unet = unet
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
self.vae = vae
self.network = network
def get_models(self):
return self.unet, self.text_encoders, self.vae, self.network
unet.to(accelerator.device, dtype=unet_weight_dtype)
[t_enc.to(accelerator.device, dtype=te_weight_dtype) for t_enc in text_encoders]
ds_model = DeepSpeedModel(unet, text_encoders, vae, network)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
# Now, ds_model is an instance of DeepSpeedEngine.
unet, text_encoders, vae, network = ds_model.get_models() # for compatiblility
vae.to(vae_dtype) # to avoid explicitly half-vae
text_encoder = text_encoders
else:
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
if train_unet:
unet = accelerator.prepare(unet)
else:
text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder]
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
else:
text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder]
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required