mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 08:21:46 +00:00
fix sd3 training to work without cachine TE outputs #1465
This commit is contained in:
@@ -759,8 +759,9 @@ def train(args):
|
||||
input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"]
|
||||
with torch.set_grad_enabled(args.train_text_encoder):
|
||||
# TODO support weighted captions
|
||||
input_ids_clip_l = input_ids_clip_l.to(accelerator.device)
|
||||
input_ids_clip_g = input_ids_clip_g.to(accelerator.device)
|
||||
# text models in sd3_models require "cpu" for input_ids
|
||||
input_ids_clip_l = input_ids_clip_l.to("cpu")
|
||||
input_ids_clip_g = input_ids_clip_g.to("cpu")
|
||||
lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens(
|
||||
sd3_tokenize_strategy,
|
||||
[clip_l, clip_g, None],
|
||||
@@ -770,7 +771,7 @@ def train(args):
|
||||
if t5_out is None:
|
||||
_, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"]
|
||||
with torch.no_grad():
|
||||
input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None
|
||||
input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None
|
||||
_, t5_out, _ = text_encoding_strategy.encode_tokens(
|
||||
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user