mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix Text Encoder only LoRA training
This commit is contained in:
@@ -378,7 +378,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||
# if not args.split_mode:
|
||||
# normal forward
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = unet(
|
||||
img=img,
|
||||
|
||||
@@ -345,7 +345,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
t5_attn_mask = None
|
||||
|
||||
# call model
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# TODO support attention mask
|
||||
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)
|
||||
|
||||
|
||||
@@ -233,7 +233,7 @@ class NetworkTrainer:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
|
||||
Reference in New Issue
Block a user