Update accel_sdxl_gen_img.py

This commit is contained in:
DKnight54
2025-02-12 11:43:05 +08:00
committed by GitHub
parent 586e89a3ab
commit b8d3c687df

View File

@@ -1513,11 +1513,12 @@ def main(args):
logger.info(f"preferred device: {device}, {distributed_state.is_main_process}")
clean_memory_on_device(device)
model_dtype = sdxl_train_util.match_mixed_precision(args, dtype)
logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}")
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype
)
for pi in range(distributed_state.state.num_processes):
if pi == distributed_state.state.local_process_index:
logger.info(f"loading model for process {distributed_state.local_process_index+1}/{distributed_state.num_processes}")
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype
)
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
distributed_state.wait_for_everyone()