make to work with PyTorch 1.12

This commit is contained in:
Kohya S
2023-07-20 21:41:16 +09:00
parent 86a8cbd002
commit acf16c063a
6 changed files with 13 additions and 8 deletions

View File

@@ -104,7 +104,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
else:
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
vae.set_use_memory_efficient_attention_xformers(args.xformers)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()