fix error on pool_workaround in sdxl TE training ref #994

This commit is contained in:
Kohya S
2023-12-10 09:18:33 +09:00
parent 912dca8f65
commit 42750f7846
4 changed files with 30 additions and 25 deletions

View File

@@ -1,9 +1,12 @@
import argparse
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -123,6 +126,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)