From 3bb80ebf20e8d1cc2f2d29789d258f8ca75976e0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 13 Jul 2023 19:02:34 +0900 Subject: [PATCH] fix sampling gen fails in lora training --- library/sdxl_lpw_stable_diffusion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index 99b0bc8d..a65a1d96 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -1005,6 +1005,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: # predict the noise residual noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) + noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training # perform guidance if do_classifier_free_guidance: