diff --git a/README.md b/README.md index dc8e25ad..5da6181b 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,40 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum ## Change History +### Oct 9. 2023 / 2023/10/9 + +- `tag_images_by_wd_14_tagger.py` now supports Onnx. If you use Onnx, TensorFlow is not required anymore. [#864](https://github.com/kohya-ss/sd-scripts/pull/864) Thanks to Isotr0py! + - `--onnx` option is added. If you use Onnx, specify `--onnx` option. + - Please install Onnx and other required packages. + 1. Uninstall TensorFlow. + 1. `pip install tensorboard==2.14.1` This is required for the specified version of protobuf. + 1. `pip install protobuf==3.20.3` This is required for Onnx. + 1. `pip install onnx==1.14.1` + 1. `pip install onnxruntime-gpu==1.16.0` or `pip install onnxruntime==1.16.0` +- `--append_tags` option is added to `tag_images_by_wd_14_tagger.py`. This option appends the tags to the existing tags, instead of replacing them. [#858](https://github.com/kohya-ss/sd-scripts/pull/858) Thanks to a-l-e-x-d-s-9! +- [OFT](https://oft.wyliu.com/) is now supported. + - You can use `networks.oft` for the network module in `sdxl_train_network.py`. The usage is the same as `networks.lora`. Some options are not supported. + - `sdxl_gen_img.py` also supports OFT as `--network_module`. + - OFT only supports SDXL currently. Because current OFT tweaks Q/K/V and O in the transformer, and SD1/2 have extremely fewer transformers than SDXL. + - The implementation is heavily based on laksjdjf's [OFT implementation](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py). Thanks to laksjdjf! +- Other bug fixes and improvements. + +- `tag_images_by_wd_14_tagger.py` が Onnx をサポートしました。Onnx を使用する場合は TensorFlow は不要です。[#864](https://github.com/kohya-ss/sd-scripts/pull/864) Isotr0py氏に感謝します。 + - Onnxを使用する場合は、`--onnx` オプションを指定してください。 + - Onnx とその他の必要なパッケージをインストールしてください。 + 1. TensorFlow をアンインストールしてください。 + 1. `pip install tensorboard==2.14.1` protobufの指定バージョンにこれが必要。 + 1. `pip install protobuf==3.20.3` Onnxのために必要。 + 1. `pip install onnx==1.14.1` + 1. `pip install onnxruntime-gpu==1.16.0` または `pip install onnxruntime==1.16.0` +- `tag_images_by_wd_14_tagger.py` に `--append_tags` オプションが追加されました。このオプションを指定すると、既存のタグに上書きするのではなく、新しいタグのみが既存のタグに追加されます。 [#858](https://github.com/kohya-ss/sd-scripts/pull/858) a-l-e-x-d-s-9氏に感謝します。 +- [OFT](https://oft.wyliu.com/) をサポートしました。 + - `sdxl_train_network.py` の`--network_module`に `networks.oft` を指定してください。使用方法は `networks.lora` と同様ですが一部のオプションは未サポートです。 + - `sdxl_gen_img.py` でも同様に OFT を指定できます。 + - OFT は現在 SDXL のみサポートしています。OFT は現在 transformer の Q/K/V と O を変更しますが、SD1/2 は transformer の数が SDXL よりも極端に少ないためです。 + - 実装は laksjdjf 氏の [OFT実装](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py) を多くの部分で参考にしています。laksjdjf 氏に感謝します。 +- その他のバグ修正と改善。 + ### Oct 1. 2023 / 2023/10/1 - SDXL training is now available in the main branch. The sdxl branch is merged into the main branch. diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index 91e4f573..965edd7e 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -1,17 +1,15 @@ import argparse import csv -import glob import os - -from PIL import Image -import cv2 -from tqdm import tqdm -import numpy as np -from tensorflow.keras.models import load_model -from huggingface_hub import hf_hub_download -import torch from pathlib import Path +import cv2 +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from tqdm import tqdm + import library.train_util as train_util # from wd14 tagger @@ -20,6 +18,7 @@ IMAGE_SIZE = 448 # wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2 DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"] +FILES_ONNX = ["model.onnx"] SUB_DIR = "variables" SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"] CSV_FILE = FILES[-1] @@ -81,7 +80,10 @@ def main(args): # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22 if not os.path.exists(args.model_dir) or args.force_download: print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}") - for file in FILES: + files = FILES + if args.onnx: + files += FILES_ONNX + for file in files: hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file) for file in SUB_DIR_FILES: hf_hub_download( @@ -96,7 +98,46 @@ def main(args): print("using existing wd14 tagger model") # 画像を読み込む - model = load_model(args.model_dir) + if args.onnx: + import onnx + import onnxruntime as ort + + onnx_path = f"{args.model_dir}/model.onnx" + print("Running wd14 tagger with onnx") + print(f"loading onnx model: {onnx_path}") + + if not os.path.exists(onnx_path): + raise Exception( + f"onnx model not found: {onnx_path}, please redownload the model with --force_download" + + " / onnxモデルが見つかりませんでした。--force_downloadで再ダウンロードしてください" + ) + + model = onnx.load(onnx_path) + input_name = model.graph.input[0].name + try: + batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value + except: + batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param + + if args.batch_size != batch_size and type(batch_size) != str: + # some rebatch model may use 'N' as dynamic axes + print( + f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}" + ) + args.batch_size = batch_size + + del model + + ort_sess = ort.InferenceSession( + onnx_path, + providers=["CUDAExecutionProvider"] + if "CUDAExecutionProvider" in ort.get_available_providers() + else ["CPUExecutionProvider"], + ) + else: + from tensorflow.keras.models import load_model + + model = load_model(f"{args.model_dir}") # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv") # 依存ライブラリを増やしたくないので自力で読むよ @@ -124,8 +165,14 @@ def main(args): def run_batch(path_imgs): imgs = np.array([im for _, im in path_imgs]) - probs = model(imgs, training=False) - probs = probs.numpy() + if args.onnx: + if len(imgs) < args.batch_size: + imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0) + probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy + probs = probs[: len(path_imgs)] + else: + probs = model(imgs, training=False) + probs = probs.numpy() for (image_path, _), prob in zip(path_imgs, probs): # 最初の4つはratingなので無視する @@ -165,9 +212,27 @@ def main(args): if len(character_tag_text) > 0: character_tag_text = character_tag_text[2:] + caption_file = os.path.splitext(image_path)[0] + args.caption_extension + tag_text = ", ".join(combined_tags) - with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f: + if args.append_tags: + # Check if file exists + if os.path.exists(caption_file): + with open(caption_file, "rt", encoding="utf-8") as f: + # Read file and remove new lines + existing_content = f.read().strip("\n") # Remove newlines + + # Split the content into tags and store them in a list + existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()] + + # Check and remove repeating tags in tag_text + new_tags = [tag for tag in combined_tags if tag not in existing_tags] + + # Create new tag_text + tag_text = ", ".join(existing_tags + new_tags) + + with open(caption_file, "wt", encoding="utf-8") as f: f.write(tag_text + "\n") if args.debug: print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}") @@ -283,12 +348,15 @@ def setup_parser() -> argparse.ArgumentParser: help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト", ) parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する") + parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する") + parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する") return parser + if __name__ == "__main__": parser = setup_parser() - + args = parser.parse_args() # スペルミスしていたオプションを復元する diff --git a/networks/oft.py b/networks/oft.py new file mode 100644 index 00000000..1d088f87 --- /dev/null +++ b/networks/oft.py @@ -0,0 +1,430 @@ +# OFT network module + +import math +import os +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel +import numpy as np +import torch +import re + + +RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") + + +class OFTModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + ): + """ + dim -> num blocks + alpha -> constraint + """ + super().__init__() + self.oft_name = oft_name + + self.num_blocks = dim + + if "Linear" in org_module.__class__.__name__: + out_dim = org_module.out_features + elif "Conv" in org_module.__class__.__name__: + out_dim = org_module.out_channels + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().numpy() + self.constraint = alpha * out_dim + self.register_buffer("alpha", torch.tensor(alpha)) + + self.block_size = out_dim // self.num_blocks + self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size)) + + self.out_dim = out_dim + self.shape = org_module.weight.shape + + self.multiplier = multiplier + self.org_module = [org_module] # moduleにならないようにlistに入れる + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + + block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I + R = torch.block_diag(*block_R_weighted) + + return R + + def forward(self, x, scale=None): + x = self.org_forward(x) + if self.multiplier == 0.0: + return x + + R = self.get_weight().to(x.device, dtype=x.dtype) + if x.dim() == 4: + x = x.permute(0, 2, 3, 1) + x = torch.matmul(x, R) + x = x.permute(0, 3, 1, 2) + else: + x = torch.matmul(x, R) + return x + + +class OFTInfModule(OFTModule): + def __init__( + self, + oft_name, + org_module: torch.nn.Module, + multiplier=1.0, + dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(oft_name, org_module, multiplier, dim, alpha) + self.enabled = True + self.network: OFTNetwork = None + + def set_network(self, network): + self.network = network + + def forward(self, x, scale=None): + if not self.enabled: + return self.org_forward(x) + return super().forward(x, scale) + + def merge_to(self, multiplier=None, sign=1): + R = self.get_weight(multiplier) * sign + + # get org weight + org_sd = self.org_module[0].state_dict() + org_weight = org_sd["weight"] + R = R.to(org_weight.device, dtype=org_weight.dtype) + + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R) + else: + weight = torch.einsum("oi, op -> pi", org_weight, R) + + # set weight to org_module + org_sd["weight"] = weight + self.org_module[0].load_state_dict(org_sd) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae: AutoencoderKL, + text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + enable_all_linear = kwargs.get("enable_all_linear", None) + enable_conv = kwargs.get("enable_conv", None) + if enable_all_linear is not None: + enable_all_linear = bool(enable_all_linear) + if enable_conv is not None: + enable_conv = bool(enable_conv) + + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=network_dim, + alpha=network_alpha, + enable_all_linear=enable_all_linear, + enable_conv=enable_conv, + varbose=True, + ) + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # check dim, alpha and if weights have for conv2d + dim = None + alpha = None + has_conv2d = None + all_linear = None + for name, param in weights_sd.items(): + if name.endswith(".alpha"): + if alpha is None: + alpha = param.item() + else: + if dim is None: + dim = param.size()[0] + if has_conv2d is None and param.dim() == 4: + has_conv2d = True + if all_linear is None: + if param.dim() == 3 and "attn" not in name: + all_linear = True + if dim is not None and alpha is not None and has_conv2d is not None: + break + if has_conv2d is None: + has_conv2d = False + if all_linear is None: + all_linear = False + + module_class = OFTInfModule if for_inference else OFTModule + network = OFTNetwork( + text_encoder, + unet, + multiplier=multiplier, + dim=dim, + alpha=alpha, + enable_all_linear=all_linear, + enable_conv=has_conv2d, + module_class=module_class, + ) + return network, weights_sd + + +class OFTNetwork(torch.nn.Module): + UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"] + UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + OFT_PREFIX_UNET = "oft_unet" # これ変えないほうがいいかな + + def __init__( + self, + text_encoder: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + dim: int = 4, + alpha: float = 1, + enable_all_linear: Optional[bool] = False, + enable_conv: Optional[bool] = False, + module_class: Type[object] = OFTModule, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.dim = dim + self.alpha = alpha + + print( + f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}" + ) + + # create module instances + def create_modules( + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[OFTModule]: + prefix = self.OFT_PREFIX_UNET + ofts = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = "Linear" in child_module.__class__.__name__ + is_conv2d = "Conv2d" in child_module.__class__.__name__ + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv): + oft_name = prefix + "." + name + "." + child_name + oft_name = oft_name.replace(".", "_") + # print(oft_name) + + oft = module_class( + oft_name, + child_module, + self.multiplier, + dim, + alpha, + ) + ofts.append(oft) + return ofts + + # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights + if enable_all_linear: + target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + else: + target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ATTN_ONLY + if enable_conv: + target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules) + print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.") + + # assertion + names = set() + for oft in self.unet_ofts: + assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}" + names.add(oft.oft_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for oft in self.unet_ofts: + oft.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + assert apply_unet, "apply_unet must be True" + + for oft in self.unet_ofts: + oft.apply_to() + self.add_module(oft.oft_name, oft) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + print("enable OFT for U-Net") + + for oft in self.unet_ofts: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(oft.oft_name): + sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key] + oft.load_state_dict(sd_for_lora, False) + oft.merge_to() + + print(f"weights are merged") + + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(ofts): + params = [] + for oft in ofts: + params.extend(oft.parameters()) + + # print num of params + num_params = 0 + for p in params: + num_params += p.numel() + print(f"OFT params: {num_params}") + return params + + param_data = {"params": enumerate_params(self.unet_ofts)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + ofts: List[OFTInfModule] = self.unet_ofts + for oft in ofts: + org_module = oft.org_module[0] + oft.merge_to() + # sd = org_module.state_dict() + # org_weight = sd["weight"] + # lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype) + # sd["weight"] = org_weight + lora_weight + # assert sd["weight"].shape == org_weight.shape + # org_module.load_state_dict(sd) + + org_module._lora_restored = False + oft.enabled = False diff --git a/requirements.txt b/requirements.txt index 4ca393f5..c27131cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,8 +19,14 @@ huggingface-hub==0.15.1 # requests==2.28.2 # timm==0.6.12 # fairscale==0.4.13 -# for WD14 captioning +# for WD14 captioning (tensorflow) # tensorflow==2.10.1 +# for WD14 captioning (onnx) +# onnx==1.14.1 +# onnxruntime-gpu==1.16.0 +# onnxruntime==1.16.0 +# this is for onnx: +# protobuf==3.20.3 # open clip for SDXL open-clip-torch==2.20.0 # for kohya_ss library