fix sd3 training to work without cachine TE outputs #1465

This commit is contained in:
Kohya S
2024-08-17 14:38:34 +09:00
parent e45d3f8634
commit 7367584e67

View File

@@ -759,8 +759,9 @@ def train(args):
input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"]
with torch.set_grad_enabled(args.train_text_encoder):
# TODO support weighted captions
input_ids_clip_l = input_ids_clip_l.to(accelerator.device)
input_ids_clip_g = input_ids_clip_g.to(accelerator.device)
# text models in sd3_models require "cpu" for input_ids
input_ids_clip_l = input_ids_clip_l.to("cpu")
input_ids_clip_g = input_ids_clip_g.to("cpu")
lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy,
[clip_l, clip_g, None],
@@ -770,7 +771,7 @@ def train(args):
if t5_out is None:
_, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"]
with torch.no_grad():
input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None
input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None
_, t5_out, _ = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask]
)