Update sdxl_gen_img.py

This commit is contained in:
DKnight54
2025-01-24 00:30:06 +08:00
committed by GitHub
parent 8daa8b3283
commit 6231883ef6

View File

@@ -1491,11 +1491,14 @@ def main(args):
if len(files) == 1:
args.ckpt = files[0]
device = get_preferred_device()
logger.info(f"preferred device: {device}")
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device
)
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
text_encoder1.to(dtype).to(device)
text_encoder2.to(dtype).to(device)
unet.to(dtype).to(device)
# xformers、Hypernetwork対応
if not args.diffusers_xformers:
mem_eff = not (args.xformers or args.sdpa)
@@ -1649,12 +1652,9 @@ def main(args):
if args.no_half_vae:
logger.info("set vae_dtype to float32")
vae_dtype = torch.float32
#vae.to(vae_dtype).to(device)
vae.to(vae_dtype).to(device)
vae.eval()
#text_encoder1.to(dtype).to(device)
#text_encoder2.to(dtype).to(device)
#unet.to(dtype).to(device)
text_encoder1.eval()
text_encoder2.eval()
unet.eval()