Fix gradient handling when Text Encoders are trained

This commit is contained in:
Kohya S
2025-01-27 21:10:52 +09:00
parent 532f5c58a6
commit 86a2f3fd26
3 changed files with 8 additions and 47 deletions

View File

@@ -232,7 +232,7 @@ class NetworkTrainer:
t.requires_grad_(True)
# Predict the noise residual
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
with torch.set_grad_enabled(is_train), accelerator.autocast():
noise_pred = self.call_unet(
args,
accelerator,
@@ -1405,8 +1405,8 @@ class NetworkTrainer:
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False,
train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True
train_unet=train_unet,
)
current_loss = loss.detach().item()
@@ -1466,8 +1466,8 @@ class NetworkTrainer:
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False,
train_text_encoder=train_text_encoder,
train_unet=train_unet,
)
current_loss = loss.detach().item()