Add LoRA training/generating.

This commit is contained in:
Kohya S
2022-12-25 21:34:59 +09:00
parent 96d695dd83
commit 445b34de1f
4 changed files with 1832 additions and 31 deletions

View File

@@ -1,7 +1,7 @@
# txt2img with Diffusers: supports SD checkpoints, EulerScheduler, clip-skip, 225 tokens, Hypernetwork etc...
# v2: CLIP guided Stable Diffusion, Image guided Stable Diffusion, highres. fix
# v3: Add dpmsolver/dpmsolver++, add VAE loading, add upscale, add 'bf16', fix the issue hypernetwork_mul is not working
# v3: Add dpmsolver/dpmsolver++, add VAE loading, add upscale, add 'bf16', fix the issue network_mul is not working
# v4: SD2.0 support (new U-Net/text encoder/tokenizer), simplify by DiffUsers 0.9.0, no_preview in interactive mode
# v5: fix clip_sample=True for scheduler, add VGG guidance
# v6: refactor to use model util, load VAE without vae folder, support safe tensors
@@ -333,7 +333,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
def replace_unet_cross_attn_to_memory_efficient():
print("Replace CrossAttention.forward to use Hypernetwork and FlashAttention")
print("Replace CrossAttention.forward to use NAI style Hypernetwork and FlashAttention")
flash_func = FlashAttentionFunction
def forward_flash_attn(self, x, context=None, mask=None):
@@ -373,7 +373,7 @@ def replace_unet_cross_attn_to_memory_efficient():
def replace_unet_cross_attn_to_xformers():
print("Replace CrossAttention.forward to use Hypernetwork and xformers")
print("Replace CrossAttention.forward to use NAI style Hypernetwork and xformers")
try:
import xformers.ops
except ImportError:
@@ -1867,25 +1867,6 @@ def main(args):
if not args.diffusers_xformers:
replace_unet_modules(unet, not args.xformers, args.xformers)
# hypernetworkを組み込む
if args.hypernetwork_module is not None:
assert not args.diffusers_xformers, "cannot use hypernetwork with diffusers_xformers / diffusers_xformers指定時はHypernetworkは利用できません"
print("import hypernetwork module:", args.hypernetwork_module)
hyp_module = importlib.import_module(args.hypernetwork_module)
hypernetwork = hyp_module.Hypernetwork(args.hypernetwork_mul)
print("load hypernetwork weights from:", args.hypernetwork_weights)
hyp_sd = torch.load(args.hypernetwork_weights, map_location='cpu')
success = hypernetwork.load_from_state_dict(hyp_sd)
assert success, "hypernetwork weights loading failed."
if args.opt_channels_last:
hypernetwork.to(memory_format=torch.channels_last)
else:
hypernetwork = None
# tokenizerを読み込む
print("loading tokenizer")
if use_stable_diffusion_format:
@@ -2000,10 +1981,27 @@ def main(args):
if vgg16_model is not None:
vgg16_model.to(dtype).to(device)
if hypernetwork is not None:
hypernetwork.to(dtype).to(device)
print("apply hypernetwork")
hypernetwork.apply_to_diffusers(vae, text_encoder, unet)
# networkを組み込む
if args.network_module is not None:
# assert not args.diffusers_xformers, "cannot use network with diffusers_xformers / diffusers_xformers指定時はnetworkは利用できません"
print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
network = network_module.create_network(args.network_mul, args.network_dim, vae,text_encoder, unet) # , **net_kwargs)
if network is None:
return
print("load network weights from:", args.network_weights)
network.load_weights(args.network_weights)
network.apply_to(text_encoder, unet)
if args.opt_channels_last:
network.to(memory_format=torch.channels_last)
network.to(dtype).to(device)
else:
network = None
if args.opt_channels_last:
print(f"set optimizing: channels last")
@@ -2012,8 +2010,8 @@ def main(args):
unet.to(memory_format=torch.channels_last)
if clip_model is not None:
clip_model.to(memory_format=torch.channels_last)
if hypernetwork is not None:
hypernetwork.to(memory_format=torch.channels_last)
if network is not None:
network.to(memory_format=torch.channels_last)
if vgg16_model is not None:
vgg16_model.to(memory_format=torch.channels_last)
@@ -2491,9 +2489,11 @@ if __name__ == '__main__':
help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用するHypernetwork利用不可')
parser.add_argument("--opt_channels_last", action='store_true',
help='set channels last option to model / モデルにchannles lastを指定し最適化する')
parser.add_argument("--hypernetwork_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
parser.add_argument("--hypernetwork_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み')
parser.add_argument("--hypernetwork_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
parser.add_argument("--network_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
parser.add_argument("--network_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み')
parser.add_argument("--network_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
parser.add_argument("--network_dim", type=int, default=None,
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')