Update sdxl_train_util.py

Try force stay FP16 on text encoders
This commit is contained in:
DKnight54
2025-01-23 23:41:49 +08:00
committed by GitHub
parent 6e3c1d0b58
commit fd3a14c9d7

View File

@@ -107,10 +107,10 @@ def _load_target_model(
text_encoder2 = pipe.text_encoder_2
# convert to fp32 for cache text_encoders outputs
if text_encoder1.dtype != torch.float32:
text_encoder1 = text_encoder1.to(dtype=torch.float32)
if text_encoder2.dtype != torch.float32:
text_encoder2 = text_encoder2.to(dtype=torch.float32)
#if text_encoder1.dtype != torch.float32:
# text_encoder1 = text_encoder1.to(dtype=torch.float32)
#if text_encoder2.dtype != torch.float32:
# text_encoder2 = text_encoder2.to(dtype=torch.float32)
vae = pipe.vae
unet = pipe.unet