support diffusers format for SDXL

This commit is contained in:
Kohya S
2023-07-12 21:57:14 +09:00
parent 8df948565a
commit 8fa5fb2816
4 changed files with 290 additions and 23 deletions

View File

@@ -171,7 +171,7 @@ def train(args):
# set_diffusers_xformers_flag(unet, True)
set_diffusers_xformers_flag(vae, True)
else:
# Windows版のxformersはfloatで学習できなかったりxformersを使わない設定も可能にしておく必要がある
# Windows版のxformersはfloatで学習できなかったりするのでxformersを使わない設定も可能にしておく必要がある
accelerator.print("Disable Diffusers' xformers")
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
vae.set_use_memory_efficient_attention_xformers(args.xformers)
@@ -271,7 +271,7 @@ def train(args):
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"