diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f78d9424..5176b536 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -84,7 +84,7 @@ def _load_target_model( # Diffusers model is loaded to CPU from diffusers import StableDiffusionXLPipeline - variant = "fp16" if weight_dtype == torch.float16 else None + variant = "fp16" if weight_dtype == torch.float16 else "bf16" if weight_dtype == torch.bfloat16 else None logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") try: try: