mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Support multiple additional networks
This commit is contained in:
@@ -1,38 +1,3 @@
|
|||||||
# txt2img with Diffusers: supports SD checkpoints, EulerScheduler, clip-skip, 225 tokens, Hypernetwork etc...
|
|
||||||
|
|
||||||
# v2: CLIP guided Stable Diffusion, Image guided Stable Diffusion, highres. fix
|
|
||||||
# v3: Add dpmsolver/dpmsolver++, add VAE loading, add upscale, add 'bf16', fix the issue network_mul is not working
|
|
||||||
# v4: SD2.0 support (new U-Net/text encoder/tokenizer), simplify by DiffUsers 0.9.0, no_preview in interactive mode
|
|
||||||
# v5: fix clip_sample=True for scheduler, add VGG guidance
|
|
||||||
# v6: refactor to use model util, load VAE without vae folder, support safe tensors
|
|
||||||
# v7: add use_original_file_name and iter_same_seed option, change vgg16 guide input image size,
|
|
||||||
# Diffusers 0.10.0 (support new schedulers (dpm_2, dpm_2_a, heun, dpmsingle), supports all scheduler in v-prediction)
|
|
||||||
# v8: accept wildcard for ckpt name (when only one file is matched), fix a bug app crushes because PIL image doesn't have filename attr sometimes,
|
|
||||||
# v9: sort file names, fix an issue in img2img when prompt from metadata with images_per_prompt>1
|
|
||||||
# v10: fix app crashes when different image size in prompts
|
|
||||||
|
|
||||||
# Copyright 2022 kohya_ss @kohya_ss
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# license of included scripts:
|
|
||||||
|
|
||||||
# FlashAttention: based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
|
||||||
# MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
|
||||||
|
|
||||||
# Diffusers (model conversion, CLIP guided stable diffusion, schedulers etc.):
|
|
||||||
# ASL 2.0 https://github.com/huggingface/diffusers/blob/main/LICENSE
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
VGG(
|
VGG(
|
||||||
(features): Sequential(
|
(features): Sequential(
|
||||||
@@ -517,7 +482,7 @@ class PipelineLike():
|
|||||||
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
||||||
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
||||||
|
|
||||||
# region xformersとか使う部分:独自に書き換えるので関係なし
|
# region xformersとか使う部分:独自に書き換えるので関係なし
|
||||||
def enable_xformers_memory_efficient_attention(self):
|
def enable_xformers_memory_efficient_attention(self):
|
||||||
r"""
|
r"""
|
||||||
Enable memory efficient attention as implemented in xformers.
|
Enable memory efficient attention as implemented in xformers.
|
||||||
@@ -1982,26 +1947,42 @@ def main(args):
|
|||||||
vgg16_model.to(dtype).to(device)
|
vgg16_model.to(dtype).to(device)
|
||||||
|
|
||||||
# networkを組み込む
|
# networkを組み込む
|
||||||
if args.network_module is not None:
|
if args.network_module:
|
||||||
# assert not args.diffusers_xformers, "cannot use network with diffusers_xformers / diffusers_xformers指定時はnetworkは利用できません"
|
networks = []
|
||||||
|
for i, network_module in enumerate(args.network_module):
|
||||||
|
print("import network module:", network_module)
|
||||||
|
imported_module = importlib.import_module(network_module)
|
||||||
|
|
||||||
print("import network module:", args.network_module)
|
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||||
network_module = importlib.import_module(args.network_module)
|
network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i]
|
||||||
|
|
||||||
network = network_module.create_network(args.network_mul, args.network_dim, vae,text_encoder, unet) # , **net_kwargs)
|
net_kwargs = {}
|
||||||
|
if args.network_args and i < len(args.network_args):
|
||||||
|
network_args = args.network_args[i]
|
||||||
|
# TODO escape special chars
|
||||||
|
network_args = network_args.split(";")
|
||||||
|
for net_arg in network_args:
|
||||||
|
key, value = net_arg.split("=")
|
||||||
|
net_kwargs[key] = value
|
||||||
|
|
||||||
|
network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs)
|
||||||
if network is None:
|
if network is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
print("load network weights from:", args.network_weights)
|
if args.network_weights and i < len(args.network_weights):
|
||||||
network.load_weights(args.network_weights)
|
network_weight = args.network_weights[i]
|
||||||
|
print("load network weights from:", network_weight)
|
||||||
|
network.load_weights(network_weight)
|
||||||
|
|
||||||
network.apply_to(text_encoder, unet)
|
network.apply_to(text_encoder, unet)
|
||||||
|
|
||||||
if args.opt_channels_last:
|
if args.opt_channels_last:
|
||||||
network.to(memory_format=torch.channels_last)
|
network.to(memory_format=torch.channels_last)
|
||||||
network.to(dtype).to(device)
|
network.to(dtype).to(device)
|
||||||
|
|
||||||
|
networks.append(network)
|
||||||
else:
|
else:
|
||||||
network = None
|
networks = []
|
||||||
|
|
||||||
if args.opt_channels_last:
|
if args.opt_channels_last:
|
||||||
print(f"set optimizing: channels last")
|
print(f"set optimizing: channels last")
|
||||||
@@ -2010,7 +1991,8 @@ def main(args):
|
|||||||
unet.to(memory_format=torch.channels_last)
|
unet.to(memory_format=torch.channels_last)
|
||||||
if clip_model is not None:
|
if clip_model is not None:
|
||||||
clip_model.to(memory_format=torch.channels_last)
|
clip_model.to(memory_format=torch.channels_last)
|
||||||
if network is not None:
|
if networks:
|
||||||
|
for network in networks:
|
||||||
network.to(memory_format=torch.channels_last)
|
network.to(memory_format=torch.channels_last)
|
||||||
if vgg16_model is not None:
|
if vgg16_model is not None:
|
||||||
vgg16_model.to(memory_format=torch.channels_last)
|
vgg16_model.to(memory_format=torch.channels_last)
|
||||||
@@ -2481,19 +2463,24 @@ if __name__ == '__main__':
|
|||||||
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
||||||
parser.add_argument("--seed", type=int, default=None,
|
parser.add_argument("--seed", type=int, default=None,
|
||||||
help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed")
|
help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed")
|
||||||
parser.add_argument("--iter_same_seed", action='store_true', help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)')
|
parser.add_argument("--iter_same_seed", action='store_true',
|
||||||
|
help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)')
|
||||||
parser.add_argument("--fp16", action='store_true', help='use fp16 / fp16を指定し省メモリ化する')
|
parser.add_argument("--fp16", action='store_true', help='use fp16 / fp16を指定し省メモリ化する')
|
||||||
parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する')
|
parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する')
|
||||||
parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する')
|
parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する')
|
||||||
parser.add_argument("--diffusers_xformers", action='store_true',
|
parser.add_argument("--diffusers_xformers", action='store_true',
|
||||||
help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)')
|
help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)')
|
||||||
parser.add_argument("--opt_channels_last", action='store_true',
|
parser.add_argument("--opt_channels_last", action='store_true',
|
||||||
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
help='set channels last option to model / モデルにchannles lastを指定し最適化する')
|
||||||
parser.add_argument("--network_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
|
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
||||||
parser.add_argument("--network_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み')
|
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
|
||||||
parser.add_argument("--network_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
||||||
parser.add_argument("--network_dim", type=int, default=None,
|
help='Hypernetwork weights to load / Hypernetworkの重み')
|
||||||
|
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
||||||
|
parser.add_argument("--network_dim", type=int, default=None, nargs='*',
|
||||||
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
||||||
|
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||||
|
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||||
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
||||||
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
|
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
|
||||||
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')
|
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')
|
||||||
|
|||||||
Reference in New Issue
Block a user