support sdpa

This commit is contained in:
ykume
2023-06-11 21:26:15 +09:00
parent 4d0c06e397
commit 9e1683cf2b
9 changed files with 177 additions and 84 deletions

View File

@@ -141,7 +141,7 @@ def train(args):
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
print("Disable Diffusers' xformers")
set_diffusers_xformers_flag(unet, False)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 学習を準備する
if cache_latents: