From a4857fa764effdbdb099fbb6bd54c6d1b46b8238 Mon Sep 17 00:00:00 2001 From: alexds9 Date: Thu, 5 Oct 2023 21:26:09 +0300 Subject: [PATCH 01/10] Add append_captions feature to wd14 tagger This feature allows for appending new tags to the existing content of caption files. If the caption file for an image already exists, the tags generated from the current run are appended to the existing ones. Duplicate tags are checked and avoided. --- finetune/tag_images_by_wd14_tagger.py | 31 ++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 91e4f573..dde586c7 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -165,12 +165,35 @@ def main(args): if len(character_tag_text) > 0: character_tag_text = character_tag_text[2:] + caption_file = os.path.splitext(image_path)[0] + args.caption_extension + tag_text = ", ".join(combined_tags) - with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: + if args.append_captions: + # Check if file exists + if os.path.exists(caption_file): + + with open(caption_file, "rt", encoding="utf-8") as f: + + # Read file and remove new lines + existing_content = f.read().strip("\n") # Remove trailing comma, whitespace, and newlines + + # Split the content into tags and store them in a list + existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()] + + # Check and remove repeating tags in tag_text + tag_text = ", ".join([tag for tag in combined_tags if tag not in existing_tags]) + + # If the file has content, prepend a comma to tag_text + if existing_content.strip() and tag_text: + tag_text = ", ".join(existing_tags) + ", " + tag_text + + + with open(caption_file, "wt", encoding="utf-8") as f: f.write(tag_text + "\n") if args.debug: - print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") + print( + f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -282,7 +305,9 @@ def setup_parser() -> argparse.ArgumentParser: default="", help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", ) - parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") + parser.add_argument("--frequency_tags", action="store_true", + help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") + parser.add_argument("--append_captions", action="store_true", help="Append captions instead of overwriting") return parser From 9378da3c8266c0a87d893a2145196ec6efeb76a0 Mon Sep 17 00:00:00 2001 From: alexds9 Date: Thu, 5 Oct 2023 21:29:46 +0300 Subject: [PATCH 02/10] Fix comment --- finetune/tag_images_by_wd14_tagger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index dde586c7..e2ac5c1d 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -176,7 +176,7 @@ def main(args): with open(caption_file, "rt", encoding="utf-8") as f: # Read file and remove new lines - existing_content = f.read().strip("\n") # Remove trailing comma, whitespace, and newlines + existing_content = f.read().strip("\n") # Remove newlines # Split the content into tags and store them in a list existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()] From 70fe7e18bea63bb2ddc3c8dfdb3a2367d55cb348 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 8 Oct 2023 20:31:10 +0800 Subject: [PATCH 03/10] add onnx to wd14 tagger --- finetune/tag_images_by_wd14_tagger.py | 55 +++++++++++++++++++++------ requirements.txt | 4 +- 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 91e4f573..816aaddb 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -2,16 +2,15 @@ import argparse import csv import glob import os - -from PIL import Image -import cv2 -from tqdm import tqdm -import numpy as np -from tensorflow.keras.models import load_model -from huggingface_hub import hf_hub_download -import torch from pathlib import Path +import cv2 +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from tqdm import tqdm + import library.train_util as train_util # from wd14 tagger @@ -81,6 +80,8 @@ def main(args): # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 if not os.path.exists(args.model_dir) or args.force_download: print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") + if args.onnx: + FILES.append("model.onnx") for file in FILES: hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) for file in SUB_DIR_FILES: @@ -96,7 +97,35 @@ def main(args): print("using existing wd14 tagger model") # 画像を読み込む - model = load_model(args.model_dir) + if args.onnx: + import onnx + import onnxruntime as ort + + onnx_path = f"{args.model_dir}/model.onnx" + print("Running wd14 tagger with onnx") + print(f"loading onnx model: {onnx_path}") + model = onnx.load(onnx_path) + input_name = model.graph.input[0].name + try: + batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value + except: + batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param + if args.batch_size != batch_size and type(batch_size) != str: + # some rebatch model may use 'N' as dynamic axes + print( + f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" + ) + args.batch_size = batch_size + ort_sess = ort.InferenceSession( + model.SerializeToString(), + providers=["CUDAExecutionProvider"] + if "CUDAExecutionProvider" in ort.get_available_providers() + else ["CPUExecutionProvider"], + ) + else: + from tensorflow.keras.models import load_model + + model = load_model(f"{args.model_dir}") # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") # 依存ライブラリを増やしたくないので自力で読むよ @@ -124,8 +153,11 @@ def main(args): def run_batch(path_imgs): imgs = np.array([im for _, im in path_imgs]) - probs = model(imgs, training=False) - probs = probs.numpy() + if args.onnx: + probs = ort_sess.run(None, {input_name: imgs}) # onnx output numpy + else: + probs = model(imgs, training=False) + probs = probs.numpy() for (image_path, _), prob in zip(path_imgs, probs): # 最初の4つはratingなので無視する @@ -283,6 +315,7 @@ def setup_parser() -> argparse.ArgumentParser: help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", ) parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") + parser.add_argument("--onnx", action="store_true", help="use onnx model for inference") return parser diff --git a/requirements.txt b/requirements.txt index 4ca393f5..fa6005ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,8 +19,10 @@ huggingface-hub==0.15.1 # requests==2.28.2 # timm==0.6.12 # fairscale==0.4.13 -# for WD14 captioning +# for WD14 captioning (tensroflow or onnx) # tensorflow==2.10.1 +# onnx==1.14.1 +# onnxruntime==1.16.0 # open clip for SDXL open-clip-torch==2.20.0 # for kohya_ss library From b8b84021e54b34ed04800e21a18fc67e6e9ce1c1 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 8 Oct 2023 20:49:03 +0800 Subject: [PATCH 04/10] fix a typo --- finetune/tag_images_by_wd14_tagger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 816aaddb..6b33af51 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -117,7 +117,7 @@ def main(args): ) args.batch_size = batch_size ort_sess = ort.InferenceSession( - model.SerializeToString(), + onnx_path, providers=["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"], @@ -154,7 +154,7 @@ def main(args): imgs = np.array([im for _, im in path_imgs]) if args.onnx: - probs = ort_sess.run(None, {input_name: imgs}) # onnx output numpy + probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy else: probs = model(imgs, training=False) probs = probs.numpy() From d6f458fcb3cda470486a9d0ea3a2dad0c72b46db Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 8 Oct 2023 23:51:18 +0800 Subject: [PATCH 05/10] fix dependency --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index fa6005ac..75de48cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,6 +22,7 @@ huggingface-hub==0.15.1 # for WD14 captioning (tensroflow or onnx) # tensorflow==2.10.1 # onnx==1.14.1 +# onnxruntime-gpu==1.16.0 # onnxruntime==1.16.0 # open clip for SDXL open-clip-torch==2.20.0 From 0d4e8b50d0ce23437a16d4735f785190a4457af3 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Oct 2023 15:09:54 +0900 Subject: [PATCH 06/10] change option to append_tags, minor update --- finetune/tag_images_by_wd14_tagger.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index e2ac5c1d..31ee93bc 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -169,31 +169,26 @@ def main(args): tag_text = ", ".join(combined_tags) - if args.append_captions: + if args.append_tags: # Check if file exists if os.path.exists(caption_file): - with open(caption_file, "rt", encoding="utf-8") as f: - # Read file and remove new lines existing_content = f.read().strip("\n") # Remove newlines - # Split the content into tags and store them in a list - existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()] + # Split the content into tags and store them in a list + existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()] # Check and remove repeating tags in tag_text - tag_text = ", ".join([tag for tag in combined_tags if tag not in existing_tags]) - - # If the file has content, prepend a comma to tag_text - if existing_content.strip() and tag_text: - tag_text = ", ".join(existing_tags) + ", " + tag_text + new_tags = [tag for tag in combined_tags if tag not in existing_tags] + # Create new tag_text + tag_text = ", ".join(existing_tags + new_tags) with open(caption_file, "wt", encoding="utf-8") as f: f.write(tag_text + "\n") if args.debug: - print( - f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") + print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") # 読み込みの高速化のためにDataLoaderを使うオプション if args.max_data_loader_n_workers is not None: @@ -305,15 +300,15 @@ def setup_parser() -> argparse.ArgumentParser: default="", help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", ) - parser.add_argument("--frequency_tags", action="store_true", - help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") - parser.add_argument("--append_captions", action="store_true", help="Append captions instead of overwriting") + parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") + parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する") return parser + if __name__ == "__main__": parser = setup_parser() - + args = parser.parse_args() # スペルミスしていたオプションを復元する From 406511c333d99286f19e9a5bf2de55bccfd5302b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Oct 2023 17:08:58 +0900 Subject: [PATCH 07/10] add error message if model.onnx doesn't exist --- finetune/tag_images_by_wd14_tagger.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index ffe94e7d..965edd7e 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -1,6 +1,5 @@ import argparse import csv -import glob import os from pathlib import Path @@ -19,6 +18,7 @@ IMAGE_SIZE = 448 # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2 DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] +FILES_ONNX = ["model.onnx"] SUB_DIR = "variables" SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] CSV_FILE = FILES[-1] @@ -80,9 +80,10 @@ def main(args): # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 if not os.path.exists(args.model_dir) or args.force_download: print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") + files = FILES if args.onnx: - FILES.append("model.onnx") - for file in FILES: + files += FILES_ONNX + for file in files: hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) for file in SUB_DIR_FILES: hf_hub_download( @@ -104,18 +105,29 @@ def main(args): onnx_path = f"{args.model_dir}/model.onnx" print("Running wd14 tagger with onnx") print(f"loading onnx model: {onnx_path}") + + if not os.path.exists(onnx_path): + raise Exception( + f"onnx model not found: {onnx_path}, please redownload the model with --force_download" + + " / onnxモデルが見つかりませんでした。--force_downloadで再ダウンロードしてください" + ) + model = onnx.load(onnx_path) input_name = model.graph.input[0].name try: batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value except: batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param + if args.batch_size != batch_size and type(batch_size) != str: # some rebatch model may use 'N' as dynamic axes print( f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" ) args.batch_size = batch_size + + del model + ort_sess = ort.InferenceSession( onnx_path, providers=["CUDAExecutionProvider"] @@ -154,7 +166,10 @@ def main(args): imgs = np.array([im for _, im in path_imgs]) if args.onnx: + if len(imgs) < args.batch_size: + imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy + probs = probs[: len(path_imgs)] else: probs = model(imgs, training=False) probs = probs.numpy() @@ -333,7 +348,7 @@ def setup_parser() -> argparse.ArgumentParser: help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", ) parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") - parser.add_argument("--onnx", action="store_true", help="use onnx model for inference") + parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する") return parser From 66741c035c0ee443399361b50414e1c1e2e8b23e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Oct 2023 17:59:24 +0900 Subject: [PATCH 08/10] add OFT --- networks/oft.py | 430 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 430 insertions(+) create mode 100644 networks/oft.py diff --git a/networks/oft.py b/networks/oft.py new file mode 100644 index 00000000..ba05885c --- /dev/null +++ b/networks/oft.py @@ -0,0 +1,430 @@ +# OFT network module + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel +import numpy as np +import torch +import re + + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +class OFTModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + ): + """ + dim -> num blocks + alpha -> constrait + """ + super().__init__() + self.oft_name = oft_name + + self.num_blocks = dim + + if "Linear" in org_module.__class__.__name__: + out_dim = org_module.out_features + elif "Conv" in org_module.__class__.__name__: + out_dim = org_module.out_channels + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().numpy() + self.constraint = alpha * out_dim + self.register_buffer("alpha", torch.tensor(alpha)) + + self.block_size = out_dim // self.num_blocks + self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size)) + + self.out_dim = out_dim + self.shape = org_module.weight.shape + + self.multiplier = multiplier + self.org_module = [org_module] # moduleにならないようにlistに入れる + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + + block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I + R = torch.block_diag(*block_R_weighted) + + return R + + def forward(self, x, scale=None): + x = self.org_forward(x) + if self.multiplier == 0.0: + return x + + R = self.get_weight().to(x.device, dtype=x.dtype) + if x.dim() == 4: + x = x.permute(0, 2, 3, 1) + x = torch.matmul(x, R) + x = x.permute(0, 3, 1, 2) + else: + x = torch.matmul(x, R) + return x + + +class OFTInfModule(OFTModule): + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(oft_name, org_module, multiplier, dim, alpha) + self.enabled = True + self.network: OFTNetwork = None + + def set_network(self, network): + self.network = network + + def forward(self, x, scale=None): + if not self.enabled: + return self.org_forward(x) + return super().forward(x, scale) + + def merge_to(self, multiplier=None, sign=1): + R = self.get_weight(multiplier) * sign + + # get org weight + org_sd = self.org_module[0].state_dict() + org_weight = org_sd["weight"] + R = R.to(org_weight.device, dtype=org_weight.dtype) + + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R) + else: + weight = torch.einsum("oi, op -> pi", org_weight, R) + + # set weight to org_module + org_sd["weight"] = weight + self.org_module[0].load_state_dict(org_sd) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + enable_all_linear = kwargs.get("enable_all_linear", None) + enable_conv = kwargs.get("enable_conv", None) + if enable_all_linear is not None: + enable_all_linear = bool(enable_all_linear) + if enable_conv is not None: + enable_conv = bool(enable_conv) + + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=network_dim, + alpha=network_alpha, + enable_all_linear=enable_all_linear, + enable_conv=enable_conv, + varbose=True, + ) + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # check dim, alpha and if weights have for conv2d + dim = None + alpha = None + has_conv2d = None + all_linear = None + for name, param in weights_sd.items(): + if name.endswith(".alpha"): + if alpha is None: + alpha = param.item() + else: + if dim is None: + dim = param.size()[0] + if has_conv2d is None and param.dim() == 4: + has_conv2d = True + if all_linear is None: + if param.dim() == 3 and "attn" not in name: + all_linear = True + if dim is not None and alpha is not None and has_conv2d is not None: + break + if has_conv2d is None: + has_conv2d = False + if all_linear is None: + all_linear = False + + module_class = OFTInfModule if for_inference else OFTModule + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=dim, + alpha=alpha, + enable_all_linear=all_linear, + enable_conv=has_conv2d, + module_class=module_class, + ) + return network, weights_sd + + +class OFTNetwork(torch.nn.Module): + UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"] + UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + OFT_PREFIX_UNET = "oft_unet" # これ変えないほうがいいかな + + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + dim: int = 4, + alpha: float = 1, + enable_all_linear: Optional[bool] = False, + enable_conv: Optional[bool] = False, + module_class: Type[object] = OFTModule, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.dim = dim + self.alpha = alpha + + print( + f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}" + ) + + # create module instances + def create_modules( + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[OFTModule]: + prefix = self.OFT_PREFIX_UNET + ofts = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = "Linear" in child_module.__class__.__name__ + is_conv2d = "Conv2d" in child_module.__class__.__name__ + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv): + oft_name = prefix + "." + name + "." + child_name + oft_name = oft_name.replace(".", "_") + # print(oft_name) + + oft = module_class( + oft_name, + child_module, + self.multiplier, + dim, + alpha, + ) + ofts.append(oft) + return ofts + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + if enable_all_linear: + target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + else: + target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ATTN_ONLY + if enable_conv: + target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) + print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.") + + # assertion + names = set() + for oft in self.unet_ofts: + assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}" + names.add(oft.oft_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for oft in self.unet_ofts: + oft.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + assert apply_unet, "apply_unet must be True" + + for oft in self.unet_ofts: + oft.apply_to() + self.add_module(oft.oft_name, oft) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + print("enable OFT for U-Net") + + for oft in self.unet_ofts: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(oft.oft_name): + sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key] + oft.load_state_dict(sd_for_lora, False) + oft.merge_to() + + print(f"weights are merged") + + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(ofts): + params = [] + for oft in ofts: + params.extend(oft.parameters()) + + # print num of params + num_params = 0 + for p in params: + num_params += p.numel() + print(f"OFT params: {num_params}") + return params + + param_data = {"params": enumerate_params(self.unet_ofts)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + oft.merge_to() + # sd = org_module.state_dict() + # org_weight = sd["weight"] + # lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype) + # sd["weight"] = org_weight + lora_weight + # assert sd["weight"].shape == org_weight.shape + # org_module.load_state_dict(sd) + + org_module._lora_restored = False + oft.enabled = False From cf49e912fc24a83d7bfd5b10c2831fce88756f90 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Oct 2023 17:59:31 +0900 Subject: [PATCH 09/10] update readme --- README.md | 34 ++++++++++++++++++++++++++++++++++ requirements.txt | 5 ++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index dc8e25ad..974aeaea 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,40 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### Oct 9. 2023 / 2023/10/9 + +- `tag_images_by_wd_14_tagger.py` now supports Onnx. If you use Onnx, TensorFlow is not required anymore. [#864](https://github.com/kohya-ss/sd-scripts/pull/864) Thanks to Isotr0py! + - `--onnx` option is added. If you use Onnx, specify `--onnx` option. + - Please install Onnx and other required packages. + 1. Uninstall TensorFlow. + 1. `pip install tensorboard==2.14.1` This is required for the specified version of protobuf. + 1. `pip install protobuf==3.20.3` This is required for Onnx. + 1. `pip install onnx==1.14.1` + 1. `pip install onnxruntime-gpu==1.16.0` or `pip install onnxruntime==1.16.0` +- `--append_tags` option is added to `tag_images_by_wd_14_tagger.py`. This option appends the tags to the existing tags, instead of replacing them. [#858](https://github.com/kohya-ss/sd-scripts/pull/858) Thanks to a-l-e-x-d-s-9! +- [OFT](https://oft.wyliu.com/) is now supported. + - You can use `networks.oft` for the network module in `sdxl_train_network.py`. The usage is the same as `networks.lora`. Some options are not supported. + - `sdxl_gen_img.py` also supports OFT as `--network_module`. + - OFT only supports SDXL currently. Because current OFT tweaks Q/K/V and O in the transformer, and SD1/2 have extremely fewer transformers than SDXL. + - The implementation is heavily based on laksjdjf's [OFT implementation](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py). Thanks to laksjdjf! +- Other bug fixes and improvements. + +- `tag_images_by_wd_14_tagger.py` が Onnx をサポートしました。Onnx を使用する場合は TensorFlow は不要です。[#864](https://github.com/kohya-ss/sd-scripts/pull/864) Isotr0py氏に感謝します。 + - Onnxを使用する場合は、`--onnx` オプションを指定してください。 + - Onnx とその他の必要なパッケージをインストールしてください。 + 1. TensorFlow をアンインストールしてください。 + 1. `pip install tensorboard==2.14.1` protobufの指定バージョンにこれが必要。 + 1. `pip install protobuf==3.20.3` Onnxのために必要。 + 1. `pip install onnx==1.14.1` + 1. `pip install onnxruntime-gpu==1.16.0` または `pip install onnxruntime==1.16.0` +- `tag_images_by_wd_14_tagger.py` に `--append_tags` オプションが追加されました。このオプションを指定すると、既存のタグに上書きするのではなく、新しいタグのみが既存のタグに追加されます。 [#858](https://github.com/kohya-ss/sd-scripts/pull/858) a-l-e-x-d-s-9氏に感謝します。 +- [OFT](https://oft.wyliu.com/) をサポートしました。 + - `sdxl_train_network.py` の`--network_module`に `networks.oft` を指定してください。使用方法は `networks.lora` と同様ですが一部のオプションは未サポートです。 + - `sdxl_gen_img.py` でも同様に OFT を指定できます。 + - OFT は現在 SDXL のみサポートしています。OFT は現在 transformer の Q/K/V と O を変更しますが、SD1/2 は transfomer の数が SDXL よりも極端に少ないためです。 + - 実装は laksjdjf 氏の [OFT実装](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py) を多くの部分で参考にしています。laksjdjf 氏に感謝します。 +- その他のバグ修正と改善。 + ### Oct 1. 2023 / 2023/10/1 - SDXL training is now available in the main branch. The sdxl branch is merged into the main branch. diff --git a/requirements.txt b/requirements.txt index 75de48cb..c27131cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,11 +19,14 @@ huggingface-hub==0.15.1 # requests==2.28.2 # timm==0.6.12 # fairscale==0.4.13 -# for WD14 captioning (tensroflow or onnx) +# for WD14 captioning (tensorflow) # tensorflow==2.10.1 +# for WD14 captioning (onnx) # onnx==1.14.1 # onnxruntime-gpu==1.16.0 # onnxruntime==1.16.0 +# this is for onnx: +# protobuf==3.20.3 # open clip for SDXL open-clip-torch==2.20.0 # for kohya_ss library From 8b79e3b06c1f18d353c37706667de3224bca4f1c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 9 Oct 2023 18:00:45 +0900 Subject: [PATCH 10/10] fix typos --- README.md | 2 +- networks/oft.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 974aeaea..5da6181b 100644 --- a/README.md +++ b/README.md @@ -279,7 +279,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum - [OFT](https://oft.wyliu.com/) をサポートしました。 - `sdxl_train_network.py` の`--network_module`に `networks.oft` を指定してください。使用方法は `networks.lora` と同様ですが一部のオプションは未サポートです。 - `sdxl_gen_img.py` でも同様に OFT を指定できます。 - - OFT は現在 SDXL のみサポートしています。OFT は現在 transformer の Q/K/V と O を変更しますが、SD1/2 は transfomer の数が SDXL よりも極端に少ないためです。 + - OFT は現在 SDXL のみサポートしています。OFT は現在 transformer の Q/K/V と O を変更しますが、SD1/2 は transformer の数が SDXL よりも極端に少ないためです。 - 実装は laksjdjf 氏の [OFT実装](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py) を多くの部分で参考にしています。laksjdjf 氏に感謝します。 - その他のバグ修正と改善。 diff --git a/networks/oft.py b/networks/oft.py index ba05885c..1d088f87 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -28,7 +28,7 @@ class OFTModule(torch.nn.Module): ): """ dim -> num blocks - alpha -> constrait + alpha -> constraint """ super().__init__() self.oft_name = oft_name