From b8d3c687df25c735e157e9451bc384bc08e6fc51 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Wed, 12 Feb 2025 11:43:05 +0800 Subject: [PATCH] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index d6af05ad..586a59ea 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -1513,11 +1513,12 @@ def main(args): logger.info(f"preferred device: {device}, {distributed_state.is_main_process}") clean_memory_on_device(device) model_dtype = sdxl_train_util.match_mixed_precision(args, dtype) - - logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}") - (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype - ) + for pi in range(distributed_state.state.num_processes): + if pi == distributed_state.state.local_process_index: + logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}") + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype + ) unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet) distributed_state.wait_for_everyone()