Fix to work training U-Net only LoRA for SD1/2

This commit is contained in:
Kohya S
2023-10-01 16:37:23 +09:00
parent 6bd6cd9c51
commit 81419f7f32

View File

@@ -426,7 +426,10 @@ class NetworkTrainer:
t_enc.train() t_enc.train()
# set top parameter requires_grad = True for gradient checkpointing works # set top parameter requires_grad = True for gradient checkpointing works
t_enc.text_model.embeddings.requires_grad_(True) if train_text_encoder:
t_enc.text_model.embeddings.requires_grad_(True)
else:
unet.parameters().__next__().requires_grad_(True)
else: else:
unet.eval() unet.eval()
for t_enc in text_encoders: for t_enc in text_encoders: