fix pipeline dtype

This commit is contained in:
Isotr0py
2023-07-27 23:16:58 +08:00
parent eec6aaddda
commit 50544b7805

View File

@@ -74,7 +74,7 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
try:
try:
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=variant, tokenizer=None)
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, torch_dtype=weight_dtype, variant=variant, tokenizer=None)
except EnvironmentError as ex:
if variant is not None:
print("try to load fp32 model")