Update T5 attention mask handling in FLUX

This commit is contained in:
Kohya S
2024-08-21 08:02:33 +09:00
parent 6ab48b09d8
commit 7e459c00b2
7 changed files with 101 additions and 50 deletions

View File

@@ -610,7 +610,10 @@ def train(args):
guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device)
# call model
l_pooled, t5_out, txt_ids = text_encoder_conds
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = flux(
@@ -621,6 +624,7 @@ def train(args):
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# unpack latents