Update accel_sdxl_gen_img.py

This commit is contained in:
DKnight54
2025-01-25 01:20:10 +08:00
committed by GitHub
parent f1fc65e61d
commit 8a3548a380

View File

@@ -1487,9 +1487,9 @@ def main(args):
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
# モデルを読み込む
logger.info("preparing accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
logger.info("preparing pipes")
#accelerator = train_util.prepare_accelerator(args)
#is_main_process = accelerator.is_main_process
distributed_state = PartialState()
device = distributed_state.device
@@ -1499,13 +1499,12 @@ def main(args):
args.ckpt = files[0]
#device = get_preferred_device()
logger.info(f"preferred device: {device}")
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util.load_target_model(
args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
model_dtype = sdxl_train_util.match_mixed_precision(args, dtype)
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device if args.lowram else "cpu", model_dtype
)
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
text_encoder1.to(dtype).to(device)
text_encoder2.to(dtype).to(device)
unet.to(dtype).to(device)
# xformers、Hypernetwork対応
if not args.diffusers_xformers:
mem_eff = not (args.xformers or args.sdpa)
@@ -1662,6 +1661,9 @@ def main(args):
vae.to(vae_dtype).to(device)
vae.eval()
text_encoder1.to(dtype).to(device)
text_encoder2.to(dtype).to(device)
unet.to(dtype).to(device)
text_encoder1.eval()
text_encoder2.eval()
unet.eval()
@@ -2861,9 +2863,9 @@ def main(args):
def setup_parser() -> argparse.ArgumentParser:
parser = sdxl_train_network.setup_parser()
parser = argparse.ArgumentParser()
#sdxl_train_util.add_sdxl_training_arguments(parser)
#add_logging_arguments(parser)
add_logging_arguments(parser)
parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト")
parser.add_argument(