mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
refactored codes, some function moved into train_utils.py
This commit is contained in:
@@ -391,28 +391,29 @@ def train(args):
|
||||
text_encoder2.to(weight_dtype)
|
||||
|
||||
if args.deepspeed:
|
||||
# Wrapping model for DeepSpeed
|
||||
import deepspeed
|
||||
if args.offload_optimizer_device is not None:
|
||||
accelerator.print('[DeepSpeed] start to manually build cpu_adam.')
|
||||
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
||||
accelerator.print('[DeepSpeed] building cpu_adam done.')
|
||||
|
||||
class DeepSpeedModel(torch.nn.Module):
|
||||
def __init__(self, unet, text_encoder) -> None:
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
self.text_encoders = self.text_encoder = torch.nn.ModuleList(text_encoder)
|
||||
|
||||
def get_models(self):
|
||||
return self.unet, self.text_encoders
|
||||
text_encoders = [text_encoder1, text_encoder2]
|
||||
ds_model = DeepSpeedModel(unet, text_encoders)
|
||||
training_models_dict = {}
|
||||
if train_unet:
|
||||
training_models_dict["unet"] = unet
|
||||
if train_text_encoder1:
|
||||
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
||||
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
||||
training_models_dict["text_encoder1"] = text_encoder1
|
||||
if train_text_encoder2:
|
||||
training_models_dict["text_encoder2"] = text_encoder2
|
||||
ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)
|
||||
# Now, ds_model is an instance of DeepSpeedEngine.
|
||||
unet, text_encoders = ds_model.get_models() # for compatiblility
|
||||
text_encoder1, text_encoder2 = text_encoder = text_encoders
|
||||
training_models = [unet, text_encoder1, text_encoder2]
|
||||
|
||||
training_models = [] # override training_models
|
||||
if train_unet:
|
||||
unet = ds_model.models["unet"]
|
||||
training_models.append(unet)
|
||||
if train_text_encoder1:
|
||||
text_encoder1 = ds_model.models["text_encoder1"]
|
||||
training_models.append(text_encoder1)
|
||||
if train_text_encoder2:
|
||||
text_encoder2 = ds_model.models["text_encoder2"]
|
||||
training_models.append(text_encoder2)
|
||||
|
||||
else: # acceleratorがなんかよろしくやってくれるらしい
|
||||
if train_unet:
|
||||
unet = accelerator.prepare(unet)
|
||||
|
||||
Reference in New Issue
Block a user