fix Text Encoder only LoRA training

This commit is contained in:
Kohya S
2025-01-27 22:03:42 +09:00
parent 59b3b94faf
commit 0778dd9b1d
3 changed files with 3 additions and 3 deletions

View File

@@ -378,7 +378,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode: # if not args.split_mode:
# normal forward # normal forward
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet( model_pred = unet(
img=img, img=img,

View File

@@ -345,7 +345,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
t5_attn_mask = None t5_attn_mask = None
# call model # call model
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): with torch.set_grad_enabled(is_train), accelerator.autocast():
# TODO support attention mask # TODO support attention mask
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled) model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)

View File

@@ -233,7 +233,7 @@ class NetworkTrainer:
t.requires_grad_(True) t.requires_grad_(True)
# Predict the noise residual # Predict the noise residual
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast(): with torch.set_grad_enabled(is_train), accelerator.autocast():
noise_pred = self.call_unet( noise_pred = self.call_unet(
args, args,
accelerator, accelerator,