diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index ba853bf2..f325ecd6 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1491,11 +1491,14 @@ def main(args): if len(files) == 1: args.ckpt = files[0] device = get_preferred_device() + logger.info(f"preferred device: {device}") (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device ) unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) - + text_encoder1.to(dtype).to(device) + text_encoder2.to(dtype).to(device) + unet.to(dtype).to(device) # xformers、Hypernetwork対応 if not args.diffusers_xformers: mem_eff = not (args.xformers or args.sdpa) @@ -1649,12 +1652,9 @@ def main(args): if args.no_half_vae: logger.info("set vae_dtype to float32") vae_dtype = torch.float32 - #vae.to(vae_dtype).to(device) + vae.to(vae_dtype).to(device) vae.eval() - #text_encoder1.to(dtype).to(device) - #text_encoder2.to(dtype).to(device) - #unet.to(dtype).to(device) text_encoder1.eval() text_encoder2.eval() unet.eval()