From 50544b78055b2b0f71a12d2af68081cfff581b9c Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 27 Jul 2023 23:16:58 +0800 Subject: [PATCH] fix pipeline dtype --- library/sdxl_train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index f37cadab..ecd2db96 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -74,7 +74,7 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}") try: try: - pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=variant, tokenizer=None) + pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, torch_dtype=weight_dtype, variant=variant, tokenizer=None) except EnvironmentError as ex: if variant is not None: print("try to load fp32 model")