Refactor caching mechanism for latents and text encoder outputs, etc.

This commit is contained in:
Kohya S
2024-07-27 13:50:05 +09:00
parent 082f13658b
commit 41dee60383
21 changed files with 1786 additions and 733 deletions

View File

@@ -24,7 +24,7 @@ import logging
logger = logging.getLogger(__name__)
from library import sd3_models, sd3_utils
from library import sd3_models, sd3_utils, strategy_sd3
def get_noise(seed, latent):
@@ -145,6 +145,7 @@ if __name__ == "__main__":
parser.add_argument("--clip_g", type=str, required=False)
parser.add_argument("--clip_l", type=str, required=False)
parser.add_argument("--t5xxl", type=str, required=False)
parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77")
parser.add_argument("--prompt", type=str, default="A photo of a cat")
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
parser.add_argument("--negative_prompt", type=str, default="")
@@ -247,7 +248,7 @@ if __name__ == "__main__":
# load tokenizers
logger.info("Loading tokenizers...")
tokenizer = sd3_models.SD3Tokenizer(use_t5xxl) # combined tokenizer
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
# load models
# logger.info("Create MMDiT from SD3 checkpoint...")
@@ -320,12 +321,19 @@ if __name__ == "__main__":
# prepare embeddings
logger.info("Encoding prompts...")
# embeds, pooled_embed
lg_out, t5_out, pooled = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl)
cond = torch.cat([lg_out, t5_out], dim=-2), pooled
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
lg_out, t5_out, pooled = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl)
neg_cond = torch.cat([lg_out, t5_out], dim=-2), pooled
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.prompt)
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens]
)
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.negative_prompt)
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens]
)
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# generate image
logger.info("Generating image...")