Fix gradient handling when Text Encoders are trained

This commit is contained in:
Kohya S
2025-01-27 21:10:52 +09:00
parent 532f5c58a6
commit 86a2f3fd26
3 changed files with 8 additions and 47 deletions

View File

@@ -345,7 +345,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
t5_attn_mask = None
# call model
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
with torch.set_grad_enabled(is_train), accelerator.autocast():
# TODO support attention mask
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)