diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index e559e718..8f15e2c4 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -83,7 +83,7 @@ def _load_target_model( # Diffusers model is loaded to CPU from diffusers import StableDiffusionXLPipeline - variant = "fp16" if weight_dtype == torch.float16 else None + variant = "fp16" if weight_dtype == torch.float16 else "bf16" if weight_dtype == torch.bfloat16 else None logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") try: try: