diff --git a/docs/train_lllite_README-ja.md b/docs/train_lllite_README-ja.md new file mode 100644 index 00000000..45491b73 --- /dev/null +++ b/docs/train_lllite_README-ja.md @@ -0,0 +1,55 @@ +# ConrtolNet-LLLite について + +## 概要 +ConrtolNet-LLLite は、[ConrtolNet](https://github.com/lllyasviel/ControlNet) の軽量版です。LoRA Like Lite という意味で、LoRAからインスピレーションを得た構造を持つ、軽量なControlNetです。現在はSDXLにのみ対応しています。 + +## サンプルの重みファイルと推論 + +こちらにあります: https://huggingface.co/kohya-ss/controlnet-lllite + +ComfyUIのカスタムノードを用意しています。: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI + +生成サンプルはこのページの末尾にあります。 + +## モデル構造 +ひとつのLLLiteモジュールは、制御用画像(以下conditioning image)を潜在空間に写像するconditioning image embeddingと、LoRAにちょっと似た構造を持つ小型のネットワークからなります。LLLiteモジュールを、LoRAと同様にU-NetのLinearやConvに追加します。詳しくはソースコードを参照してください。 + +推論環境の制限で、現在はCrossAttentionのみ(attn1のq/k/v、attn2のq)に追加されます。 + +## モデルの学習 + +### データセットの準備 +通常のdatasetに加え、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。 + +```toml +[[datasets.subsets]] +image_dir = "path/to/image/dir" +caption_extension = ".txt" +conditioning_data_dir = "path/to/conditioning/image/dir" +``` + +現時点の制約として、random_cropは使用できません。 + +### 学習 +スクリプトで生成する場合は、`sdxl_train_control_net_lllite.py` を実行してください。`--cond_emb_dim` でconditioning image embeddingの次元数を指定できます。`--network_dim` でLoRA的モジュールのrankを指定できます。その他のオプションは`sdxl_train_network.py`に準じますが、`--network_module`の指定は不要です。 + +conditioning image embeddingの次元数は、サンプルのCannyでは32を指定しています。LoRA的モジュールのrankは同じく64です。対象とするconditioning imageの特徴に合わせて調整してください。 + +(サンプルのCannyは恐らくかなり難しいと思われます。depthなどでは半分程度にしてもいいかもしれません。) + +### 推論 + + +スクリプトで生成する場合は、`sdxl_gen_img.py` を実行してください。`--control_net_lllite_models` でLLLiteのモデルファイルを指定できます。次元数はモデルファイルから自動取得します。 + +`--guide_image_path`で推論に用いるconditioning imageを指定してください。なおpreprocessは行われないため、たとえばCannyならCanny処理を行った画像を指定してください(背景黒に白線)。`--control_net_preps`, `--control_net_weights`, `--control_net_ratios` には未対応です。 + +## サンプル +Canny +![kohya_ss_girl_standing_at_classroom_smiling_to_the_viewer_class_78976b3e-0d4d-4ea0-b8e3-053ae493abbc](https://github.com/kohya-ss/sd-scripts/assets/52813779/37e9a736-649b-4c0f-ab26-880a1bf319b5) + +![im_20230820104253_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/c8896900-ab86-4120-932f-6e2ae17b77c0) + +![im_20230820104302_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/b12457a0-ee3c-450e-ba9a-b712d0fe86bb) + +![im_20230820104310_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/8845b8d9-804a-44ac-9618-113a28eac8a1) diff --git a/docs/train_lllite_README.md b/docs/train_lllite_README.md new file mode 100644 index 00000000..d3b63ecf --- /dev/null +++ b/docs/train_lllite_README.md @@ -0,0 +1,59 @@ +# About ConrtolNet-LLLite + +## Overview + +ConrtolNet-LLLite is a lightweight version of [ConrtolNet](https://github.com/lllyasviel/ControlNet). It is a "LoRA Like Lite" that is inspired by LoRA and has a lightweight structure. Currently, only SDXL is supported. + +## Sample weight file and inference + +Sample weight file is available here: https://huggingface.co/kohya-ss/controlnet-lllite + +A custom node for ComfyUI is available: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI + +Sample images are at the end of this page. + +## Model structure + +A single LLLite module consists of a conditioning image embedding that maps a conditioning image to a latent space and a small network with a structure similar to LoRA. The LLLite module is added to U-Net's Linear and Conv in the same way as LoRA. Please refer to the source code for details. + +Due to the limitations of the inference environment, only CrossAttention (attn1 q/k/v, attn2 q) is currently added. + +## Model training + +### Preparing the dataset + +In addition to the normal dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. + +```toml +[[datasets.subsets]] +image_dir = "path/to/image/dir" +caption_extension = ".txt" +conditioning_data_dir = "path/to/conditioning/image/dir" +``` + +At the moment, random_crop cannot be used. + +### Training + +Run `sdxl_train_control_net_lllite.py`. You can specify the dimension of the conditioning image embedding with `--cond_emb_dim`. You can specify the rank of the LoRA-like module with `--network_dim`. Other options are the same as `sdxl_train_network.py`, but `--network_module` is not required. + +For the sample Canny, the dimension of the conditioning image embedding is 32. The rank of the LoRA-like module is also 64. Adjust according to the features of the conditioning image you are targeting. + +(The sample Canny is probably quite difficult. It may be better to reduce it to about half for depth, etc.) + +### Inference + +If you want to generate images with a script, run `sdxl_gen_img.py`. You can specify the LLLite model file with `--control_net_lllite_models`. The dimension is automatically obtained from the model file. + +Specify the conditioning image to be used for inference with `--guide_image_path`. Since preprocess is not performed, if it is Canny, specify an image processed with Canny (white line on black background). `--control_net_preps`, `--control_net_weights`, and `--control_net_ratios` are not supported. + +## Sample + +Canny +![kohya_ss_girl_standing_at_classroom_smiling_to_the_viewer_class_78976b3e-0d4d-4ea0-b8e3-053ae493abbc](https://github.com/kohya-ss/sd-scripts/assets/52813779/37e9a736-649b-4c0f-ab26-880a1bf319b5) + +![im_20230820104253_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/c8896900-ab86-4120-932f-6e2ae17b77c0) + +![im_20230820104302_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/b12457a0-ee3c-450e-ba9a-b712d0fe86bb) + +![im_20230820104310_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/8845b8d9-804a-44ac-9618-113a28eac8a1) diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 6ea4bc33..586909bd 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -39,6 +39,7 @@ CONTEXT_DIM: int = 2048 MODEL_CHANNELS: int = 320 TIME_EMBED_DIM = 320 * 4 +USE_REENTRANT = True # region memory effcient attention @@ -322,7 +323,7 @@ class ResnetBlock2D(nn.Module): return custom_forward - x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb) + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb, use_reentrant=USE_REENTRANT) else: x = self.forward_body(x, emb) @@ -356,7 +357,9 @@ class Downsample2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), hidden_states) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.forward_body), hidden_states, use_reentrant=USE_REENTRANT + ) else: hidden_states = self.forward_body(hidden_states) @@ -641,7 +644,9 @@ class BasicTransformerBlock(nn.Module): return custom_forward - output = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), hidden_states, context, timestep) + output = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.forward_body), hidden_states, context, timestep, use_reentrant=USE_REENTRANT + ) else: output = self.forward_body(hidden_states, context, timestep) @@ -782,7 +787,9 @@ class Upsample2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), hidden_states, output_size) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.forward_body), hidden_states, output_size, use_reentrant=USE_REENTRANT + ) else: hidden_states = self.forward_body(hidden_states, output_size) diff --git a/library/train_util.py b/library/train_util.py index 0b40e3ed..ff1b4a33 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1743,6 +1743,9 @@ class ControlNetDataset(BaseDataset): self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + def __len__(self): return self.dreambooth_dataset_delegate.__len__() @@ -1767,17 +1770,26 @@ class ControlNetDataset(BaseDataset): cond_img = load_image(image_info.cond_img_path) if self.dreambooth_dataset_delegate.enable_bucket: - cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ assert ( cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" - ct, cl = crop_top_left + cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + + # TODO support random crop + # 現在サポートしているcropはrandomではなく中央のみ h, w = target_size_hw + ct = (cond_img.shape[0] - h) // 2 + cl = (cond_img.shape[1] - w) // 2 cond_img = cond_img[ct : ct + h, cl : cl + w] else: - assert ( - cond_img.shape[0] == self.height and cond_img.shape[1] == self.width - ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + # assert ( + # cond_img.shape[0] == self.height and cond_img.shape[1] == self.width + # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + # resize to target + if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: + cond_img = cv2.resize( + cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4 + ) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index bb8dcd6b..51f581b2 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -5,35 +5,41 @@ from safetensors.torch import load_file def main(file): - print(f"loading: {file}") - if os.path.splitext(file)[1] == '.safetensors': - sd = load_file(file) - else: - sd = torch.load(file, map_location='cpu') + print(f"loading: {file}") + if os.path.splitext(file)[1] == ".safetensors": + sd = load_file(file) + else: + sd = torch.load(file, map_location="cpu") - values = [] + values = [] - keys = list(sd.keys()) - for key in keys: - if 'lora_up' in key or 'lora_down' in key: - values.append((key, sd[key])) - print(f"number of LoRA modules: {len(values)}") + keys = list(sd.keys()) + for key in keys: + if "lora_up" in key or "lora_down" in key: + values.append((key, sd[key])) + print(f"number of LoRA modules: {len(values)}") - for key, value in values: - value = value.to(torch.float32) - print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") + if args.show_all_keys: + for key in [k for k in keys if k not in values]: + values.append((key, sd[key])) + print(f"number of all modules: {len(values)}") + + for key, value in values: + value = value.to(torch.float32) + print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル") + parser = argparse.ArgumentParser() + parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル") + parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する") - return parser + return parser -if __name__ == '__main__': - parser = setup_parser() +if __name__ == "__main__": + parser = setup_parser() - args = parser.parse_args() + args = parser.parse_args() - main(args.file) + main(args.file) diff --git a/networks/control_net_lllite.py b/networks/control_net_lllite.py new file mode 100644 index 00000000..3140919c --- /dev/null +++ b/networks/control_net_lllite.py @@ -0,0 +1,430 @@ +import os +from typing import Optional, List, Type +import torch +from library import sdxl_original_unet + + +# input_blocksに適用するかどうか / if True, input_blocks are not applied +SKIP_INPUT_BLOCKS = False + +# output_blocksに適用するかどうか / if True, output_blocks are not applied +SKIP_OUTPUT_BLOCKS = True + +# conv2dに適用するかどうか / if True, conv2d are not applied +SKIP_CONV2D = False + +# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない +# if True, only transformer_blocks are applied, and ResBlocks are not applied +TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks + +# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc. +ATTN1_2_ONLY = True + +# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified +ATTN_QKV_ONLY = True + +# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2 +# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY +ATTN1_ETC_ONLY = False # True + +# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用 +# max index of transformer_blocks. if None, apply to all transformer_blocks +TRANSFORMER_MAX_BLOCK_INDEX = None + + +class LLLiteModule(torch.nn.Module): + def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None): + super().__init__() + + self.is_conv2d = org_module.__class__.__name__ == "Conv2d" + self.lllite_name = name + self.cond_emb_dim = cond_emb_dim + self.org_module = [org_module] + self.dropout = dropout + + if self.is_conv2d: + in_dim = org_module.in_channels + else: + in_dim = org_module.in_features + + # conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない + # conditioning1 embeds conditioning image. it is not called for each timestep + modules = [] + modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size + if depth == 1: + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) + elif depth == 2: + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0)) + elif depth == 3: + # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4 + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) + + self.conditioning1 = torch.nn.Sequential(*modules) + + # downで入力の次元数を削減する。LoRAにヒントを得ていることにする + # midでconditioning image embeddingと入力を結合する + # upで元の次元数に戻す + # これらはtimestepごとに呼ばれる + # reduce the number of input dimensions with down. inspired by LoRA + # combine conditioning image embedding and input with mid + # restore to the original dimension with up + # these are called for each timestep + + if self.is_conv2d: + self.down = torch.nn.Sequential( + torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0), + torch.nn.ReLU(inplace=True), + ) + self.mid = torch.nn.Sequential( + torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0), + torch.nn.ReLU(inplace=True), + ) + self.up = torch.nn.Sequential( + torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0), + ) + else: + # midの前にconditioningをreshapeすること / reshape conditioning before mid + self.down = torch.nn.Sequential( + torch.nn.Linear(in_dim, mlp_dim), + torch.nn.ReLU(inplace=True), + ) + self.mid = torch.nn.Sequential( + torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim), + torch.nn.ReLU(inplace=True), + ) + self.up = torch.nn.Sequential( + torch.nn.Linear(mlp_dim, in_dim), + ) + + # Zero-Convにする / set to Zero-Conv + torch.nn.init.zeros_(self.up[0].weight) # zero conv + + self.depth = depth # 1~3 + self.cond_emb = None + self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference + self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0 + + # batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない + # Controlの種類によっては使えるかも + # both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice + # it may be available depending on the type of Control + + def set_cond_image(self, cond_image): + r""" + 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む + / call the model inside, so if necessary, surround it with torch.no_grad() + """ + # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance + # print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") + cx = self.conditioning1(cond_image) + if not self.is_conv2d: + # reshape / b,c,h,w -> b,h*w,c + n, c, h, w = cx.shape + cx = cx.view(n, c, h * w).permute(0, 2, 1) + self.cond_emb = cx + + def set_batch_cond_only(self, cond_only, zeros): + self.batch_cond_only = cond_only + self.use_zeros_for_batch_uncond = zeros + + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + + def forward(self, x): + r""" + 学習用の便利forward。元のモジュールのforwardを呼び出す + / convenient forward for training. call the forward of the original module + """ + cx = self.cond_emb + + if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only + cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) + if self.use_zeros_for_batch_uncond: + cx[0::2] = 0.0 # uncond is zero + # print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}") + + # downで入力の次元数を削減し、conditioning image embeddingと結合する + # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している + # down reduces the number of input dimensions and combines it with conditioning image embedding + # we expect that it will mix well by combining in the channel direction instead of adding + + cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2) + cx = self.mid(cx) + + if self.dropout is not None and self.training: + cx = torch.nn.functional.dropout(cx, p=self.dropout) + + cx = self.up(cx) + + # residua (x) lを加算して元のforwardを呼び出す / add residual (x) and call the original forward + if self.batch_cond_only: + cx = torch.zeros_like(x)[1::2] + cx + + x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here + return x + + +class ControlNetLLLite(torch.nn.Module): + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + + def __init__( + self, + unet: sdxl_original_unet.SdxlUNet2DConditionModel, + cond_emb_dim: int = 16, + mlp_dim: int = 16, + dropout: Optional[float] = None, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + # self.unets = [unet] + + def create_modules( + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + module_class: Type[object], + ) -> List[torch.nn.Module]: + prefix = "lllite_unet" + + modules = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + + if is_linear or (is_conv2d and not SKIP_CONV2D): + # block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う + # block index to depth: depth is using to calculate conditioning size and channels + block_name, index1, index2 = (name + "." + child_name).split(".")[:3] + index1 = int(index1) + if block_name == "input_blocks": + if SKIP_INPUT_BLOCKS: + continue + depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3) + elif block_name == "middle_block": + depth = 3 + elif block_name == "output_blocks": + if SKIP_OUTPUT_BLOCKS: + continue + depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1) + if int(index2) >= 2: + depth -= 1 + else: + raise NotImplementedError() + + lllite_name = prefix + "." + name + "." + child_name + lllite_name = lllite_name.replace(".", "_") + + if TRANSFORMER_MAX_BLOCK_INDEX is not None: + p = lllite_name.find("transformer_blocks") + if p >= 0: + tf_index = int(lllite_name[p:].split("_")[2]) + if tf_index > TRANSFORMER_MAX_BLOCK_INDEX: + continue + + # time embは適用外とする + # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない + # time emb is not applied + # attn2 conditioning (input from CLIP) cannot be applied because the shape is different + if "emb_layers" in lllite_name or ( + "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name) + ): + continue + + if ATTN1_2_ONLY: + if not ("attn1" in lllite_name or "attn2" in lllite_name): + continue + if ATTN_QKV_ONLY: + if "to_out" in lllite_name: + continue + + if ATTN1_ETC_ONLY: + if "proj_out" in lllite_name: + pass + elif "attn1" in lllite_name and ( + "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name + ): + pass + elif "ff_net_2" in lllite_name: + pass + else: + continue + + module = module_class( + depth, + cond_emb_dim, + lllite_name, + child_module, + mlp_dim, + dropout=dropout, + ) + modules.append(module) + return modules + + target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE + if not TRANSFORMER_ONLY: + target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + # create module instances + self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule) + print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") + + def forward(self, x): + return x # dummy + + def set_cond_image(self, cond_image): + r""" + 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む + / call the model inside, so if necessary, surround it with torch.no_grad() + """ + for module in self.unet_modules: + module.set_cond_image(cond_image) + + def set_batch_cond_only(self, cond_only, zeros): + for module in self.unet_modules: + module.set_batch_cond_only(cond_only, zeros) + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self): + print("applying LLLite for U-Net...") + for module in self.unet_modules: + module.apply_to() + self.add_module(module.lllite_name, module) + + # マージできるかどうかを返す + def is_mergeable(self): + return False + + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + raise NotImplementedError() + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_optimizer_params(self): + self.requires_grad_(True) + return self.parameters() + + def prepare_grad_etc(self): + self.requires_grad_(True) + + def on_epoch_start(self): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + +if __name__ == "__main__": + # デバッグ用 / for debug + + # sdxl_original_unet.USE_REENTRANT = False + + # test shape etc + print("create unet") + unet = sdxl_original_unet.SdxlUNet2DConditionModel() + unet.to("cuda").to(torch.float16) + + print("create ControlNet-LLLite") + control_net = ControlNetLLLite(unet, 32, 64) + control_net.apply_to() + control_net.to("cuda") + + print(control_net) + + # print number of parameters + print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad)) + + input() + + unet.set_use_memory_efficient_attention(True, False) + unet.set_gradient_checkpointing(True) + unet.train() # for gradient checkpointing + + control_net.train() + + # # visualize + # import torchviz + # print("run visualize") + # controlnet.set_control(conditioning_image) + # output = unet(x, t, ctx, y) + # print("make_dot") + # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) + # print("render") + # image.format = "svg" # "png" + # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time + # input() + + import bitsandbytes + + optimizer = bitsandbytes.adam.Adam8bit(control_net.prepare_optimizer_params(), 1e-3) + + scaler = torch.cuda.amp.GradScaler(enabled=True) + + print("start training") + steps = 10 + + sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0] + for step in range(steps): + print(f"step {step}") + + batch_size = 1 + conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 + x = torch.randn(batch_size, 4, 128, 128).cuda() + t = torch.randint(low=0, high=10, size=(batch_size,)).cuda() + ctx = torch.randn(batch_size, 77, 2048).cuda() + y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() + + with torch.cuda.amp.autocast(enabled=True): + control_net.set_cond_image(conditioning_image) + + output = unet(x, t, ctx, y) + target = torch.randn_like(output) + loss = torch.nn.functional.mse_loss(output, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + print(sample_param) + + # from safetensors.torch import save_file + + # save_file(control_net.state_dict(), "logs/control_net.safetensors") diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 69c0bd1d..cb16a781 100644 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -47,10 +47,9 @@ import library.train_util as train_util import library.sdxl_model_util as sdxl_model_util import library.sdxl_train_util as sdxl_train_util from networks.lora import LoRANetwork -import tools.original_control_net as original_control_net -from tools.original_control_net import ControlNetInfo from library.sdxl_original_unet import SdxlUNet2DConditionModel from library.original_unet import FlashAttentionFunction +from networks.control_net_lllite import ControlNetLLLite # scheduler: SCHEDULER_LINEAR_START = 0.00085 @@ -327,7 +326,7 @@ class PipelineLike: self.token_replacements_list.append({}) # ControlNet # not supported yet - self.control_nets: List[ControlNetInfo] = [] + self.control_nets: List[ControlNetLLLite] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない # Textual Inversion @@ -392,6 +391,7 @@ class PipelineLike: is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: Optional[int] = 1, img2img_noise=None, + clip_guide_images=None, **kwargs, ): # TODO support secondary prompt @@ -496,11 +496,16 @@ class PipelineLike: text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) if self.control_nets: + # ControlNetのhintにguide imageを流用する if isinstance(clip_guide_images, PIL.Image.Image): clip_guide_images = [clip_guide_images] + if isinstance(clip_guide_images[0], PIL.Image.Image): + clip_guide_images = [preprocess_image(im) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images) + if isinstance(clip_guide_images, list): + clip_guide_images = torch.stack(clip_guide_images) - # ControlNetのhintにguide imageを流用する - # 前処理はControlNet側で行う + clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype) # create size embs if original_height is None: @@ -654,35 +659,47 @@ class PipelineLike: num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 if self.control_nets: - guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if self.control_net_enabled: + for control_net in self.control_nets: + with torch.no_grad(): + control_net.set_cond_image(clip_guide_images) + else: + for control_net in self.control_nets: + control_net.set_cond_image(None) for i, t in enumerate(tqdm(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - if self.control_nets and self.control_net_enabled: - if reginonal_network: - num_sub_and_neg_prompts = len(text_embeddings) // batch_size - text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt - else: - text_emb_last = text_embeddings + # # disable control net if ratio is set + # if self.control_nets and self.control_net_enabled: + # pass # TODO - # not working yet - noise_pred = original_control_net.call_unet_and_control_net( - i, - num_latent_input, - self.unet, - self.control_nets, - guided_hints, - i / len(timesteps), - latent_model_input, - t, - text_emb_last, - ).sample - else: - noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + # predict the noise residual + # TODO Diffusers' ControlNet + # if self.control_nets and self.control_net_enabled: + # if reginonal_network: + # num_sub_and_neg_prompts = len(text_embeddings) // batch_size + # text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt + # else: + # text_emb_last = text_embeddings + + # # not working yet + # noise_pred = original_control_net.call_unet_and_control_net( + # i, + # num_latent_input, + # self.unet, + # self.control_nets, + # guided_hints, + # i / len(timesteps), + # latent_model_input, + # t, + # text_emb_last, + # ).sample + # else: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) # perform guidance if do_classifier_free_guidance: @@ -1550,16 +1567,40 @@ def main(args): upscaler.to(dtype).to(device) # ControlNetの処理 - control_nets: List[ControlNetInfo] = [] - if args.control_net_models: - for i, model in enumerate(args.control_net_models): - prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] - weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] - ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + control_nets: List[ControlNetLLLite] = [] + # if args.control_net_models: + # for i, model in enumerate(args.control_net_models): + # prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + # weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + # ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] - ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model) - prep = original_control_net.load_preprocess(prep_type) - control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + # ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model) + # prep = original_control_net.load_preprocess(prep_type) + # control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + if args.control_net_lllite_models: + for i, model_file in enumerate(args.control_net_lllite_models): + print(f"loading ControlNet-LLLite: {model_file}") + + from safetensors.torch import load_file + + state_dict = load_file(model_file) + mlp_dim = None + cond_emb_dim = None + for key, value in state_dict.items(): + if mlp_dim is None and "down.0.weight" in key: + mlp_dim = value.shape[0] + elif cond_emb_dim is None and "conditioning1.0" in key: + cond_emb_dim = value.shape[0] * 2 + if mlp_dim is not None and cond_emb_dim is not None: + break + assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}" + + control_net = ControlNetLLLite(unet, cond_emb_dim, mlp_dim) + control_net.apply_to() + control_net.load_state_dict(state_dict) + control_net.to(dtype).to(device) + control_net.set_batch_cond_only(False, False) + control_nets.append(control_net) if args.opt_channels_last: print(f"set optimizing: channels last") @@ -1572,8 +1613,9 @@ def main(args): network.to(memory_format=torch.channels_last) for cn in control_nets: - cn.unet.to(memory_format=torch.channels_last) - cn.net.to(memory_format=torch.channels_last) + cn.to(memory_format=torch.channels_last) + # cn.unet.to(memory_format=torch.channels_last) + # cn.net.to(memory_format=torch.channels_last) pipe = PipelineLike( device, @@ -2573,20 +2615,23 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" - ) - parser.add_argument( - "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" - ) - parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み") - parser.add_argument( - "--control_net_ratios", - type=float, - default=None, - nargs="*", - help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + "--control_net_lllite_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" ) # parser.add_argument( + # "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + # ) + # parser.add_argument( + # "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" + # ) + # parser.add_argument("--control_net_multiplier", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率") + # parser.add_argument( + # "--control_net_ratios", + # type=float, + # default=None, + # nargs="*", + # help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + # ) + # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # ) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py new file mode 100644 index 00000000..09cf1643 --- /dev/null +++ b/sdxl_train_control_net_lllite.py @@ -0,0 +1,572 @@ +import argparse +import gc +import json +import math +import os +import random +import time +from multiprocessing import Value +from types import SimpleNamespace +import toml + +from tqdm import tqdm +import torch +from torch.nn.parallel import DistributedDataParallel as DDP +from accelerate.utils import set_seed +from diffusers import DDPMScheduler, ControlNetModel +from safetensors.torch import load_file +from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util + +import library.model_util as model_util +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.huggingface_util as huggingface_util +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + add_v_prediction_like_loss, + apply_snr_weight, + prepare_scheduler_for_custom_training, + pyramid_noise_like, + apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, +) +import networks.control_net_lllite as control_net_lllite + + +# TODO 他のスクリプトと共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = { + "loss/current": current_loss, + "loss/average": avr_loss, + "lr": lr_scheduler.get_last_lr()[0], + } + + if args.optimizer_type.lower().startswith("DAdapt".lower()): + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + + return logs + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + sdxl_train_util.verify_sdxl_training_args(args) + + cache_latents = args.cache_latents + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if use_user_config: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension, + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + train_dataset_group.verify_bucket_reso_steps(32) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + else: + print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + print("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents( + vae, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + ) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + accelerator.wait_for_everyone() + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + with torch.no_grad(): + train_dataset_group.cache_text_encoder_outputs( + (tokenizer1, tokenizer2), + (text_encoder1, text_encoder2), + accelerator.device, + None, + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) + accelerator.wait_for_everyone() + + # prepare ControlNet + network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout) + network.apply_to() + + if args.network_weights is not None: + info = network.load_weights(args.network_weights) + accelerator.print(f"load ControlNet weights from {args.network_weights}: {info}") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + network.enable_gradient_checkpointing() # may have no effect + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + trainable_params = list(network.prepare_optimizer_params()) + print(f"trainable params count: {len(trainable_params)}") + print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + unet.to(weight_dtype) + network.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + unet.to(weight_dtype) + network.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader, lr_scheduler + ) + network: control_net_lllite.ControlNetLLLite + + # transform DDP after prepare (train_network here only) + unet, network = train_util.transform_models_if_DDP([unet, network]) + + if args.gradient_checkpointing: + unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる + else: + unet.eval() + + network.prepare_grad_etc() + + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + text_encoder2.to("cpu", dtype=torch.float32) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + ) + + loss_list = [] + loss_total = 0.0 + del train_dataset_group + + # function for saving/removing + def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) + sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite" + + unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # training loop + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + network.on_epoch_start() # train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(network): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + with torch.no_grad(): + # Get the text embedding for conditioning + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( + args.max_token_length, + input_ids1, + input_ids2, + tokenizer1, + tokenizer2, + text_encoder1, + text_encoder2, + None if not args.full_fp16 else weight_dtype, + ) + else: + encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) + encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) + pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # concat embeddings + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + + noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + + with accelerator.autocast(): + # conditioning imageをControlNetに渡す / pass conditioning image to ControlNet + # 内部でcond_embに変換される / it will be converted to cond_emb inside + network.set_cond_image(controlnet_image) + + # それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = network.get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + # self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + # end of epoch + + if is_main_process: + network = accelerator.unwrap_model(network) + + accelerator.end_training() + + if is_main_process and args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) + + print("model saved.") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数") + parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数") + parser.add_argument( + "--network_dropout", + type=float, + default=None, + help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + return parser + + +if __name__ == "__main__": + # sdxl_original_unet.USE_REENTRANT = False + + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args)