Update accel_sdxl_gen_img.py

draft accel image gen
This commit is contained in:
DKnight54
2025-01-24 17:39:01 +08:00
committed by GitHub
parent 8f85024917
commit 8b241c4c8b

View File

@@ -1486,15 +1486,20 @@ def main(args):
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
# モデルを読み込む
logger.info("preparing accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
distributed_state = PartialState()
device = distributed_state.device
if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う
files = glob.glob(args.ckpt)
if len(files) == 1:
args.ckpt = files[0]
device = get_preferred_device()
#device = get_preferred_device()
logger.info(f"preferred device: {device}")
model_dtype = sdxl_train_util.match_mixed_precision(args, dtype)
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device, model_dtype
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util.load_target_model(
args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
)
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
text_encoder1.to(dtype).to(device)
@@ -1819,7 +1824,7 @@ def main(args):
args.clip_skip,
)
pipe.set_control_nets(control_nets)
logger.info("pipeline is ready.")
logger.info(f"pipeline on {device} is ready.")
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()