mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix gradient handling when Text Encoders are trained
This commit is contained in:
@@ -232,7 +232,7 @@ class NetworkTrainer:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
@@ -1405,8 +1405,8 @@ class NetworkTrainer:
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=False,
|
||||
train_unet=False,
|
||||
train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True
|
||||
train_unet=train_unet,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
@@ -1466,8 +1466,8 @@ class NetworkTrainer:
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=False,
|
||||
train_unet=False,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
|
||||
Reference in New Issue
Block a user