Support multiple additional networks

This commit is contained in:
Kohya S
2023-01-04 08:32:22 +09:00
parent bda0e8333c
commit 6d10233a53

View File

@@ -1,38 +1,3 @@
# txt2img with Diffusers: supports SD checkpoints, EulerScheduler, clip-skip, 225 tokens, Hypernetwork etc...
# v2: CLIP guided Stable Diffusion, Image guided Stable Diffusion, highres. fix
# v3: Add dpmsolver/dpmsolver++, add VAE loading, add upscale, add 'bf16', fix the issue network_mul is not working
# v4: SD2.0 support (new U-Net/text encoder/tokenizer), simplify by DiffUsers 0.9.0, no_preview in interactive mode
# v5: fix clip_sample=True for scheduler, add VGG guidance
# v6: refactor to use model util, load VAE without vae folder, support safe tensors
# v7: add use_original_file_name and iter_same_seed option, change vgg16 guide input image size,
# Diffusers 0.10.0 (support new schedulers (dpm_2, dpm_2_a, heun, dpmsingle), supports all scheduler in v-prediction)
# v8: accept wildcard for ckpt name (when only one file is matched), fix a bug app crushes because PIL image doesn't have filename attr sometimes,
# v9: sort file names, fix an issue in img2img when prompt from metadata with images_per_prompt>1
# v10: fix app crashes when different image size in prompts
# Copyright 2022 kohya_ss @kohya_ss
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# license of included scripts:
# FlashAttention: based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
# MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
# Diffusers (model conversion, CLIP guided stable diffusion, schedulers etc.):
# ASL 2.0 https://github.com/huggingface/diffusers/blob/main/LICENSE
""" """
VGG( VGG(
(features): Sequential( (features): Sequential(
@@ -1982,26 +1947,42 @@ def main(args):
vgg16_model.to(dtype).to(device) vgg16_model.to(dtype).to(device)
# networkを組み込む # networkを組み込む
if args.network_module is not None: if args.network_module:
# assert not args.diffusers_xformers, "cannot use network with diffusers_xformers / diffusers_xformers指定時はnetworkは利用できません" networks = []
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
print("import network module:", args.network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_module = importlib.import_module(args.network_module) network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i]
network = network_module.create_network(args.network_mul, args.network_dim, vae,text_encoder, unet) # , **net_kwargs) net_kwargs = {}
if args.network_args and i < len(args.network_args):
network_args = args.network_args[i]
# TODO escape special chars
network_args = network_args.split(";")
for net_arg in network_args:
key, value = net_arg.split("=")
net_kwargs[key] = value
network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs)
if network is None: if network is None:
return return
print("load network weights from:", args.network_weights) if args.network_weights and i < len(args.network_weights):
network.load_weights(args.network_weights) network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
network.load_weights(network_weight)
network.apply_to(text_encoder, unet) network.apply_to(text_encoder, unet)
if args.opt_channels_last: if args.opt_channels_last:
network.to(memory_format=torch.channels_last) network.to(memory_format=torch.channels_last)
network.to(dtype).to(device) network.to(dtype).to(device)
networks.append(network)
else: else:
network = None networks = []
if args.opt_channels_last: if args.opt_channels_last:
print(f"set optimizing: channels last") print(f"set optimizing: channels last")
@@ -2010,7 +1991,8 @@ def main(args):
unet.to(memory_format=torch.channels_last) unet.to(memory_format=torch.channels_last)
if clip_model is not None: if clip_model is not None:
clip_model.to(memory_format=torch.channels_last) clip_model.to(memory_format=torch.channels_last)
if network is not None: if networks:
for network in networks:
network.to(memory_format=torch.channels_last) network.to(memory_format=torch.channels_last)
if vgg16_model is not None: if vgg16_model is not None:
vgg16_model.to(memory_format=torch.channels_last) vgg16_model.to(memory_format=torch.channels_last)
@@ -2481,19 +2463,24 @@ if __name__ == '__main__':
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
parser.add_argument("--seed", type=int, default=None, parser.add_argument("--seed", type=int, default=None,
help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed") help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed")
parser.add_argument("--iter_same_seed", action='store_true', help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使うプロンプト間の差異の比較用') parser.add_argument("--iter_same_seed", action='store_true',
help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使うプロンプト間の差異の比較用')
parser.add_argument("--fp16", action='store_true', help='use fp16 / fp16を指定し省メモリ化する') parser.add_argument("--fp16", action='store_true', help='use fp16 / fp16を指定し省メモリ化する')
parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する') parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する')
parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する') parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する')
parser.add_argument("--diffusers_xformers", action='store_true', parser.add_argument("--diffusers_xformers", action='store_true',
help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用するHypernetwork利用不可') help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用するHypernetwork利用不可')
parser.add_argument("--opt_channels_last", action='store_true', parser.add_argument("--opt_channels_last", action='store_true',
help='set channels last option to model / モデルにchannels lastを指定し最適化する') help='set channels last option to model / モデルにchannles lastを指定し最適化する')
parser.add_argument("--network_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') parser.add_argument("--network_module", type=str, default=None, nargs='*',
parser.add_argument("--network_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み') help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
parser.add_argument("--network_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率') parser.add_argument("--network_weights", type=str, default=None, nargs='*',
parser.add_argument("--network_dim", type=int, default=None, help='Hypernetwork weights to load / Hypernetworkの重み')
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
parser.add_argument("--network_dim", type=int, default=None, nargs='*',
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
parser.add_argument("--max_embeddings_multiples", type=int, default=None, parser.add_argument("--max_embeddings_multiples", type=int, default=None,
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる') help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')