Update T5 attention mask handling in FLUX

This commit is contained in:
Kohya S
2024-08-21 08:02:33 +09:00
parent 6ab48b09d8
commit 7e459c00b2
7 changed files with 101 additions and 50 deletions

View File

@@ -233,11 +233,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.flux_lower = flux_lower
self.target_device = device
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None):
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
self.flux_lower.to("cpu")
clean_memory_on_device(self.target_device)
self.flux_upper.to(self.target_device)
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance)
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
self.flux_upper.to("cpu")
clean_memory_on_device(self.target_device)
self.flux_lower.to(self.target_device)
@@ -300,10 +300,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
guidance_vec.requires_grad_(True)
# Predict the noise residual
l_pooled, t5_out, txt_ids = text_encoder_conds
# print(
# f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}"
# )
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
if not args.split_mode:
# normal forward
@@ -317,6 +316,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
else:
# split forward to reduce memory usage
@@ -337,6 +337,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# move flux upper back to cpu, and then move flux lower to gpu