diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index bf6c8907..7948efaa 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -1486,15 +1486,20 @@ def main(args): # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" # モデルを読み込む + logger.info("preparing accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + distributed_state = PartialState() + device = distributed_state.device + if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う files = glob.glob(args.ckpt) if len(files) == 1: args.ckpt = files[0] - device = get_preferred_device() + #device = get_preferred_device() logger.info(f"preferred device: {device}") - model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) - (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device, model_dtype + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util.load_target_model( + args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype ) unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) text_encoder1.to(dtype).to(device) @@ -1819,7 +1824,7 @@ def main(args): args.clip_skip, ) pipe.set_control_nets(control_nets) - logger.info("pipeline is ready.") + logger.info(f"pipeline on {device} is ready.") if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention()