From d65f46c297cf4c7d56cf12b5668ef3a595ca2b53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=95=B7=E5=B6=8B=E5=A4=A7=E5=9C=B0?= Date: Thu, 14 Mar 2024 21:10:20 +0900 Subject: [PATCH] =?UTF-8?q?diffusers=E5=BD=A2=E5=BC=8F=E3=81=AEbf16?= =?UTF-8?q?=E3=81=AE=E3=83=A2=E3=83=87=E3=83=AB=E3=82=92=E3=83=AD=E3=83=BC?= =?UTF-8?q?=E3=83=89=E3=81=A7=E3=81=8D=E3=82=8B=E3=82=88=E3=81=86=E3=81=AB?= =?UTF-8?q?=E3=81=99=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- library/sdxl_train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 1932bf88..05ba7f32 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -81,7 +81,7 @@ def _load_target_model( # Diffusers model is loaded to CPU from diffusers import StableDiffusionXLPipeline - variant = "fp16" if weight_dtype == torch.float16 else None + variant = "fp16" if weight_dtype == torch.float16 else "bf16" if weight_dtype == torch.bfloat16 else None logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") try: try: