support sdpa

This commit is contained in:
ykume
2023-06-11 21:26:15 +09:00
parent 4d0c06e397
commit 9e1683cf2b
9 changed files with 177 additions and 84 deletions

View File

@@ -160,7 +160,7 @@ def train(args):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 差分追加学習のためにモデルを読み込む
import sys