refactored codes, some function moved into train_utils.py

This commit is contained in:
BootsofLagrangian
2024-02-22 16:20:53 +09:00
parent 03f0816f86
commit 4d5186d1cf
5 changed files with 119 additions and 96 deletions

View File

@@ -243,24 +243,19 @@ def train(args):
text_encoder.to(weight_dtype)
if args.deepspeed:
# wrapping model
import deepspeed
if args.offload_optimizer_device is not None:
accelerator.print('[DeepSpeed] start to manually build cpu_adam.')
deepspeed.ops.op_builder.CPUAdamBuilder().load()
accelerator.print('[DeepSpeed] building cpu_adam done.')
class DeepSpeedModel(torch.nn.Module):
def __init__(self, unet, text_encoder) -> None:
super().__init__()
self.unet = unet
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
def get_models(self):
return self.unet, self.text_encoders
ds_model = DeepSpeedModel(unet, text_encoders)
training_models_dict = {}
training_models_dict["unet"] = unet
if args.train_text_encoder: training_models_dict["text_encoder"] = text_encoder
ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
# Now, ds_model is an instance of DeepSpeedEngine.
unet, text_encoders = ds_model.get_models() # for compatiblility
text_encoder = text_encoders
training_models = []
unet = ds_model.models["unet"]
training_models.append(unet)
if args.train_text_encoder:
text_encoder = ds_model.models["text_encoder"]
training_models.append(text_encoder)
else: # acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder: