mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support deepspeed
This commit is contained in:
@@ -354,7 +354,7 @@ def train(args):
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
@@ -389,18 +389,37 @@ def train(args):
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder2.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if train_unet:
|
||||
unet = accelerator.prepare(unet)
|
||||
if train_text_encoder1:
|
||||
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
|
||||
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
||||
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||
if train_text_encoder2:
|
||||
text_encoder2 = accelerator.prepare(text_encoder2)
|
||||
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
if args.deepspeed:
|
||||
# Wrapping model for DeepSpeed
|
||||
class DeepSpeedModel(torch.nn.Module):
|
||||
def __init__(self, unet, text_encoder, vae) -> None:
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
|
||||
self.vae = vae
|
||||
|
||||
def get_models(self):
|
||||
return self.unet, self.text_encoders, self.vae
|
||||
text_encoders = [text_encoder1, text_encoder2]
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
[t_enc.to(accelerator.device, dtype=weight_dtype) for t_enc in text_encoders]
|
||||
ds_model = DeepSpeedModel(unet, text_encoders, vae)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||
# Now, ds_model is an instance of DeepSpeedEngine.
|
||||
unet, text_encoders, vae = ds_model.get_models() # for compatiblility
|
||||
vae.to(vae_dtype) # to avoid explicitly half-vae
|
||||
text_encoder1, text_encoder2 = text_encoders[0], text_encoders[1]
|
||||
else: # acceleratorがなんかよろしくやってくれるらしい
|
||||
if train_unet:
|
||||
unet = accelerator.prepare(unet)
|
||||
if train_text_encoder1:
|
||||
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
|
||||
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
||||
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||
if train_text_encoder2:
|
||||
text_encoder2 = accelerator.prepare(text_encoder2)
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||
if args.cache_text_encoder_outputs:
|
||||
|
||||
Reference in New Issue
Block a user