sd3 training

This commit is contained in:
Kohya S
2024-06-23 23:38:20 +09:00
parent a518e3c819
commit d53ea22b2a
8 changed files with 1909 additions and 44 deletions

View File

@@ -320,8 +320,11 @@ if __name__ == "__main__":
# prepare embeddings
logger.info("Encoding prompts...")
# embeds, pooled_embed
cond = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl)
neg_cond = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl)
lg_out, t5_out, pooled = sd3_utils.get_cond(args.prompt, tokenizer, clip_l, clip_g, t5xxl)
cond = torch.cat([lg_out, t5_out], dim=-2), pooled
lg_out, t5_out, pooled = sd3_utils.get_cond(args.negative_prompt, tokenizer, clip_l, clip_g, t5xxl)
neg_cond = torch.cat([lg_out, t5_out], dim=-2), pooled
# generate image
logger.info("Generating image...")