fix error on pool_workaround in sdxl TE training ref #994

This commit is contained in:
Kohya S
2023-12-10 09:18:33 +09:00
parent 912dca8f65
commit 42750f7846
4 changed files with 30 additions and 25 deletions

View File

@@ -505,6 +505,7 @@ def train(args):
# else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
# unwrap_model is fine for models not wrapped by accelerator
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
@@ -514,6 +515,7 @@ def train(args):
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)