mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
support SD3 LoRA
This commit is contained in:
30
sd3_train.py
30
sd3_train.py
@@ -220,12 +220,7 @@ def train(args):
|
||||
sd3_state_dict = None
|
||||
|
||||
# load tokenizer and prepare tokenize strategy
|
||||
if args.t5xxl_max_token_length is None:
|
||||
t5xxl_max_token_length = 256 # default value for T5XXL
|
||||
else:
|
||||
t5xxl_max_token_length = args.t5xxl_max_token_length
|
||||
|
||||
sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(t5xxl_max_token_length)
|
||||
sd3_tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length)
|
||||
strategy_base.TokenizeStrategy.set_strategy(sd3_tokenize_strategy)
|
||||
|
||||
# load clip_l, clip_g, t5xxl for caching text encoder outputs
|
||||
@@ -876,6 +871,9 @@ def train(args):
|
||||
lg_out = None
|
||||
t5_out = None
|
||||
lg_pooled = None
|
||||
l_attn_mask = None
|
||||
g_attn_mask = None
|
||||
t5_attn_mask = None
|
||||
|
||||
if lg_out is None:
|
||||
# not cached or training, so get from text encoders
|
||||
@@ -885,7 +883,7 @@ def train(args):
|
||||
# text models in sd3_models require "cpu" for input_ids
|
||||
input_ids_clip_l = input_ids_clip_l.to("cpu")
|
||||
input_ids_clip_g = input_ids_clip_g.to("cpu")
|
||||
lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens(
|
||||
lg_out, _, lg_pooled, l_attn_mask, g_attn_mask, _ = text_encoding_strategy.encode_tokens(
|
||||
sd3_tokenize_strategy,
|
||||
[clip_l, clip_g, None],
|
||||
[input_ids_clip_l, input_ids_clip_g, None, l_attn_mask, g_attn_mask, None],
|
||||
@@ -895,7 +893,7 @@ def train(args):
|
||||
_, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"]
|
||||
with torch.no_grad():
|
||||
input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None
|
||||
_, t5_out, _ = text_encoding_strategy.encode_tokens(
|
||||
_, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens(
|
||||
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask]
|
||||
)
|
||||
|
||||
@@ -1104,22 +1102,6 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--use_t5xxl_cache_only", action="store_true", help="cache T5-XXL outputs only / T5-XXLの出力のみキャッシュする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl_max_token_length",
|
||||
type=int,
|
||||
default=None,
|
||||
help="maximum token length for T5-XXL. 256 if omitted / T5-XXLの最大トークン数。省略時は256",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply_lg_attn_mask",
|
||||
action="store_true",
|
||||
help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--apply_t5_attn_mask",
|
||||
action="store_true",
|
||||
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--learning_rate_te1",
|
||||
|
||||
Reference in New Issue
Block a user