From 6d10233a53a133ca69b24a5bc6afac8fcd132756 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 4 Jan 2023 08:32:22 +0900 Subject: [PATCH 1/2] Support multiple additional networks --- gen_img_diffusers.py | 107 +++++++++++++++++++------------------------ 1 file changed, 47 insertions(+), 60 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 75f14afa..208b1b70 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -1,38 +1,3 @@ -# txt2img with Diffusers: supports SD checkpoints, EulerScheduler, clip-skip, 225 tokens, Hypernetwork etc... - -# v2: CLIP guided Stable Diffusion, Image guided Stable Diffusion, highres. fix -# v3: Add dpmsolver/dpmsolver++, add VAE loading, add upscale, add 'bf16', fix the issue network_mul is not working -# v4: SD2.0 support (new U-Net/text encoder/tokenizer), simplify by DiffUsers 0.9.0, no_preview in interactive mode -# v5: fix clip_sample=True for scheduler, add VGG guidance -# v6: refactor to use model util, load VAE without vae folder, support safe tensors -# v7: add use_original_file_name and iter_same_seed option, change vgg16 guide input image size, -# Diffusers 0.10.0 (support new schedulers (dpm_2, dpm_2_a, heun, dpmsingle), supports all scheduler in v-prediction) -# v8: accept wildcard for ckpt name (when only one file is matched), fix a bug app crushes because PIL image doesn't have filename attr sometimes, -# v9: sort file names, fix an issue in img2img when prompt from metadata with images_per_prompt>1 -# v10: fix app crashes when different image size in prompts - -# Copyright 2022 kohya_ss @kohya_ss -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# license of included scripts: - -# FlashAttention: based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py -# MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE - -# Diffusers (model conversion, CLIP guided stable diffusion, schedulers etc.): -# ASL 2.0 https://github.com/huggingface/diffusers/blob/main/LICENSE - """ VGG( (features): Sequential( @@ -517,7 +482,7 @@ class PipelineLike(): self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers) self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD) -# region xformersとか使う部分:独自に書き換えるので関係なし + # region xformersとか使う部分:独自に書き換えるので関係なし def enable_xformers_memory_efficient_attention(self): r""" Enable memory efficient attention as implemented in xformers. @@ -1982,26 +1947,42 @@ def main(args): vgg16_model.to(dtype).to(device) # networkを組み込む - if args.network_module is not None: - # assert not args.diffusers_xformers, "cannot use network with diffusers_xformers / diffusers_xformers指定時はnetworkは利用できません" + if args.network_module: + networks = [] + for i, network_module in enumerate(args.network_module): + print("import network module:", network_module) + imported_module = importlib.import_module(network_module) - print("import network module:", args.network_module) - network_module = importlib.import_module(args.network_module) + network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i] - network = network_module.create_network(args.network_mul, args.network_dim, vae,text_encoder, unet) # , **net_kwargs) - if network is None: - return + net_kwargs = {} + if args.network_args and i < len(args.network_args): + network_args = args.network_args[i] + # TODO escape special chars + network_args = network_args.split(";") + for net_arg in network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value - print("load network weights from:", args.network_weights) - network.load_weights(args.network_weights) + network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs) + if network is None: + return - network.apply_to(text_encoder, unet) + if args.network_weights and i < len(args.network_weights): + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) + network.load_weights(network_weight) - if args.opt_channels_last: - network.to(memory_format=torch.channels_last) - network.to(dtype).to(device) + network.apply_to(text_encoder, unet) + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + networks.append(network) else: - network = None + networks = [] if args.opt_channels_last: print(f"set optimizing: channels last") @@ -2010,8 +1991,9 @@ def main(args): unet.to(memory_format=torch.channels_last) if clip_model is not None: clip_model.to(memory_format=torch.channels_last) - if network is not None: - network.to(memory_format=torch.channels_last) + if networks: + for network in networks: + network.to(memory_format=torch.channels_last) if vgg16_model is not None: vgg16_model.to(memory_format=torch.channels_last) @@ -2053,7 +2035,7 @@ def main(args): print(f"convert image to RGB from {image.mode}: {p}") image = image.convert("RGB") images.append(image) - + return images def resize_images(imgs, size): @@ -2481,19 +2463,24 @@ if __name__ == '__main__': # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") parser.add_argument("--seed", type=int, default=None, help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed") - parser.add_argument("--iter_same_seed", action='store_true', help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)') + parser.add_argument("--iter_same_seed", action='store_true', + help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)') parser.add_argument("--fp16", action='store_true', help='use fp16 / fp16を指定し省メモリ化する') parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する') parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する') parser.add_argument("--diffusers_xformers", action='store_true', - help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') + help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') parser.add_argument("--opt_channels_last", action='store_true', - help='set channels last option to model / モデルにchannels lastを指定し最適化する') - parser.add_argument("--network_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') - parser.add_argument("--network_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み') - parser.add_argument("--network_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率') - parser.add_argument("--network_dim", type=int, default=None, + help='set channels last option to model / モデルにchannles lastを指定し最適化する') + parser.add_argument("--network_module", type=str, default=None, nargs='*', + help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') + parser.add_argument("--network_weights", type=str, default=None, nargs='*', + help='Hypernetwork weights to load / Hypernetworkの重み') + parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率') + parser.add_argument("--network_dim", type=int, default=None, nargs='*', help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') + parser.add_argument("--network_args", type=str, default=None, nargs='*', + help='additional argmuments for network (key=value) / ネットワークへの追加の引数') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') parser.add_argument("--max_embeddings_multiples", type=int, default=None, help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる') From 4dd22f4dc86f948012f196736419d2afc48a0c24 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 6 Jan 2023 21:36:01 +0900 Subject: [PATCH 2/2] add script to approximate diff of two models --- networks/extract_lora_from_models.py | 158 +++++++++++++++++++++++++++ train_network_README-ja.md | 42 ++++++- 2 files changed, 195 insertions(+), 5 deletions(-) create mode 100644 networks/extract_lora_from_models.py diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py new file mode 100644 index 00000000..c882e88f --- /dev/null +++ b/networks/extract_lora_from_models.py @@ -0,0 +1,158 @@ +# extract approximating LoRA by svd from two SD models +# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py +# Thanks to cloneofsimo! + +import argparse +import os +import torch +from safetensors.torch import load_file, save_file +from tqdm import tqdm +import library.model_util as model_util +import lora + + +CLAMP_QUANTILE = 0.99 +MIN_DIFF = 1e-6 + + +def save_to_file(file_name, model, state_dict, dtype): + if dtype is not None: + for key in list(state_dict.keys()): + if type(state_dict[key]) == torch.Tensor: + state_dict[key] = state_dict[key].to(dtype) + + if os.path.splitext(file_name)[1] == '.safetensors': + save_file(model, file_name) + else: + torch.save(model, file_name) + + +def svd(args): + def str_to_dtype(p): + if p == 'float': + return torch.float + if p == 'fp16': + return torch.float16 + if p == 'bf16': + return torch.bfloat16 + return None + + save_dtype = str_to_dtype(args.save_precision) + + print(f"loading SD model : {args.model_org}") + text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org) + print(f"loading SD model : {args.model_tuned}") + text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) + + # create LoRA network to extract weights + lora_network_o = lora.create_network(1.0, args.dim, None, text_encoder_o, unet_o) + lora_network_t = lora.create_network(1.0, args.dim, None, text_encoder_t, unet_t) + assert len(lora_network_o.text_encoder_loras) == len( + lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) " + + # get diffs + diffs = {} + text_encoder_different = False + for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)): + lora_name = lora_o.lora_name + module_o = lora_o.org_module + module_t = lora_t.org_module + diff = module_t.weight - module_o.weight + + # Text Encoder might be same + if torch.max(torch.abs(diff)) > MIN_DIFF: + text_encoder_different = True + + diff = diff.float() + diffs[lora_name] = diff + + if not text_encoder_different: + print("Text encoder is same. Extract U-Net only.") + lora_network_o.text_encoder_loras = [] + diffs = {} + + for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)): + lora_name = lora_o.lora_name + module_o = lora_o.org_module + module_t = lora_t.org_module + diff = module_t.weight - module_o.weight + diff = diff.float() + + if args.device: + diff = diff.to(args.device) + + diffs[lora_name] = diff + + # make LoRA with svd + print("calculating by svd") + rank = args.dim + lora_weights = {} + with torch.no_grad(): + for lora_name, mat in tqdm(list(diffs.items())): + conv2d = (len(mat.size()) == 4) + if conv2d: + mat = mat.squeeze() + + U, S, Vh = torch.linalg.svd(mat) + + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, CLAMP_QUANTILE) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + lora_weights[lora_name] = (U, Vh) + + # make state dict for LoRA + lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict + lora_sd = lora_network_o.state_dict() + print(f"LoRA has {len(lora_sd)} weights.") + + for key in list(lora_sd.keys()): + lora_name = key.split('.')[0] + i = 0 if "lora_up" in key else 1 + + weights = lora_weights[lora_name][i] + # print(key, i, weights.size(), lora_sd[key].size()) + if len(lora_sd[key].size()) == 4: + weights = weights.unsqueeze(2).unsqueeze(3) + + assert weights.size() == lora_sd[key].size() + lora_sd[key] = weights + + # load state dict to LoRA and save it + info = lora_network_o.load_state_dict(lora_sd) + print(f"Loading extracted LoRA weights: {info}") + + dir_name = os.path.dirname(args.save_to) + if dir_name and not os.path.exists(dir_name): + os.makedirs(dir_name, exist_ok=True) + + lora_network_o.save_weights(args.save_to, save_dtype) + print(f"LoRA weights are saved to: {args.save_to}") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む') + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat") + parser.add_argument("--model_org", type=str, default=None, + help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors") + parser.add_argument("--model_tuned", type=str, default=None, + help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors") + parser.add_argument("--save_to", type=str, default=None, + help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") + parser.add_argument("--dim", type=int, default=4, help="dimension of LoRA (default 4) / LoRAの次元数(デフォルト4)") + parser.add_argument("--device", type=str, default=None, help="device to use, 'cuda' for GPU / 計算を行うデバイス、'cuda'でGPUを使う") + + args = parser.parse_args() + svd(args) diff --git a/train_network_README-ja.md b/train_network_README-ja.md index bba4293d..77ef4c17 100644 --- a/train_network_README-ja.md +++ b/train_network_README-ja.md @@ -10,9 +10,7 @@ cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。 -WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルに、このリポジトリ内のスクリプトであらかじめマージしておく必要があります。マージ後のモデルファイルはLoRAの学習結果が反映されたものになります。 - -なお当リポジトリ内の画像生成スクリプトで生成する場合はマージ不要です。 +WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extention](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。 ## 学習方法 @@ -24,7 +22,7 @@ DreamBoothの手法(identifier(sksなど)とclass、オプションで正 ### DreamBoothの手法を用いる場合 -note.com [環境整備とDreamBooth学習スクリプトについて](https://note.com/kohya_ss/n/nba4eceaa4594) を参照してデータを用意してください。 +[DreamBoothのガイド](./train_db_README-ja.md) を参照してデータを用意してください。 学習するとき、train_db.pyの代わりにtrain_network.pyを指定してください。 @@ -110,7 +108,7 @@ python networks\merge_lora.py --sd_model ..\model\model.ckpt ### 複数のLoRAのモデルをマージする -結局のところSDモデルにマージしないと推論できないのであまり使い道はないかもしれません。ただ、複数のLoRAモデルをひとつずつSDモデルにマージしていく場合と、複数のLoRAモデルをマージしてからSDモデルにマージする場合とは、計算順序の関連で微妙に異なる結果になります。 +複数のLoRAモデルをひとつずつSDモデルに適用する場合と、複数のLoRAモデルをマージしてからSDモデルにマージする場合とは、計算順序の関連で微妙に異なる結果になります。 たとえば以下のようなコマンドラインになります。 @@ -144,6 +142,40 @@ gen_img_diffusers.pyに、--network_module、--network_weights、--network_dim --network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。 +## 二つのモデルの差分からLoRAモデルを作成する + +[こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数式はそのまま使わせていただきました(よく理解していませんが近似には特異値分解を用いるようです)。 + +二つのモデル(たとえばfine tuningの元モデルとfine tuning後のモデル)の差分を、LoRAで近似します。 + +### スクリプトの実行方法 + +以下のように指定してください。 +``` +python networks\extract_lora_from_models.py --model_org base-model.ckpt + --model_tuned fine-tuned-model.ckpt + --save_to lora-weights.safetensors --dim 4 +``` + +--model_orgオプションに元のStable Diffusionモデルを指定します。作成したLoRAモデルを適用する場合は、このモデルを指定して適用することになります。.ckptまたは.safetensorsが指定できます。 + +--model_tunedオプションに差分を抽出する対象のStable Diffusionモデルを指定します。たとえばfine tuningやDreamBooth後のモデルを指定します。.ckptまたは.safetensorsが指定できます。 + +--save_toにLoRAモデルの保存先を指定します。--dimにLoRAの次元数を指定します。 + +生成されたLoRAモデルは、学習したLoRAモデルと同様に使用できます。 + +Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRAとなります。 + +### その他のオプション + +- --v2 + - v2.xのStable Diffusionモデルを使う場合に指定してください。 +- --device + - ``--device cuda``としてcudaを指定すると計算をGPU上で行います。処理が速くなります(CPUでもそこまで遅くないため、せいぜい倍~数倍程度のようです)。 +- --save_precision + - LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。 + ## 追加情報 ### cloneofsimo氏のリポジトリとの違い