mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix gradient handling when Text Encoders are trained
This commit is contained in:
@@ -376,9 +376,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
t5_attn_mask = None
|
t5_attn_mask = None
|
||||||
|
|
||||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||||
# if not args.split_mode:
|
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
||||||
# normal forward
|
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
|
||||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||||
model_pred = unet(
|
model_pred = unet(
|
||||||
img=img,
|
img=img,
|
||||||
@@ -390,44 +389,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
guidance=guidance_vec,
|
guidance=guidance_vec,
|
||||||
txt_attention_mask=t5_attn_mask,
|
txt_attention_mask=t5_attn_mask,
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
else:
|
|
||||||
# split forward to reduce memory usage
|
|
||||||
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
|
||||||
with accelerator.autocast():
|
|
||||||
# move flux lower to cpu, and then move flux upper to gpu
|
|
||||||
unet.to("cpu")
|
|
||||||
clean_memory_on_device(accelerator.device)
|
|
||||||
self.flux_upper.to(accelerator.device)
|
|
||||||
|
|
||||||
# upper model does not require grad
|
|
||||||
with torch.no_grad():
|
|
||||||
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
|
||||||
img=packed_noisy_model_input,
|
|
||||||
img_ids=img_ids,
|
|
||||||
txt=t5_out,
|
|
||||||
txt_ids=txt_ids,
|
|
||||||
y=l_pooled,
|
|
||||||
timesteps=timesteps / 1000,
|
|
||||||
guidance=guidance_vec,
|
|
||||||
txt_attention_mask=t5_attn_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
# move flux upper back to cpu, and then move flux lower to gpu
|
|
||||||
self.flux_upper.to("cpu")
|
|
||||||
clean_memory_on_device(accelerator.device)
|
|
||||||
unet.to(accelerator.device)
|
|
||||||
|
|
||||||
# lower model requires grad
|
|
||||||
intermediate_img.requires_grad_(True)
|
|
||||||
intermediate_txt.requires_grad_(True)
|
|
||||||
vec.requires_grad_(True)
|
|
||||||
pe.requires_grad_(True)
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_train and train_unet):
|
|
||||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
|
||||||
"""
|
|
||||||
|
|
||||||
return model_pred
|
return model_pred
|
||||||
|
|
||||||
model_pred = call_dit(
|
model_pred = call_dit(
|
||||||
|
|||||||
@@ -345,7 +345,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
t5_attn_mask = None
|
t5_attn_mask = None
|
||||||
|
|
||||||
# call model
|
# call model
|
||||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||||
# TODO support attention mask
|
# TODO support attention mask
|
||||||
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)
|
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)
|
||||||
|
|
||||||
|
|||||||
@@ -232,7 +232,7 @@ class NetworkTrainer:
|
|||||||
t.requires_grad_(True)
|
t.requires_grad_(True)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||||
noise_pred = self.call_unet(
|
noise_pred = self.call_unet(
|
||||||
args,
|
args,
|
||||||
accelerator,
|
accelerator,
|
||||||
@@ -1405,8 +1405,8 @@ class NetworkTrainer:
|
|||||||
text_encoding_strategy,
|
text_encoding_strategy,
|
||||||
tokenize_strategy,
|
tokenize_strategy,
|
||||||
is_train=False,
|
is_train=False,
|
||||||
train_text_encoder=False,
|
train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True
|
||||||
train_unet=False,
|
train_unet=train_unet,
|
||||||
)
|
)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
@@ -1466,8 +1466,8 @@ class NetworkTrainer:
|
|||||||
text_encoding_strategy,
|
text_encoding_strategy,
|
||||||
tokenize_strategy,
|
tokenize_strategy,
|
||||||
is_train=False,
|
is_train=False,
|
||||||
train_text_encoder=False,
|
train_text_encoder=train_text_encoder,
|
||||||
train_unet=False,
|
train_unet=train_unet,
|
||||||
)
|
)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
|
|||||||
Reference in New Issue
Block a user