Fix dtype issue in PyTorch 2.0 for generating samples in training sdxl network

This commit is contained in:
Kohya S
2024-01-22 12:43:37 +09:00
parent e0a3c69223
commit 696dd7f668

View File

@@ -925,7 +925,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
unet_dtype = self.unet.dtype
dtype = unet_dtype
if dtype.itemsize == 1: # fp8
if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8
dtype = torch.float16
self.unet.to(dtype)