mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Update T5 attention mask handling in FLUX
This commit is contained in:
@@ -610,7 +610,10 @@ def train(args):
|
||||
guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device)
|
||||
|
||||
# call model
|
||||
l_pooled, t5_out, txt_ids = text_encoder_conds
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
with accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = flux(
|
||||
@@ -621,6 +624,7 @@ def train(args):
|
||||
y=l_pooled,
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
# unpack latents
|
||||
|
||||
Reference in New Issue
Block a user