mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Fix gradient handling when Text Encoders are trained
This commit is contained in:
@@ -376,9 +376,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
t5_attn_mask = None
|
||||
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||
# if not args.split_mode:
|
||||
# normal forward
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = unet(
|
||||
img=img,
|
||||
@@ -390,44 +389,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
"""
|
||||
else:
|
||||
# split forward to reduce memory usage
|
||||
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
||||
with accelerator.autocast():
|
||||
# move flux lower to cpu, and then move flux upper to gpu
|
||||
unet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
self.flux_upper.to(accelerator.device)
|
||||
|
||||
# upper model does not require grad
|
||||
with torch.no_grad():
|
||||
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
txt=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
y=l_pooled,
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
# move flux upper back to cpu, and then move flux lower to gpu
|
||||
self.flux_upper.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
|
||||
# lower model requires grad
|
||||
intermediate_img.requires_grad_(True)
|
||||
intermediate_txt.requires_grad_(True)
|
||||
vec.requires_grad_(True)
|
||||
pe.requires_grad_(True)
|
||||
|
||||
with torch.set_grad_enabled(is_train and train_unet):
|
||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
||||
"""
|
||||
|
||||
return model_pred
|
||||
|
||||
model_pred = call_dit(
|
||||
|
||||
@@ -345,7 +345,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
t5_attn_mask = None
|
||||
|
||||
# call model
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# TODO support attention mask
|
||||
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)
|
||||
|
||||
|
||||
@@ -232,7 +232,7 @@ class NetworkTrainer:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
@@ -1405,8 +1405,8 @@ class NetworkTrainer:
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=False,
|
||||
train_unet=False,
|
||||
train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True
|
||||
train_unet=train_unet,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
@@ -1466,8 +1466,8 @@ class NetworkTrainer:
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=False,
|
||||
train_unet=False,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
|
||||
Reference in New Issue
Block a user