mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Refactor caching mechanism for latents and text encoder outputs, etc.
This commit is contained in:
@@ -24,7 +24,7 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import sd3_models, sd3_utils
|
||||
from library import sd3_models, sd3_utils, strategy_sd3
|
||||
|
||||
|
||||
def get_noise(seed, latent):
|
||||
@@ -145,6 +145,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--clip_g", type=str, required=False)
|
||||
parser.add_argument("--clip_l", type=str, required=False)
|
||||
parser.add_argument("--t5xxl", type=str, required=False)
|
||||
parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77")
|
||||
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
||||
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
|
||||
parser.add_argument("--negative_prompt", type=str, default="")
|
||||
@@ -247,7 +248,7 @@ if __name__ == "__main__":
|
||||
|
||||
# load tokenizers
|
||||
logger.info("Loading tokenizers...")
|
||||
tokenizer = sd3_models.SD3Tokenizer(use_t5xxl) # combined tokenizer
|
||||
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
|
||||
|
||||
# load models
|
||||
# logger.info("Create MMDiT from SD3 checkpoint...")
|
||||
@@ -320,12 +321,19 @@ if __name__ == "__main__":
|
||||
|
||||
# prepare embeddings
|
||||
logger.info("Encoding prompts...")
|
||||
# embeds, pooled_embed
|
||||
lg_out, t5_out, pooled = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl)
|
||||
cond = torch.cat([lg_out, t5_out], dim=-2), pooled
|
||||
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
|
||||
|
||||
lg_out, t5_out, pooled = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl)
|
||||
neg_cond = torch.cat([lg_out, t5_out], dim=-2), pooled
|
||||
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.prompt)
|
||||
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens]
|
||||
)
|
||||
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
|
||||
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(args.negative_prompt)
|
||||
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [clip_l, clip_g, t5xxl], [l_tokens, g_tokens, t5_tokens]
|
||||
)
|
||||
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
|
||||
# generate image
|
||||
logger.info("Generating image...")
|
||||
|
||||
Reference in New Issue
Block a user