mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
Update sdxl_gen_img.py
This commit is contained in:
@@ -1490,9 +1490,9 @@ def main(args):
|
||||
files = glob.glob(args.ckpt)
|
||||
if len(files) == 1:
|
||||
args.ckpt = files[0]
|
||||
gc.collect()
|
||||
device = get_preferred_device()
|
||||
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
|
||||
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
|
||||
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype, device
|
||||
)
|
||||
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
|
||||
|
||||
@@ -1622,7 +1622,7 @@ def main(args):
|
||||
# scheduler.config.clip_sample = True
|
||||
|
||||
# deviceを決定する
|
||||
device = get_preferred_device()
|
||||
|
||||
|
||||
# custom pipelineをコピったやつを生成する
|
||||
if args.vae_slices:
|
||||
@@ -1649,16 +1649,15 @@ def main(args):
|
||||
if args.no_half_vae:
|
||||
logger.info("set vae_dtype to float32")
|
||||
vae_dtype = torch.float32
|
||||
vae.to(vae_dtype).to(device)
|
||||
#vae.to(vae_dtype).to(device)
|
||||
vae.eval()
|
||||
|
||||
text_encoder1.to(dtype).to(device)
|
||||
text_encoder2.to(dtype).to(device)
|
||||
unet.to(dtype).to(device)
|
||||
#text_encoder1.to(dtype).to(device)
|
||||
#text_encoder2.to(dtype).to(device)
|
||||
#unet.to(dtype).to(device)
|
||||
text_encoder1.eval()
|
||||
text_encoder2.eval()
|
||||
unet.eval()
|
||||
gc.collect()
|
||||
# networkを組み込む
|
||||
if args.network_module:
|
||||
networks = []
|
||||
|
||||
Reference in New Issue
Block a user