add FLUX.1 LoRA training

This commit is contained in:
Kohya S
2024-08-09 22:56:48 +09:00
parent da4d0fe016
commit 36b2e6fc28
10 changed files with 2992 additions and 55 deletions

View File

@@ -52,6 +52,11 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
self.logit_scale = logit_scale
self.ckpt_info = ckpt_info
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
def get_tokenize_strategy(self, args):