diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 1c4e1965..b74bea91 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -107,10 +107,10 @@ def _load_target_model( text_encoder2 = pipe.text_encoder_2 # convert to fp32 for cache text_encoders outputs - #if text_encoder1.dtype != torch.float32: - # text_encoder1 = text_encoder1.to(dtype=torch.float32) - #if text_encoder2.dtype != torch.float32: - # text_encoder2 = text_encoder2.to(dtype=torch.float32) + if text_encoder1.dtype != torch.float32: + text_encoder1 = text_encoder1.to(dtype=torch.float32) + if text_encoder2.dtype != torch.float32: + text_encoder2 = text_encoder2.to(dtype=torch.float32) vae = pipe.vae unet = pipe.unet