Fix gradient handling when Text Encoders are trained

This commit is contained in:
Kohya S
2025-01-27 21:10:52 +09:00
parent 532f5c58a6
commit 86a2f3fd26
3 changed files with 8 additions and 47 deletions

View File

@@ -376,9 +376,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
t5_attn_mask = None
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode:
# normal forward
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
img=img,
@@ -390,44 +389,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
"""
else:
# split forward to reduce memory usage
assert network.train_blocks == "single", "train_blocks must be single for split mode"
with accelerator.autocast():
# move flux lower to cpu, and then move flux upper to gpu
unet.to("cpu")
clean_memory_on_device(accelerator.device)
self.flux_upper.to(accelerator.device)
# upper model does not require grad
with torch.no_grad():
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# move flux upper back to cpu, and then move flux lower to gpu
self.flux_upper.to("cpu")
clean_memory_on_device(accelerator.device)
unet.to(accelerator.device)
# lower model requires grad
intermediate_img.requires_grad_(True)
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)
with torch.set_grad_enabled(is_train and train_unet):
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
"""
return model_pred
model_pred = call_dit(

View File

@@ -345,7 +345,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
t5_attn_mask = None
# call model
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
with torch.set_grad_enabled(is_train), accelerator.autocast():
# TODO support attention mask
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)

View File

@@ -232,7 +232,7 @@ class NetworkTrainer:
t.requires_grad_(True)
# Predict the noise residual
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
with torch.set_grad_enabled(is_train), accelerator.autocast():
noise_pred = self.call_unet(
args,
accelerator,
@@ -1405,8 +1405,8 @@ class NetworkTrainer:
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False,
train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True
train_unet=train_unet,
)
current_loss = loss.detach().item()
@@ -1466,8 +1466,8 @@ class NetworkTrainer:
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False,
train_text_encoder=train_text_encoder,
train_unet=train_unet,
)
current_loss = loss.detach().item()