support attn mask for l+g/t5

This commit is contained in:
Kohya S
2024-08-05 20:51:34 +09:00
parent 231df197dd
commit da4d0fe016
4 changed files with 107 additions and 24 deletions

View File

@@ -146,6 +146,8 @@ if __name__ == "__main__":
parser.add_argument("--clip_l", type=str, required=False)
parser.add_argument("--t5xxl", type=str, required=False)
parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77")
parser.add_argument("--apply_lg_attn_mask", action="store_true")
parser.add_argument("--apply_t5_attn_mask", action="store_true")
parser.add_argument("--prompt", type=str, default="A photo of a cat")
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
parser.add_argument("--negative_prompt", type=str, default="")
@@ -323,15 +325,15 @@ if __name__ == "__main__":
logger.info("Encoding prompts...")
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.prompt)
tokens_and_masks = tokenize_strategy.tokenize(args.prompt)
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens]
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
)
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.negative_prompt)
tokens_and_masks = tokenize_strategy.tokenize(args.negative_prompt)
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens]
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
)
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)