diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 75f14afa..208b1b70 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -1,38 +1,3 @@ -# txt2img with Diffusers: supports SD checkpoints, EulerScheduler, clip-skip, 225 tokens, Hypernetwork etc... - -# v2: CLIP guided Stable Diffusion, Image guided Stable Diffusion, highres. fix -# v3: Add dpmsolver/dpmsolver++, add VAE loading, add upscale, add 'bf16', fix the issue network_mul is not working -# v4: SD2.0 support (new U-Net/text encoder/tokenizer), simplify by DiffUsers 0.9.0, no_preview in interactive mode -# v5: fix clip_sample=True for scheduler, add VGG guidance -# v6: refactor to use model util, load VAE without vae folder, support safe tensors -# v7: add use_original_file_name and iter_same_seed option, change vgg16 guide input image size, -# Diffusers 0.10.0 (support new schedulers (dpm_2, dpm_2_a, heun, dpmsingle), supports all scheduler in v-prediction) -# v8: accept wildcard for ckpt name (when only one file is matched), fix a bug app crushes because PIL image doesn't have filename attr sometimes, -# v9: sort file names, fix an issue in img2img when prompt from metadata with images_per_prompt>1 -# v10: fix app crashes when different image size in prompts - -# Copyright 2022 kohya_ss @kohya_ss -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# license of included scripts: - -# FlashAttention: based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py -# MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE - -# Diffusers (model conversion, CLIP guided stable diffusion, schedulers etc.): -# ASL 2.0 https://github.com/huggingface/diffusers/blob/main/LICENSE - """ VGG( (features): Sequential( @@ -517,7 +482,7 @@ class PipelineLike(): self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers) self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD) -# region xformersとか使う部分:独自に書き換えるので関係なし + # region xformersとか使う部分:独自に書き換えるので関係なし def enable_xformers_memory_efficient_attention(self): r""" Enable memory efficient attention as implemented in xformers. @@ -1982,26 +1947,42 @@ def main(args): vgg16_model.to(dtype).to(device) # networkを組み込む - if args.network_module is not None: - # assert not args.diffusers_xformers, "cannot use network with diffusers_xformers / diffusers_xformers指定時はnetworkは利用できません" + if args.network_module: + networks = [] + for i, network_module in enumerate(args.network_module): + print("import network module:", network_module) + imported_module = importlib.import_module(network_module) - print("import network module:", args.network_module) - network_module = importlib.import_module(args.network_module) + network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i] - network = network_module.create_network(args.network_mul, args.network_dim, vae,text_encoder, unet) # , **net_kwargs) - if network is None: - return + net_kwargs = {} + if args.network_args and i < len(args.network_args): + network_args = args.network_args[i] + # TODO escape special chars + network_args = network_args.split(";") + for net_arg in network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value - print("load network weights from:", args.network_weights) - network.load_weights(args.network_weights) + network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs) + if network is None: + return - network.apply_to(text_encoder, unet) + if args.network_weights and i < len(args.network_weights): + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) + network.load_weights(network_weight) - if args.opt_channels_last: - network.to(memory_format=torch.channels_last) - network.to(dtype).to(device) + network.apply_to(text_encoder, unet) + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + networks.append(network) else: - network = None + networks = [] if args.opt_channels_last: print(f"set optimizing: channels last") @@ -2010,8 +1991,9 @@ def main(args): unet.to(memory_format=torch.channels_last) if clip_model is not None: clip_model.to(memory_format=torch.channels_last) - if network is not None: - network.to(memory_format=torch.channels_last) + if networks: + for network in networks: + network.to(memory_format=torch.channels_last) if vgg16_model is not None: vgg16_model.to(memory_format=torch.channels_last) @@ -2053,7 +2035,7 @@ def main(args): print(f"convert image to RGB from {image.mode}: {p}") image = image.convert("RGB") images.append(image) - + return images def resize_images(imgs, size): @@ -2481,19 +2463,24 @@ if __name__ == '__main__': # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") parser.add_argument("--seed", type=int, default=None, help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed") - parser.add_argument("--iter_same_seed", action='store_true', help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)') + parser.add_argument("--iter_same_seed", action='store_true', + help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)') parser.add_argument("--fp16", action='store_true', help='use fp16 / fp16を指定し省メモリ化する') parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する') parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する') parser.add_argument("--diffusers_xformers", action='store_true', - help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') + help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') parser.add_argument("--opt_channels_last", action='store_true', - help='set channels last option to model / モデルにchannels lastを指定し最適化する') - parser.add_argument("--network_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') - parser.add_argument("--network_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み') - parser.add_argument("--network_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率') - parser.add_argument("--network_dim", type=int, default=None, + help='set channels last option to model / モデルにchannles lastを指定し最適化する') + parser.add_argument("--network_module", type=str, default=None, nargs='*', + help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') + parser.add_argument("--network_weights", type=str, default=None, nargs='*', + help='Hypernetwork weights to load / Hypernetworkの重み') + parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率') + parser.add_argument("--network_dim", type=int, default=None, nargs='*', help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') + parser.add_argument("--network_args", type=str, default=None, nargs='*', + help='additional argmuments for network (key=value) / ネットワークへの追加の引数') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') parser.add_argument("--max_embeddings_multiples", type=int, default=None, help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')