mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support attn mask for l+g/t5
This commit is contained in:
@@ -146,6 +146,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--clip_l", type=str, required=False)
|
||||
parser.add_argument("--t5xxl", type=str, required=False)
|
||||
parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77")
|
||||
parser.add_argument("--apply_lg_attn_mask", action="store_true")
|
||||
parser.add_argument("--apply_t5_attn_mask", action="store_true")
|
||||
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
|
||||
parser.add_argument("--negative_prompt", type=str, default="")
|
||||
@@ -323,15 +325,15 @@ if __name__ == "__main__":
|
||||
logger.info("Encoding prompts...")
|
||||
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
|
||||
|
||||
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.prompt)
|
||||
tokens_and_masks = tokenize_strategy.tokenize(args.prompt)
|
||||
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens]
|
||||
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
|
||||
)
|
||||
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
|
||||
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.negative_prompt)
|
||||
tokens_and_masks = tokenize_strategy.tokenize(args.negative_prompt)
|
||||
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens]
|
||||
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
|
||||
)
|
||||
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user