From 696dd7f66802091f973ef0ce5ffa1e8002e90789 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 22 Jan 2024 12:43:37 +0900 Subject: [PATCH] Fix dtype issue in PyTorch 2.0 for generating samples in training sdxl network --- library/sdxl_lpw_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index 0562e88a..03b18256 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -925,7 +925,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: unet_dtype = self.unet.dtype dtype = unet_dtype - if dtype.itemsize == 1: # fp8 + if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8 dtype = torch.float16 self.unet.to(dtype)