mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add LoRA training/generating.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# txt2img with Diffusers: supports SD checkpoints, EulerScheduler, clip-skip, 225 tokens, Hypernetwork etc...
|
||||
|
||||
# v2: CLIP guided Stable Diffusion, Image guided Stable Diffusion, highres. fix
|
||||
# v3: Add dpmsolver/dpmsolver++, add VAE loading, add upscale, add 'bf16', fix the issue hypernetwork_mul is not working
|
||||
# v3: Add dpmsolver/dpmsolver++, add VAE loading, add upscale, add 'bf16', fix the issue network_mul is not working
|
||||
# v4: SD2.0 support (new U-Net/text encoder/tokenizer), simplify by DiffUsers 0.9.0, no_preview in interactive mode
|
||||
# v5: fix clip_sample=True for scheduler, add VGG guidance
|
||||
# v6: refactor to use model util, load VAE without vae folder, support safe tensors
|
||||
@@ -333,7 +333,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_memory_efficient():
|
||||
print("Replace CrossAttention.forward to use Hypernetwork and FlashAttention")
|
||||
print("Replace CrossAttention.forward to use NAI style Hypernetwork and FlashAttention")
|
||||
flash_func = FlashAttentionFunction
|
||||
|
||||
def forward_flash_attn(self, x, context=None, mask=None):
|
||||
@@ -373,7 +373,7 @@ def replace_unet_cross_attn_to_memory_efficient():
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_xformers():
|
||||
print("Replace CrossAttention.forward to use Hypernetwork and xformers")
|
||||
print("Replace CrossAttention.forward to use NAI style Hypernetwork and xformers")
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
@@ -1867,25 +1867,6 @@ def main(args):
|
||||
if not args.diffusers_xformers:
|
||||
replace_unet_modules(unet, not args.xformers, args.xformers)
|
||||
|
||||
# hypernetworkを組み込む
|
||||
if args.hypernetwork_module is not None:
|
||||
assert not args.diffusers_xformers, "cannot use hypernetwork with diffusers_xformers / diffusers_xformers指定時はHypernetworkは利用できません"
|
||||
|
||||
print("import hypernetwork module:", args.hypernetwork_module)
|
||||
hyp_module = importlib.import_module(args.hypernetwork_module)
|
||||
|
||||
hypernetwork = hyp_module.Hypernetwork(args.hypernetwork_mul)
|
||||
|
||||
print("load hypernetwork weights from:", args.hypernetwork_weights)
|
||||
hyp_sd = torch.load(args.hypernetwork_weights, map_location='cpu')
|
||||
success = hypernetwork.load_from_state_dict(hyp_sd)
|
||||
assert success, "hypernetwork weights loading failed."
|
||||
|
||||
if args.opt_channels_last:
|
||||
hypernetwork.to(memory_format=torch.channels_last)
|
||||
else:
|
||||
hypernetwork = None
|
||||
|
||||
# tokenizerを読み込む
|
||||
print("loading tokenizer")
|
||||
if use_stable_diffusion_format:
|
||||
@@ -2000,10 +1981,27 @@ def main(args):
|
||||
if vgg16_model is not None:
|
||||
vgg16_model.to(dtype).to(device)
|
||||
|
||||
if hypernetwork is not None:
|
||||
hypernetwork.to(dtype).to(device)
|
||||
print("apply hypernetwork")
|
||||
hypernetwork.apply_to_diffusers(vae, text_encoder, unet)
|
||||
# networkを組み込む
|
||||
if args.network_module is not None:
|
||||
# assert not args.diffusers_xformers, "cannot use network with diffusers_xformers / diffusers_xformers指定時はnetworkは利用できません"
|
||||
|
||||
print("import network module:", args.network_module)
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
|
||||
network = network_module.create_network(args.network_mul, args.network_dim, vae,text_encoder, unet) # , **net_kwargs)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
print("load network weights from:", args.network_weights)
|
||||
network.load_weights(args.network_weights)
|
||||
|
||||
network.apply_to(text_encoder, unet)
|
||||
|
||||
if args.opt_channels_last:
|
||||
network.to(memory_format=torch.channels_last)
|
||||
network.to(dtype).to(device)
|
||||
else:
|
||||
network = None
|
||||
|
||||
if args.opt_channels_last:
|
||||
print(f"set optimizing: channels last")
|
||||
@@ -2012,8 +2010,8 @@ def main(args):
|
||||
unet.to(memory_format=torch.channels_last)
|
||||
if clip_model is not None:
|
||||
clip_model.to(memory_format=torch.channels_last)
|
||||
if hypernetwork is not None:
|
||||
hypernetwork.to(memory_format=torch.channels_last)
|
||||
if network is not None:
|
||||
network.to(memory_format=torch.channels_last)
|
||||
if vgg16_model is not None:
|
||||
vgg16_model.to(memory_format=torch.channels_last)
|
||||
|
||||
@@ -2491,9 +2489,11 @@ if __name__ == '__main__':
|
||||
help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)')
|
||||
parser.add_argument("--opt_channels_last", action='store_true',
|
||||
help='set channels last option to model / モデルにchannles lastを指定し最適化する')
|
||||
parser.add_argument("--hypernetwork_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
|
||||
parser.add_argument("--hypernetwork_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み')
|
||||
parser.add_argument("--hypernetwork_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
||||
parser.add_argument("--network_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
|
||||
parser.add_argument("--network_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み')
|
||||
parser.add_argument("--network_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
||||
parser.add_argument("--network_dim", type=int, default=None,
|
||||
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
||||
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
||||
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
|
||||
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')
|
||||
|
||||
Reference in New Issue
Block a user