From 8a3548a380afe34e0495e5164bb17b6dfc849dcd Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 25 Jan 2025 01:20:10 +0800 Subject: [PATCH] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 7ed35620..60c1b6e6 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -1487,9 +1487,9 @@ def main(args): # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" # モデルを読み込む - logger.info("preparing accelerator") - accelerator = train_util.prepare_accelerator(args) - is_main_process = accelerator.is_main_process + logger.info("preparing pipes") + #accelerator = train_util.prepare_accelerator(args) + #is_main_process = accelerator.is_main_process distributed_state = PartialState() device = distributed_state.device @@ -1499,13 +1499,12 @@ def main(args): args.ckpt = files[0] #device = get_preferred_device() logger.info(f"preferred device: {device}") - (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util.load_target_model( - args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype + model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype ) unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) - text_encoder1.to(dtype).to(device) - text_encoder2.to(dtype).to(device) - unet.to(dtype).to(device) + # xformers、Hypernetwork対応 if not args.diffusers_xformers: mem_eff = not (args.xformers or args.sdpa) @@ -1662,6 +1661,9 @@ def main(args): vae.to(vae_dtype).to(device) vae.eval() + text_encoder1.to(dtype).to(device) + text_encoder2.to(dtype).to(device) + unet.to(dtype).to(device) text_encoder1.eval() text_encoder2.eval() unet.eval() @@ -2861,9 +2863,9 @@ def main(args): def setup_parser() -> argparse.ArgumentParser: - parser = sdxl_train_network.setup_parser() + parser = argparse.ArgumentParser() #sdxl_train_util.add_sdxl_training_arguments(parser) - #add_logging_arguments(parser) + add_logging_arguments(parser) parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") parser.add_argument(