Fix some LoRA not trained if gradient checkpointing

This commit is contained in:
Kohya S
2023-01-19 20:39:33 +09:00
parent f2f2ce0d7d
commit e6a8c9d269

View File

@@ -166,6 +166,9 @@ def train(args):
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
text_encoder.train()
# set top parameter requires_grad = True for gradient checkpointing works
text_encoder.text_model.embeddings.requires_grad_(True)
else:
unet.eval()
text_encoder.eval()