From fd3a14c9d7d277165a06bd60af10a2a6ddb3e006 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Thu, 23 Jan 2025 23:41:49 +0800 Subject: [PATCH] Update sdxl_train_util.py Try force stay FP16 on text encoders --- library/sdxl_train_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index b74bea91..1c4e1965 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -107,10 +107,10 @@ def _load_target_model( text_encoder2 = pipe.text_encoder_2 # convert to fp32 for cache text_encoders outputs - if text_encoder1.dtype != torch.float32: - text_encoder1 = text_encoder1.to(dtype=torch.float32) - if text_encoder2.dtype != torch.float32: - text_encoder2 = text_encoder2.to(dtype=torch.float32) + #if text_encoder1.dtype != torch.float32: + # text_encoder1 = text_encoder1.to(dtype=torch.float32) + #if text_encoder2.dtype != torch.float32: + # text_encoder2 = text_encoder2.to(dtype=torch.float32) vae = pipe.vae unet = pipe.unet