support SD3 LoRA

This commit is contained in:
kohya-ss
2024-10-25 21:58:31 +09:00
parent f52fb66e8f
commit d2c549d7b2
7 changed files with 1334 additions and 67 deletions

View File

@@ -220,12 +220,7 @@ def train(args):
sd3_state_dict = None
# load tokenizer and prepare tokenize strategy
if args.t5xxl_max_token_length is None:
t5xxl_max_token_length = 256 # default value for T5XXL
else:
t5xxl_max_token_length = args.t5xxl_max_token_length
sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(t5xxl_max_token_length)
sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length)
strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy)
# load clip_l, clip_g, t5xxl for caching text encoder outputs
@@ -876,6 +871,9 @@ def train(args):
lg_out = None
t5_out = None
lg_pooled = None
l_attn_mask = None
g_attn_mask = None
t5_attn_mask = None
if lg_out is None:
# not cached or training, so get from text encoders
@@ -885,7 +883,7 @@ def train(args):
# text models in sd3_models require "cpu" for input_ids
input_ids_clip_l = input_ids_clip_l.to("cpu")
input_ids_clip_g = input_ids_clip_g.to("cpu")
lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens(
lg_out, _, lg_pooled, l_attn_mask, g_attn_mask, _ = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy,
[clip_l, clip_g, None],
[input_ids_clip_l, input_ids_clip_g, None, l_attn_mask, g_attn_mask, None],
@@ -895,7 +893,7 @@ def train(args):
_, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"]
with torch.no_grad():
input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None
_, t5_out, _ = text_encoding_strategy.encode_tokens(
_, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask]
)
@@ -1104,22 +1102,6 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする"
)
parser.add_argument(
"--t5xxl_max_token_length",
type=int,
default=None,
help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256",
)
parser.add_argument(
"--apply_lg_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスクゼロ埋めを適用する",
)
parser.add_argument(
"--apply_t5_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスクゼロ埋めを適用する",
)
parser.add_argument(
"--learning_rate_te1",