From fef7eb73ad77566759fc74276ffb9fc8d1cafeec Mon Sep 17 00:00:00 2001 From: ykume Date: Sat, 19 Aug 2023 18:44:40 +0900 Subject: [PATCH] rename and update --- ...a_control_net.py => control_net_lllite.py} | 302 +- sdxl_gen_img.py | 145 +- sdxl_gen_img_lora_ctrl_test.py | 2612 ----------------- ...net.py => sdxl_train_control_net_lllite.py | 18 +- 4 files changed, 253 insertions(+), 2824 deletions(-) rename networks/{lora_control_net.py => control_net_lllite.py} (54%) delete mode 100644 sdxl_gen_img_lora_ctrl_test.py rename sdxl_train_lora_control_net.py => sdxl_train_control_net_lllite.py (97%) diff --git a/networks/lora_control_net.py b/networks/control_net_lllite.py similarity index 54% rename from networks/lora_control_net.py rename to networks/control_net_lllite.py index 120ab0ac..e0157080 100644 --- a/networks/lora_control_net.py +++ b/networks/control_net_lllite.py @@ -1,7 +1,6 @@ import os from typing import Optional, List, Type import torch -from networks.lora import LoRAModule, LoRANetwork from library import sdxl_original_unet @@ -21,6 +20,9 @@ TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not # Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc. ATTN1_2_ONLY = True +# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified +ATTN_QKV_ONLY = True + # Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2 # ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY ATTN1_ETC_ONLY = False # True @@ -30,126 +32,159 @@ ATTN1_ETC_ONLY = False # True TRANSFORMER_MAX_BLOCK_INDEX = None -class LoRAModuleControlNet(LoRAModule): - def __init__(self, depth, cond_emb_dim, name, org_module, multiplier, lora_dim, alpha, dropout=None): - super().__init__(name, org_module, multiplier, lora_dim, alpha, dropout=dropout) +class LLLiteModule(torch.nn.Module): + def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None): + super().__init__() + self.is_conv2d = org_module.__class__.__name__ == "Conv2d" + self.lllite_name = name self.cond_emb_dim = cond_emb_dim - - # conditioning1は、conditioning image embeddingを、各LoRA的モジュールでさらに学習する。ここはtimestepごとに呼ばれない - # それぞれのモジュールで異なる表現を学習することを期待している - # conditioning1 learns conditioning image embedding in each LoRA-like module. this is not called for each timestep - # we expect to learn different representations in each module - - # conditioning2は、conditioning1の出力とLoRAの出力を結合し、LoRAの出力に加算する。timestepごとに呼ばれる - # conditioning image embeddingとU-Netの出力を合わせて学ぶことで、conditioningに応じたU-Netの調整を行う - # conditioning2 combines the output of conditioning1 and the output of LoRA, and adds it to the output of LoRA. this is called for each timestep - # by learning the output of conditioning image embedding and U-Net together, we adjust U-Net according to conditioning + self.org_module = [org_module] + self.dropout = dropout if self.is_conv2d: - self.conditioning1 = torch.nn.Sequential( - torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0), - torch.nn.ReLU(inplace=True), - torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0), + in_dim = org_module.in_channels + else: + in_dim = org_module.in_features + + # conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない + # conditioning1 embeds conditioning image. it is not called for each timestep + modules = [] + modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size + if depth == 1: + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) + elif depth == 2: + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0)) + elif depth == 3: + # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4 + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) + + self.conditioning1 = torch.nn.Sequential(*modules) + + # downで入力の次元数を削減する。LoRAにヒントを得ていることにする + # midでconditioning image embeddingと入力を結合する + # upで元の次元数に戻す + # これらはtimestepごとに呼ばれる + # reduce the number of input dimensions with down. inspired by LoRA + # combine conditioning image embedding and input with mid + # restore to the original dimension with up + # these are called for each timestep + + if self.is_conv2d: + self.down = torch.nn.Sequential( + torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0), torch.nn.ReLU(inplace=True), ) - self.conditioning2 = torch.nn.Sequential( - torch.nn.Conv2d(lora_dim + cond_emb_dim, cond_emb_dim, kernel_size=1, stride=1, padding=0), - torch.nn.ReLU(inplace=True), - torch.nn.Conv2d(cond_emb_dim, lora_dim, kernel_size=1, stride=1, padding=0), + self.mid = torch.nn.Sequential( + torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0), torch.nn.ReLU(inplace=True), ) + self.up = torch.nn.Sequential( + torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0), + ) else: - self.conditioning1 = torch.nn.Sequential( - torch.nn.Linear(cond_emb_dim, cond_emb_dim), - torch.nn.ReLU(inplace=True), - torch.nn.Linear(cond_emb_dim, cond_emb_dim), + # midの前にconditioningをreshapeすること / reshape conditioning before mid + self.down = torch.nn.Sequential( + torch.nn.Linear(in_dim, mlp_dim), torch.nn.ReLU(inplace=True), ) - self.conditioning2 = torch.nn.Sequential( - torch.nn.Linear(lora_dim + cond_emb_dim, cond_emb_dim), - torch.nn.ReLU(inplace=True), - torch.nn.Linear(cond_emb_dim, lora_dim), + self.mid = torch.nn.Sequential( + torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim), torch.nn.ReLU(inplace=True), ) + self.up = torch.nn.Sequential( + torch.nn.Linear(mlp_dim, in_dim), + ) - # Zero-Convにするならコメントを外す / uncomment if you want to use Zero-Conv - # torch.nn.init.zeros_(self.conditioning2[-2].weight) # zero conv + # Zero-Convにする / set to Zero-Conv + torch.nn.init.zeros_(self.up[0].weight) # zero conv self.depth = depth # 1~3 self.cond_emb = None self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0 - def set_cond_embs(self, cond_embs_4d, cond_embs_3d): + # batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない + # Controlの種類によっては使えるかも + # both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice + # it may be available depending on the type of Control + + def set_cond_image(self, cond_image): r""" 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む / call the model inside, so if necessary, surround it with torch.no_grad() """ - # conv2dとlinearでshapeが違うので必要な方を選択 / select the required one because the shape is different for conv2d and linear - cond_embs = cond_embs_4d if self.is_conv2d else cond_embs_3d - - cond_emb = cond_embs[self.depth - 1] - # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance - self.cond_emb = self.conditioning1(cond_emb) + cx = self.conditioning1(cond_image) + if not self.is_conv2d: + # reshape / b,c,h,w -> b,h*w,c + n, c, h, w = cx.shape + cx = cx.view(n, c, h * w).permute(0, 2, 1) + self.cond_emb = cx def set_batch_cond_only(self, cond_only, zeros): self.batch_cond_only = cond_only self.use_zeros_for_batch_uncond = zeros + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + def forward(self, x): - if self.cond_emb is None: - return self.org_forward(x) - - # LoRA-Down - lx = x - if self.batch_cond_only: - lx = lx[1::2] # cond only in inference - - lx = self.lora_down(lx) - - if self.dropout is not None and self.training: - lx = torch.nn.functional.dropout(lx, p=self.dropout) - - # conditioning image embeddingを結合 / combine conditioning image embedding + r""" + 学習用の便利forward。元のモジュールのforwardを呼び出す + / convenient forward for training. call the forward of the original module + """ cx = self.cond_emb - if not self.batch_cond_only and lx.shape[0] // 2 == cx.shape[0]: # inference only + if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) if self.use_zeros_for_batch_uncond: cx[0::2] = 0.0 # uncond is zero - # print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") + # print(f"C {self.lllite_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") + # downで入力の次元数を削減し、conditioning image embeddingと結合する # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している + # down reduces the number of input dimensions and combines it with conditioning image embedding # we expect that it will mix well by combining in the channel direction instead of adding - cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2) - cx = self.conditioning2(cx) - lx = lx + cx # lxはresidual的に加算される / lx is added residually + cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2) + cx = self.mid(cx) - # LoRA-Up - lx = self.lora_up(lx) + if self.dropout is not None and self.training: + cx = torch.nn.functional.dropout(cx, p=self.dropout) - # call original module - x = self.org_forward(x) + cx = self.up(cx) - # add LoRA + # residualを加算する / add residual if self.batch_cond_only: - x[1::2] += lx * self.multiplier * self.scale + x[1::2] += cx else: - x += lx * self.multiplier * self.scale + # to_outを対象とすると、cloneがないと次のエラーが出る / if to_out is the target, the following error will occur without clone + # RuntimeError: Output 0 of ReshapeAliasBackward0 is a view and is being modified inplace. + # This view was created inside a custom Function ... + # x = x.clone() + x += cx + + x = self.org_forward(x) # ここで元のモジュールを呼び出す / call the original module here return x -class LoRAControlNet(torch.nn.Module): +class ControlNetLLLite(torch.nn.Module): + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + def __init__( self, unet: sdxl_original_unet.SdxlUNet2DConditionModel, cond_emb_dim: int = 16, - lora_dim: int = 16, - alpha: float = 1, + mlp_dim: int = 16, dropout: Optional[float] = None, varbose: Optional[bool] = False, ) -> None: @@ -161,9 +196,9 @@ class LoRAControlNet(torch.nn.Module): target_replace_modules: List[torch.nn.Module], module_class: Type[object], ) -> List[torch.nn.Module]: - prefix = LoRANetwork.LORA_PREFIX_UNET + prefix = "lllite_unet" - loras = [] + modules = [] for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): @@ -190,13 +225,13 @@ class LoRAControlNet(torch.nn.Module): else: raise NotImplementedError() - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") + lllite_name = prefix + "." + name + "." + child_name + lllite_name = lllite_name.replace(".", "_") if TRANSFORMER_MAX_BLOCK_INDEX is not None: - p = lora_name.find("transformer_blocks") + p = lllite_name.find("transformer_blocks") if p >= 0: - tf_index = int(lora_name[p:].split("_")[2]) + tf_index = int(lllite_name[p:].split("_")[2]) if tf_index > TRANSFORMER_MAX_BLOCK_INDEX: continue @@ -204,104 +239,63 @@ class LoRAControlNet(torch.nn.Module): # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない # time emb is not applied # attn2 conditioning (input from CLIP) cannot be applied because the shape is different - if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)): + if "emb_layers" in lllite_name or ( + "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name) + ): continue if ATTN1_2_ONLY: - if not ("attn1" in lora_name or "attn2" in lora_name): + if not ("attn1" in lllite_name or "attn2" in lllite_name): continue + if ATTN_QKV_ONLY: + if "to_out" in lllite_name: + continue if ATTN1_ETC_ONLY: - if "proj_out" in lora_name: + if "proj_out" in lllite_name: pass - elif "attn1" in lora_name and ("to_k" in lora_name or "to_v" in lora_name or "to_out" in lora_name): + elif "attn1" in lllite_name and ( + "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name + ): pass - elif "ff_net_2" in lora_name: + elif "ff_net_2" in lllite_name: pass else: continue - lora = module_class( + module = module_class( depth, cond_emb_dim, - lora_name, + lllite_name, child_module, - 1.0, - lora_dim, - alpha, + mlp_dim, dropout=dropout, ) - loras.append(lora) - return loras + modules.append(module) + return modules - target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE if not TRANSFORMER_ONLY: - target_modules = target_modules + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 # create module instances - self.unet_loras: List[LoRAModuleControlNet] = create_modules(unet, target_modules, LoRAModuleControlNet) - print(f"create ControlNet LoRA for U-Net: {len(self.unet_loras)} modules.") - - # conditioning image embedding - - # control画像そのままではLoRA的モジュールの入力にはサイズもチャネルも扱いにくいので、 - # 適切な潜在空間に変換する。ここでは、conditioning image embeddingと呼ぶ - # ただcontrol画像自体にはあまり情報量はないので、conditioning image embeddingはわりと小さくてよいはず - # また、conditioning image embeddingは、各LoRA的モジュールでさらに個別に学習する - # depthに応じて3つのサイズを用意する - - # conditioning image embedding is converted to an appropriate latent space - # because the size and channels of the input to the LoRA-like module are difficult to handle - # we call it conditioning image embedding - # however, the control image itself does not have much information, so the conditioning image embedding should be small - # conditioning image embedding is also learned individually in each LoRA-like module - # prepare three sizes according to depth - - self.cond_block0 = torch.nn.Sequential( - torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0), # to latent (from VAE) size - torch.nn.ReLU(inplace=True), - torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=3, stride=2, padding=1), - torch.nn.ReLU(inplace=True), - ) - self.cond_block1 = torch.nn.Sequential( - torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=1), - torch.nn.ReLU(inplace=True), - torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1), - torch.nn.ReLU(inplace=True), - ) - self.cond_block2 = torch.nn.Sequential( - torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1), - torch.nn.ReLU(inplace=True), - ) + self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule) + print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") def forward(self, x): - x = self.cond_block0(x) - x0 = x - x = self.cond_block1(x) - x1 = x - x = self.cond_block2(x) - x2 = x + return x # dummy - x_3d = [] # for Linear - for x0 in [x0, x1, x2]: - # b,c,h,w -> b,h*w,c - n, c, h, w = x0.shape - x0 = x0.view(n, c, h * w).permute(0, 2, 1) - x_3d.append(x0) - - return [x0, x1, x2], x_3d - - def set_cond_embs(self, cond_embs_4d, cond_embs_3d): + def set_cond_image(self, cond_image): r""" 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む / call the model inside, so if necessary, surround it with torch.no_grad() """ - for lora in self.unet_loras: - lora.set_cond_embs(cond_embs_4d, cond_embs_3d) + for module in self.unet_modules: + module.set_cond_image(cond_image) def set_batch_cond_only(self, cond_only, zeros): - for lora in self.unet_loras: - lora.set_batch_cond_only(cond_only, zeros) + for module in self.unet_modules: + module.set_batch_cond_only(cond_only, zeros) def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": @@ -315,10 +309,10 @@ class LoRAControlNet(torch.nn.Module): return info def apply_to(self): - print("applying LoRA for U-Net...") - for lora in self.unet_loras: - lora.apply_to() - self.add_module(lora.lora_name, lora) + print("applying LLLite for U-Net...") + for module in self.unet_modules: + module.apply_to() + self.add_module(module.lllite_name, module) # マージできるかどうかを返す def is_mergeable(self): @@ -367,16 +361,15 @@ class LoRAControlNet(torch.nn.Module): if __name__ == "__main__": # デバッグ用 / for debug - # これを指定しないとエラーが出てcond_blockが学習できない / if not specified, an error occurs and cond_block cannot be learned - sdxl_original_unet.USE_REENTRANT = False + # sdxl_original_unet.USE_REENTRANT = False # test shape etc print("create unet") unet = sdxl_original_unet.SdxlUNet2DConditionModel() unet.to("cuda").to(torch.float16) - print("create LoRA controlnet") - control_net = LoRAControlNet(unet, 64, 32, 1) + print("create ControlNet-LLLite") + control_net = ControlNetLLLite(unet, 32, 64) control_net.apply_to() control_net.to("cuda") @@ -414,6 +407,7 @@ if __name__ == "__main__": print("start training") steps = 10 + sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0] for step in range(steps): print(f"step {step}") @@ -425,8 +419,7 @@ if __name__ == "__main__": y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() with torch.cuda.amp.autocast(enabled=True): - cond_embs_4d, cond_embs_3d = control_net(conditioning_image) - control_net.set_cond_embs(cond_embs_4d, cond_embs_3d) + control_net.set_cond_image(conditioning_image) output = unet(x, t, ctx, y) target = torch.randn_like(output) @@ -436,3 +429,8 @@ if __name__ == "__main__": scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) + print(sample_param) + + # from safetensors.torch import save_file + + # save_file(control_net.state_dict(), "logs/control_net.safetensors") diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 69c0bd1d..cb16a781 100644 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -47,10 +47,9 @@ import library.train_util as train_util import library.sdxl_model_util as sdxl_model_util import library.sdxl_train_util as sdxl_train_util from networks.lora import LoRANetwork -import tools.original_control_net as original_control_net -from tools.original_control_net import ControlNetInfo from library.sdxl_original_unet import SdxlUNet2DConditionModel from library.original_unet import FlashAttentionFunction +from networks.control_net_lllite import ControlNetLLLite # scheduler: SCHEDULER_LINEAR_START = 0.00085 @@ -327,7 +326,7 @@ class PipelineLike: self.token_replacements_list.append({}) # ControlNet # not supported yet - self.control_nets: List[ControlNetInfo] = [] + self.control_nets: List[ControlNetLLLite] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない # Textual Inversion @@ -392,6 +391,7 @@ class PipelineLike: is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: Optional[int] = 1, img2img_noise=None, + clip_guide_images=None, **kwargs, ): # TODO support secondary prompt @@ -496,11 +496,16 @@ class PipelineLike: text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) if self.control_nets: + # ControlNetのhintにguide imageを流用する if isinstance(clip_guide_images, PIL.Image.Image): clip_guide_images = [clip_guide_images] + if isinstance(clip_guide_images[0], PIL.Image.Image): + clip_guide_images = [preprocess_image(im) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images) + if isinstance(clip_guide_images, list): + clip_guide_images = torch.stack(clip_guide_images) - # ControlNetのhintにguide imageを流用する - # 前処理はControlNet側で行う + clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype) # create size embs if original_height is None: @@ -654,35 +659,47 @@ class PipelineLike: num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 if self.control_nets: - guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if self.control_net_enabled: + for control_net in self.control_nets: + with torch.no_grad(): + control_net.set_cond_image(clip_guide_images) + else: + for control_net in self.control_nets: + control_net.set_cond_image(None) for i, t in enumerate(tqdm(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - if self.control_nets and self.control_net_enabled: - if reginonal_network: - num_sub_and_neg_prompts = len(text_embeddings) // batch_size - text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt - else: - text_emb_last = text_embeddings + # # disable control net if ratio is set + # if self.control_nets and self.control_net_enabled: + # pass # TODO - # not working yet - noise_pred = original_control_net.call_unet_and_control_net( - i, - num_latent_input, - self.unet, - self.control_nets, - guided_hints, - i / len(timesteps), - latent_model_input, - t, - text_emb_last, - ).sample - else: - noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + # predict the noise residual + # TODO Diffusers' ControlNet + # if self.control_nets and self.control_net_enabled: + # if reginonal_network: + # num_sub_and_neg_prompts = len(text_embeddings) // batch_size + # text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt + # else: + # text_emb_last = text_embeddings + + # # not working yet + # noise_pred = original_control_net.call_unet_and_control_net( + # i, + # num_latent_input, + # self.unet, + # self.control_nets, + # guided_hints, + # i / len(timesteps), + # latent_model_input, + # t, + # text_emb_last, + # ).sample + # else: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) # perform guidance if do_classifier_free_guidance: @@ -1550,16 +1567,40 @@ def main(args): upscaler.to(dtype).to(device) # ControlNetの処理 - control_nets: List[ControlNetInfo] = [] - if args.control_net_models: - for i, model in enumerate(args.control_net_models): - prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] - weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] - ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + control_nets: List[ControlNetLLLite] = [] + # if args.control_net_models: + # for i, model in enumerate(args.control_net_models): + # prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + # weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + # ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] - ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model) - prep = original_control_net.load_preprocess(prep_type) - control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + # ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model) + # prep = original_control_net.load_preprocess(prep_type) + # control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + if args.control_net_lllite_models: + for i, model_file in enumerate(args.control_net_lllite_models): + print(f"loading ControlNet-LLLite: {model_file}") + + from safetensors.torch import load_file + + state_dict = load_file(model_file) + mlp_dim = None + cond_emb_dim = None + for key, value in state_dict.items(): + if mlp_dim is None and "down.0.weight" in key: + mlp_dim = value.shape[0] + elif cond_emb_dim is None and "conditioning1.0" in key: + cond_emb_dim = value.shape[0] * 2 + if mlp_dim is not None and cond_emb_dim is not None: + break + assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}" + + control_net = ControlNetLLLite(unet, cond_emb_dim, mlp_dim) + control_net.apply_to() + control_net.load_state_dict(state_dict) + control_net.to(dtype).to(device) + control_net.set_batch_cond_only(False, False) + control_nets.append(control_net) if args.opt_channels_last: print(f"set optimizing: channels last") @@ -1572,8 +1613,9 @@ def main(args): network.to(memory_format=torch.channels_last) for cn in control_nets: - cn.unet.to(memory_format=torch.channels_last) - cn.net.to(memory_format=torch.channels_last) + cn.to(memory_format=torch.channels_last) + # cn.unet.to(memory_format=torch.channels_last) + # cn.net.to(memory_format=torch.channels_last) pipe = PipelineLike( device, @@ -2573,20 +2615,23 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" - ) - parser.add_argument( - "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" - ) - parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み") - parser.add_argument( - "--control_net_ratios", - type=float, - default=None, - nargs="*", - help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + "--control_net_lllite_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" ) # parser.add_argument( + # "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + # ) + # parser.add_argument( + # "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" + # ) + # parser.add_argument("--control_net_multiplier", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率") + # parser.add_argument( + # "--control_net_ratios", + # type=float, + # default=None, + # nargs="*", + # help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + # ) + # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # ) diff --git a/sdxl_gen_img_lora_ctrl_test.py b/sdxl_gen_img_lora_ctrl_test.py deleted file mode 100644 index b1cab576..00000000 --- a/sdxl_gen_img_lora_ctrl_test.py +++ /dev/null @@ -1,2612 +0,0 @@ -# temporary test code for LoRA control net - -import itertools -import json -from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable -import glob -import importlib -import inspect -import time -import zipfile -from diffusers.utils import deprecate -from diffusers.configuration_utils import FrozenDict -import argparse -import math -import os -import random -import re - -import diffusers -import numpy as np -import torch -import torchvision -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - DPMSolverSinglestepScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - DDIMScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - KDPM2DiscreteScheduler, - KDPM2AncestralDiscreteScheduler, - # UNet2DConditionModel, - StableDiffusionPipeline, -) -from einops import rearrange -from tqdm import tqdm -from torchvision import transforms -from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig -import PIL -from PIL import Image -from PIL.PngImagePlugin import PngInfo - -import library.model_util as model_util -import library.train_util as train_util -import library.sdxl_model_util as sdxl_model_util -import library.sdxl_train_util as sdxl_train_util -from networks.lora import LoRANetwork -from library.sdxl_original_unet import SdxlUNet2DConditionModel -from library.original_unet import FlashAttentionFunction -from networks.lora_control_net import LoRAControlNet - -# scheduler: -SCHEDULER_LINEAR_START = 0.00085 -SCHEDULER_LINEAR_END = 0.0120 -SCHEDULER_TIMESTEPS = 1000 -SCHEDLER_SCHEDULE = "scaled_linear" - -# その他の設定 -LATENT_CHANNELS = 4 -DOWNSAMPLING_FACTOR = 8 - -# region モジュール入れ替え部 -""" -高速化のためのモジュール入れ替え -""" - - -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): - if mem_eff_attn: - print("Enable memory efficient attention for U-Net") - - # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い - unet.set_use_memory_efficient_attention(False, True) - elif xformers: - print("Enable xformers for U-Net") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - unet.set_use_memory_efficient_attention(True, False) - elif sdpa: - print("Enable SDPA for U-Net") - unet.set_use_memory_efficient_attention(False, False) - unet.set_use_sdpa(True) - - -# TODO common train_util.py -def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): - if mem_eff_attn: - replace_vae_attn_to_memory_efficient() - elif xformers: - # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す? - vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う - elif sdpa: - replace_vae_attn_to_sdpa() - - -def replace_vae_attn_to_memory_efficient(): - print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, hidden_states, **kwargs): - q_bucket_size = 512 - k_bucket_size = 1024 - - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_flash_attn_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_flash_attn(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_flash_attn - - -def replace_vae_attn_to_xformers(): - print("VAE: Attention.forward has been replaced to xformers") - import xformers.ops - - def forward_xformers(self, hidden_states, **kwargs): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - query_proj = query_proj.contiguous() - key_proj = key_proj.contiguous() - value_proj = value_proj.contiguous() - out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_xformers_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_xformers(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_xformers - - -def replace_vae_attn_to_sdpa(): - print("VAE: Attention.forward has been replaced to sdpa") - - def forward_sdpa(self, hidden_states, **kwargs): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - out = torch.nn.functional.scaled_dot_product_attention( - query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - out = rearrange(out, "b n h d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_sdpa_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_sdpa(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_sdpa - - -# endregion - -# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 -# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py -# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 - - -class PipelineLike: - def __init__( - self, - device, - vae: AutoencoderKL, - text_encoders: List[CLIPTextModel], - tokenizers: List[CLIPTokenizer], - unet: SdxlUNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - clip_skip: int, - ): - super().__init__() - self.device = device - self.clip_skip = clip_skip - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - self.vae = vae - self.text_encoders = text_encoders - self.tokenizers = tokenizers - self.unet: SdxlUNet2DConditionModel = unet - self.scheduler = scheduler - self.safety_checker = None - - # Textual Inversion - self.token_replacements_list = [] - for _ in range(len(self.text_encoders)): - self.token_replacements_list.append({}) - - # ControlNet # not supported yet - self.control_nets: List[LoRAControlNet] = [] - self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない - - # Textual Inversion - def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): - self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids - - def set_enable_control_net(self, en: bool): - self.control_net_enabled = en - - def get_token_replacer(self, tokenizer): - tokenizer_index = self.tokenizers.index(tokenizer) - token_replacements = self.token_replacements_list[tokenizer_index] - - def replace_tokens(tokens): - # print("replace_tokens", tokens, "=>", token_replacements) - if isinstance(tokens, torch.Tensor): - tokens = tokens.tolist() - - new_tokens = [] - for token in tokens: - if token in token_replacements: - replacement = token_replacements[token] - new_tokens.extend(replacement) - else: - new_tokens.append(token) - return new_tokens - - return replace_tokens - - def set_control_nets(self, ctrl_nets): - self.control_nets = ctrl_nets - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - height: int = 1024, - width: int = 1024, - original_height: int = None, - original_width: int = None, - original_height_negative: int = None, - original_width_negative: int = None, - crop_top: int = 0, - crop_left: int = 0, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_scale: float = None, - strength: float = 0.8, - # num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - vae_batch_size: float = None, - return_latents: bool = False, - # return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: Optional[int] = 1, - img2img_noise=None, - clip_guide_images=None, - **kwargs, - ): - # TODO support secondary prompt - num_images_per_prompt = 1 # fixed because already prompt is repeated - - if isinstance(prompt, str): - batch_size = 1 - prompt = [prompt] - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - reginonal_network = " AND " in prompt[0] - - vae_batch_size = ( - batch_size - if vae_batch_size is None - else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) - ) - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." - ) - - # get prompt text embeddings - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") - negative_scale = None - - # get unconditional embeddings for classifier free guidance - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - tes_text_embs = [] - tes_uncond_embs = [] - tes_real_uncond_embs = [] - # use last pool - for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): - token_replacer = self.get_token_replacer(tokenizer) - - text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings( - tokenizer, - text_encoder, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - token_replacer=token_replacer, - device=self.device, - **kwargs, - ) - tes_text_embs.append(text_embeddings) - tes_uncond_embs.append(uncond_embeddings) - - if negative_scale is not None: - _, real_uncond_embeddings, _ = get_weighted_text_embeddings( - token_replacer, - prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 - uncond_prompt=[""] * batch_size, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - token_replacer=token_replacer, - device=self.device, - **kwargs, - ) - tes_real_uncond_embs.append(real_uncond_embeddings) - - # concat text encoder outputs - text_embeddings = tes_text_embs[0] - uncond_embeddings = tes_uncond_embs[0] - for i in range(1, len(tes_text_embs)): - text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 - uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 - - if do_classifier_free_guidance: - if negative_scale is None: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - else: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - - if self.control_nets: - # ControlNetのhintにguide imageを流用する - if isinstance(clip_guide_images, PIL.Image.Image): - clip_guide_images = [clip_guide_images] - if isinstance(clip_guide_images[0], PIL.Image.Image): - clip_guide_images = [preprocess_image(im) for im in clip_guide_images] - clip_guide_images = torch.cat(clip_guide_images) - if isinstance(clip_guide_images, list): - clip_guide_images = torch.stack(clip_guide_images) - - clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype) - - # create size embs - if original_height is None: - original_height = height - if original_width is None: - original_width = width - if original_height_negative is None: - original_height_negative = original_height - if original_width_negative is None: - original_width_negative = original_width - if crop_top is None: - crop_top = 0 - if crop_left is None: - crop_left = 0 - emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) - uc_emb1 = sdxl_train_util.get_timestep_embedding( - torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256 - ) - emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) - emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) - c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) - uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) - - c_vector = torch.cat([text_pool, c_vector], dim=1) - uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) - - vector_embeddings = torch.cat([uc_vector, c_vector]) - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps, self.device) - - latents_dtype = text_embeddings.dtype - init_latents_orig = None - mask = None - - if init_image is None: - # get the initial random noise unless the user supplied it - - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_shape = ( - batch_size * num_images_per_prompt, - self.unet.in_channels, - height // 8, - width // 8, - ) - - if latents is None: - if self.device.type == "mps": - # randn does not exist on mps - latents = torch.randn( - latents_shape, - generator=generator, - device="cpu", - dtype=latents_dtype, - ).to(self.device) - else: - latents = torch.randn( - latents_shape, - generator=generator, - device=self.device, - dtype=latents_dtype, - ) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) - - timesteps = self.scheduler.timesteps.to(self.device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - else: - # image to tensor - if isinstance(init_image, PIL.Image.Image): - init_image = [init_image] - if isinstance(init_image[0], PIL.Image.Image): - init_image = [preprocess_image(im) for im in init_image] - init_image = torch.cat(init_image) - if isinstance(init_image, list): - init_image = torch.stack(init_image) - - # mask image to tensor - if mask_image is not None: - if isinstance(mask_image, PIL.Image.Image): - mask_image = [mask_image] - if isinstance(mask_image[0], PIL.Image.Image): - mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint - - # encode the init image into latents and scale the latents - init_image = init_image.to(device=self.device, dtype=latents_dtype) - if init_image.size()[-2:] == (height // 8, width // 8): - init_latents = init_image - else: - if vae_batch_size >= batch_size: - init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() - init_latents = [] - for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): - init_latent_dist = self.vae.encode( - (init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)).to( - self.vae.dtype - ) - ).latent_dist - init_latents.append(init_latent_dist.sample(generator=generator)) - init_latents = torch.cat(init_latents) - - init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents - - if len(init_latents) == 1: - init_latents = init_latents.repeat((batch_size, 1, 1, 1)) - init_latents_orig = init_latents - - # preprocess mask - if mask_image is not None: - mask = mask_image.to(device=self.device, dtype=latents_dtype) - if len(mask) == 1: - mask = mask.repeat((batch_size, 1, 1, 1)) - - # check sizes - if not mask.shape == init_latents.shape: - raise ValueError("The mask and init_image should be the same size!") - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) - - # add noise to latents using the timesteps - latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].to(self.device) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 - - if self.control_nets: - if self.control_net_enabled: - for control_net in self.control_nets: - with torch.no_grad(): - cond_embs_4d, cond_embs_3d = control_net(clip_guide_images) - control_net.set_cond_embs(cond_embs_4d, cond_embs_3d) - else: - for control_net in self.control_nets: - control_net.set_cond_image(None) - - for i, t in enumerate(tqdm(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # # disable control net if ratio is set - # if self.control_nets and self.control_net_enabled: - # pass # TODO - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) - - # perform guidance - if do_classifier_free_guidance: - if negative_scale is None: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - else: - noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( - num_latent_input - ) # uncond is real uncond - noise_pred = ( - noise_pred_uncond - + guidance_scale * (noise_pred_text - noise_pred_uncond) - - negative_scale * (noise_pred_negative - noise_pred_uncond) - ) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if mask is not None: - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # call the callback, if provided - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None - - if return_latents: - return (latents, False) - - latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents - if vae_batch_size >= batch_size: - image = self.vae.decode(latents.to(self.vae.dtype)).sample - else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() - images = [] - for i in tqdm(range(0, batch_size, vae_batch_size)): - images.append( - self.vae.decode( - (latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).to(self.vae.dtype) - ).sample - ) - image = torch.cat(images) - - image = (image / 2 + 0.5).clamp(0, 1) - - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - if output_type == "pil": - # image = self.numpy_to_pil(image) - image = (image * 255).round().astype("uint8") - image = [Image.fromarray(im) for im in image] - - return image - - # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - -re_attention = re.compile( - r""" -\\\(| -\\\)| -\\\[| -\\]| -\\\\| -\\| -\(| -\[| -:([+-]?[.\d]+)\)| -\)| -]| -[^\\()\[\]:]+| -: -""", - re.X, -) - - -def parse_prompt_attention(text): - """ - Parses a string with attention tokens and returns a list of pairs: text and its associated weight. - Accepted tokens are: - (abc) - increases attention to abc by a multiplier of 1.1 - (abc:3.12) - increases attention to abc by a multiplier of 3.12 - [abc] - decreases attention to abc by a multiplier of 1.1 - \( - literal character '(' - \[ - literal character '[' - \) - literal character ')' - \] - literal character ']' - \\ - literal character '\' - anything else - just text - >>> parse_prompt_attention('normal text') - [['normal text', 1.0]] - >>> parse_prompt_attention('an (important) word') - [['an ', 1.0], ['important', 1.1], [' word', 1.0]] - >>> parse_prompt_attention('(unbalanced') - [['unbalanced', 1.1]] - >>> parse_prompt_attention('\(literal\]') - [['(literal]', 1.0]] - >>> parse_prompt_attention('(unnecessary)(parens)') - [['unnecessaryparens', 1.1]] - >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') - [['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1]] - """ - - res = [] - round_brackets = [] - square_brackets = [] - - round_bracket_multiplier = 1.1 - square_bracket_multiplier = 1 / 1.1 - - def multiply_range(start_position, multiplier): - for p in range(start_position, len(res)): - res[p][1] *= multiplier - - # keep break as separate token - text = text.replace("BREAK", "\\BREAK\\") - - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) - - if text.startswith("\\"): - res.append([text[1:], 1.0]) - elif text == "(": - round_brackets.append(len(res)) - elif text == "[": - square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ")" and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == "]" and len(square_brackets) > 0: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - res.append([text, 1.0]) - - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) - - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) - - if len(res) == 0: - res = [["", 1.0]] - - # merge runs of identical weights - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 - - return res - - -def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: List[str], max_length: int): - r""" - Tokenize a list of prompts and return its tokens with weights of each token. - No padding, starting or ending token is included. - """ - tokens = [] - weights = [] - truncated = False - - for text in prompt: - texts_and_weights = parse_prompt_attention(text) - text_token = [] - text_weight = [] - for word, weight in texts_and_weights: - if word.strip() == "BREAK": - # pad until next multiple of tokenizer's max token length - pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) - print(f"BREAK pad_len: {pad_len}") - for i in range(pad_len): - # v2のときEOSをつけるべきかどうかわからないぜ - # if i == 0: - # text_token.append(tokenizer.eos_token_id) - # else: - text_token.append(tokenizer.pad_token_id) - text_weight.append(1.0) - continue - - # tokenize and discard the starting and the ending token - token = tokenizer(word).input_ids[1:-1] - - token = token_replacer(token) # for Textual Inversion - - text_token += token - # copy the weight by length of token - text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) - if len(text_token) > max_length: - truncated = True - break - # truncate - if len(text_token) > max_length: - truncated = True - text_token = text_token[:max_length] - text_weight = text_weight[:max_length] - tokens.append(text_token) - weights.append(text_weight) - if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") - return tokens, weights - - -def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): - r""" - Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. - """ - max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) - weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length - for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) - if no_boseos_middle: - weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) - else: - w = [] - if len(weights[i]) == 0: - w = [1.0] * weights_length - else: - for j in range(max_embeddings_multiples): - w.append(1.0) # weight for starting token in this chunk - w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] - w.append(1.0) # weight for ending token in this chunk - w += [1.0] * (weights_length - len(w)) - weights[i] = w[:] - - return tokens, weights - - -def get_unweighted_text_embeddings( - text_encoder: CLIPTextModel, - text_input: torch.Tensor, - chunk_length: int, - clip_skip: int, - eos: int, - pad: int, - no_boseos_middle: Optional[bool] = True, -): - """ - When the length of tokens is a multiple of the capacity of the text encoder, - it should be split into chunks and sent to the text encoder individually. - """ - max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) - if max_embeddings_multiples > 1: - text_embeddings = [] - pool = None - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() - - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - if pad == eos: # v1 - text_input_chunk[:, -1] = text_input[0, -1] - else: # v2 - for j in range(len(text_input_chunk)): - if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある - text_input_chunk[j, -1] = eos - if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD - text_input_chunk[j, 1] = eos - - # -2 is same for Text Encoder 1 and 2 - enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) - text_embedding = enc_out["hidden_states"][-2] - if pool is None: - pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided - if pool is not None: - pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos) - - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] - - text_embeddings.append(text_embedding) - text_embeddings = torch.concat(text_embeddings, axis=1) - else: - enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) - text_embeddings = enc_out["hidden_states"][-2] - pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this - if pool is not None: - pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input, eos) - return text_embeddings, pool - - -def get_weighted_text_embeddings( - tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, - prompt: Union[str, List[str]], - uncond_prompt: Optional[Union[str, List[str]]] = None, - max_embeddings_multiples: Optional[int] = 1, - no_boseos_middle: Optional[bool] = False, - skip_parsing: Optional[bool] = False, - skip_weighting: Optional[bool] = False, - clip_skip=None, - token_replacer=None, - device=None, - **kwargs, -): - max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - if isinstance(prompt, str): - prompt = [prompt] - - # split the prompts with "AND". each prompt must have the same number of splits - new_prompts = [] - for p in prompt: - new_prompts.extend(p.split(" AND ")) - prompt = new_prompts - - if not skip_parsing: - prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2) - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2) - else: - prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] - prompt_weights = [[1.0] * len(token) for token in prompt_tokens] - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids] - uncond_weights = [[1.0] * len(token) for token in uncond_tokens] - - # round up the longest length of tokens to a multiple of (model_max_length - 2) - max_length = max([len(token) for token in prompt_tokens]) - if uncond_prompt is not None: - max_length = max(max_length, max([len(token) for token in uncond_tokens])) - - max_embeddings_multiples = min( - max_embeddings_multiples, - (max_length - 1) // (tokenizer.model_max_length - 2) + 1, - ) - max_embeddings_multiples = max(1, max_embeddings_multiples) - max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - - # pad the length of tokens and weights - bos = tokenizer.bos_token_id - eos = tokenizer.eos_token_id - pad = tokenizer.pad_token_id - prompt_tokens, prompt_weights = pad_tokens_and_weights( - prompt_tokens, - prompt_weights, - max_length, - bos, - eos, - pad, - no_boseos_middle=no_boseos_middle, - chunk_length=tokenizer.model_max_length, - ) - prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) - if uncond_prompt is not None: - uncond_tokens, uncond_weights = pad_tokens_and_weights( - uncond_tokens, - uncond_weights, - max_length, - bos, - eos, - pad, - no_boseos_middle=no_boseos_middle, - chunk_length=tokenizer.model_max_length, - ) - uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) - - # get the embeddings - text_embeddings, text_pool = get_unweighted_text_embeddings( - text_encoder, - prompt_tokens, - tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) - if uncond_prompt is not None: - uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( - text_encoder, - uncond_tokens, - tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=device) - - # assign weights to the prompts and normalize in the sense of mean - # TODO: should we normalize by chunk or in a whole (current implementation)? - # →全体でいいんじゃないかな - if (not skip_parsing) and (not skip_weighting): - previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= prompt_weights.unsqueeze(-1) - current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - if uncond_prompt is not None: - previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= uncond_weights.unsqueeze(-1) - current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - - if uncond_prompt is not None: - return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens - return text_embeddings, text_pool, None, None, prompt_tokens - - -def preprocess_image(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 - - -def preprocess_mask(mask): - mask = mask.convert("L") - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) - return mask - - -# regular expression for dynamic prompt: -# starts and ends with "{" and "}" -# contains at least one variant divided by "|" -# optional framgments divided by "$$" at start -# if the first fragment is "E" or "e", enumerate all variants -# if the second fragment is a number or two numbers, repeat the variants in the range -# if the third fragment is a string, use it as a separator - -RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") - - -def handle_dynamic_prompt_variants(prompt, repeat_count): - founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) - if not founds: - return [prompt] - - # make each replacement for each variant - enumerating = False - replacers = [] - for found in founds: - # if "e$$" is found, enumerate all variants - found_enumerating = found.group(2) is not None - enumerating = enumerating or found_enumerating - - separator = ", " if found.group(6) is None else found.group(6) - variants = found.group(7).split("|") - - # parse count range - count_range = found.group(4) - if count_range is None: - count_range = [1, 1] - else: - count_range = count_range.split("-") - if len(count_range) == 1: - count_range = [int(count_range[0]), int(count_range[0])] - elif len(count_range) == 2: - count_range = [int(count_range[0]), int(count_range[1])] - else: - print(f"invalid count range: {count_range}") - count_range = [1, 1] - if count_range[0] > count_range[1]: - count_range = [count_range[1], count_range[0]] - if count_range[0] < 0: - count_range[0] = 0 - if count_range[1] > len(variants): - count_range[1] = len(variants) - - if found_enumerating: - # make function to enumerate all combinations - def make_replacer_enum(vari, cr, sep): - def replacer(): - values = [] - for count in range(cr[0], cr[1] + 1): - for comb in itertools.combinations(vari, count): - values.append(sep.join(comb)) - return values - - return replacer - - replacers.append(make_replacer_enum(variants, count_range, separator)) - else: - # make function to choose random combinations - def make_replacer_single(vari, cr, sep): - def replacer(): - count = random.randint(cr[0], cr[1]) - comb = random.sample(vari, count) - return [sep.join(comb)] - - return replacer - - replacers.append(make_replacer_single(variants, count_range, separator)) - - # make each prompt - if not enumerating: - # if not enumerating, repeat the prompt, replace each variant randomly - prompts = [] - for _ in range(repeat_count): - current = prompt - for found, replacer in zip(founds, replacers): - current = current.replace(found.group(0), replacer()[0], 1) - prompts.append(current) - else: - # if enumerating, iterate all combinations for previous prompts - prompts = [prompt] - - for found, replacer in zip(founds, replacers): - if found.group(2) is not None: - # make all combinations for existing prompts - new_prompts = [] - for current in prompts: - replecements = replacer() - for replecement in replecements: - new_prompts.append(current.replace(found.group(0), replecement, 1)) - prompts = new_prompts - - for found, replacer in zip(founds, replacers): - # make random selection for existing prompts - if found.group(2) is None: - for i in range(len(prompts)): - prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) - - return prompts - - -# endregion - - -# def load_clip_l14_336(dtype): -# print(f"loading CLIP: {CLIP_ID_L14_336}") -# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) -# return text_encoder - - -class BatchDataBase(NamedTuple): - # バッチ分割が必要ないデータ - step: int - prompt: str - negative_prompt: str - seed: int - init_image: Any - mask_image: Any - clip_prompt: str - guide_image: Any - - -class BatchDataExt(NamedTuple): - # バッチ分割が必要なデータ - width: int - height: int - original_width: int - original_height: int - original_width_negative: int - original_height_negative: int - crop_left: int - crop_top: int - steps: int - scale: float - negative_scale: float - strength: float - network_muls: Tuple[float] - num_sub_prompts: int - - -class BatchData(NamedTuple): - return_latents: bool - base: BatchDataBase - ext: BatchDataExt - - -def main(args): - if args.fp16: - dtype = torch.float16 - elif args.bf16: - dtype = torch.bfloat16 - else: - dtype = torch.float32 - - highres_fix = args.highres_fix_scale is not None - # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - - # モデルを読み込む - if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う - files = glob.glob(args.ckpt) - if len(files) == 1: - args.ckpt = files[0] - - (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype - ) - - # xformers、Hypernetwork対応 - if not args.diffusers_xformers: - mem_eff = not (args.xformers or args.sdpa) - replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) - replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) - - # tokenizerを読み込む - print("loading tokenizer") - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - - # schedulerを用意する - sched_init_args = {} - has_steps_offset = True - has_clip_sample = True - scheduler_num_noises_per_step = 1 - - if args.sampler == "ddim": - scheduler_cls = DDIMScheduler - scheduler_module = diffusers.schedulers.scheduling_ddim - elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - scheduler_module = diffusers.schedulers.scheduling_ddpm - elif args.sampler == "pndm": - scheduler_cls = PNDMScheduler - scheduler_module = diffusers.schedulers.scheduling_pndm - has_clip_sample = False - elif args.sampler == "lms" or args.sampler == "k_lms": - scheduler_cls = LMSDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_lms_discrete - has_clip_sample = False - elif args.sampler == "euler" or args.sampler == "k_euler": - scheduler_cls = EulerDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_euler_discrete - has_clip_sample = False - elif args.sampler == "euler_a" or args.sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete - has_clip_sample = False - elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args["algorithm_type"] = args.sampler - scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep - has_clip_sample = False - elif args.sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep - has_clip_sample = False - has_steps_offset = False - elif args.sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_heun_discrete - has_clip_sample = False - elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": - scheduler_cls = KDPM2DiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete - has_clip_sample = False - elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": - scheduler_cls = KDPM2AncestralDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete - scheduler_num_noises_per_step = 2 - has_clip_sample = False - - # 警告を出さないようにする - if has_steps_offset: - sched_init_args["steps_offset"] = 1 - if has_clip_sample: - sched_init_args["clip_sample"] = False - - # samplerの乱数をあらかじめ指定するための処理 - - # replace randn - class NoiseManager: - def __init__(self): - self.sampler_noises = None - self.sampler_noise_index = 0 - - def reset_sampler_noises(self, noises): - self.sampler_noise_index = 0 - self.sampler_noises = noises - - def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) - if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): - noise = self.sampler_noises[self.sampler_noise_index] - if shape != noise.shape: - noise = None - else: - noise = None - - if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") - noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) - - self.sampler_noise_index += 1 - return noise - - class TorchRandReplacer: - def __init__(self, noise_manager): - self.noise_manager = noise_manager - - def __getattr__(self, item): - if item == "randn": - return self.noise_manager.randn - if hasattr(torch, item): - return getattr(torch, item) - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - - noise_manager = NoiseManager() - if scheduler_module is not None: - scheduler_module.torch = TorchRandReplacer(noise_manager) - - scheduler = scheduler_cls( - num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, - beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, - **sched_init_args, - ) - - # ↓以下は結局PipeでFalseに設定されるので意味がなかった - # # clip_sample=Trueにする - # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") - # scheduler.config.clip_sample = True - - # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない - - # custom pipelineをコピったやつを生成する - if args.vae_slices: - from library.slicing_vae import SlicingAutoencoderKL - - sli_vae = SlicingAutoencoderKL( - act_fn="silu", - block_out_channels=(128, 256, 512, 512), - down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], - in_channels=3, - latent_channels=4, - layers_per_block=2, - norm_num_groups=32, - out_channels=3, - sample_size=512, - up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], - num_slices=args.vae_slices, - ) - sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする - vae = sli_vae - del sli_vae - - vae_dtype = dtype - if args.no_half_vae: - print("set vae_dtype to float32") - vae_dtype = torch.float32 - vae.to(vae_dtype).to(device) - - text_encoder1.to(dtype).to(device) - text_encoder2.to(dtype).to(device) - unet.to(dtype).to(device) - - # networkを組み込む - if args.network_module: - networks = [] - network_default_muls = [] - network_pre_calc = args.network_pre_calc - - for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) - imported_module = importlib.import_module(network_module) - - network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - network_default_muls.append(network_mul) - - net_kwargs = {} - if args.network_args and i < len(args.network_args): - network_args = args.network_args[i] - # TODO escape special chars - network_args = network_args.split(";") - for net_arg in network_args: - key, value = net_arg.split("=") - net_kwargs[key] = value - - if args.network_weights and i < len(args.network_weights): - network_weight = args.network_weights[i] - print("load network weights from:", network_weight) - - if model_util.is_safetensors(network_weight) and args.network_show_meta: - from safetensors.torch import safe_open - - with safe_open(network_weight, framework="pt") as f: - metadata = f.metadata() - if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") - - network, weights_sd = imported_module.create_network_from_weights( - network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs - ) - else: - raise ValueError("No weight. Weight is required.") - if network is None: - return - - mergeable = network.is_mergeable() - if args.network_merge and not mergeable: - print("network is not mergiable. ignore merge option.") - - if not args.network_merge or not mergeable: - network.apply_to([text_encoder1, text_encoder2], unet) - info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - print(f"weights are loaded: {info}") - - if args.opt_channels_last: - network.to(memory_format=torch.channels_last) - network.to(dtype).to(device) - - if network_pre_calc: - print("backup original weights") - network.backup_weights() - - networks.append(network) - else: - network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device) - - else: - networks = [] - - # upscalerの指定があれば取得する - upscaler = None - if args.highres_fix_upscaler: - print("import upscaler module:", args.highres_fix_upscaler) - imported_module = importlib.import_module(args.highres_fix_upscaler) - - us_kwargs = {} - if args.highres_fix_upscaler_args: - for net_arg in args.highres_fix_upscaler_args.split(";"): - key, value = net_arg.split("=") - us_kwargs[key] = value - - print("create upscaler") - upscaler = imported_module.create_upscaler(**us_kwargs) - upscaler.to(dtype).to(device) - - # ControlNetの処理 - control_nets: List[LoRAControlNet] = [] - if args.control_net_models: - for i, model_file in enumerate(args.control_net_models): - print(f"loading control net: {model_file}") - - from safetensors.torch import load_file - - state_dict = load_file(model_file) - lora_rank = None - emb_dim = None - for key, value in state_dict.items(): - if lora_rank is None and "lora_down.weight" in key: - lora_rank = value.shape[0] - elif emb_dim is None and "conditioning1.0" in key: - emb_dim = value.shape[0] - if lora_rank is not None and emb_dim is not None: - break - assert lora_rank is not None and emb_dim is not None, f"invalid control net: {model_file}" - - control_net = LoRAControlNet(unet, emb_dim, lora_rank, 1) - control_net.apply_to() - control_net.load_state_dict(state_dict) - control_net.to(dtype).to(device) - control_net.set_batch_cond_only(False, False) - control_nets.append(control_net) - - if args.opt_channels_last: - print(f"set optimizing: channels last") - text_encoder1.to(memory_format=torch.channels_last) - text_encoder2.to(memory_format=torch.channels_last) - vae.to(memory_format=torch.channels_last) - unet.to(memory_format=torch.channels_last) - if networks: - for network in networks: - network.to(memory_format=torch.channels_last) - - for cn in control_nets: - cn.to(memory_format=torch.channels_last) - - pipe = PipelineLike( - device, - vae, - [text_encoder1, text_encoder2], - [tokenizer1, tokenizer2], - unet, - scheduler, - args.clip_skip, - ) - pipe.set_control_nets(control_nets) - print("pipeline is ready.") - - if args.diffusers_xformers: - pipe.enable_xformers_memory_efficient_attention() - - # Textual Inversionを処理する - if args.textual_inversion_embeddings: - token_ids_embeds1 = [] - token_ids_embeds2 = [] - for embeds_file in args.textual_inversion_embeddings: - if model_util.is_safetensors(embeds_file): - from safetensors.torch import load_file - - data = load_file(embeds_file) - else: - data = torch.load(embeds_file, map_location="cpu") - - if "string_to_param" in data: - data = data["string_to_param"] - - embeds1 = data["clip_l"] # text encoder 1 - embeds2 = data["clip_g"] # text encoder 2 - - num_vectors_per_token = embeds1.size()[0] - token_string = os.path.splitext(os.path.basename(embeds_file))[0] - - token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] - - # add new word to tokenizer, count is num_vectors_per_token - num_added_tokens1 = tokenizer1.add_tokens(token_strings) - num_added_tokens2 = tokenizer2.add_tokens(token_strings) - assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, ( - f"tokenizer has same word to token string (filename): {embeds_file}" - + f" / 指定した名前(ファイル名)のトークンが既に存在します: {embeds_file}" - ) - - token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) - token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") - assert ( - min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 - ), f"token ids1 is not ordered" - assert ( - min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 - ), f"token ids2 is not ordered" - assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}" - assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}" - - if num_vectors_per_token > 1: - pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ... - pipe.add_token_replacement(1, token_ids2[0], token_ids2) - - token_ids_embeds1.append((token_ids1, embeds1)) - token_ids_embeds2.append((token_ids2, embeds2)) - - text_encoder1.resize_token_embeddings(len(tokenizer1)) - text_encoder2.resize_token_embeddings(len(tokenizer2)) - token_embeds1 = text_encoder1.get_input_embeddings().weight.data - token_embeds2 = text_encoder2.get_input_embeddings().weight.data - for token_ids, embeds in token_ids_embeds1: - for token_id, embed in zip(token_ids, embeds): - token_embeds1[token_id] = embed - for token_ids, embeds in token_ids_embeds2: - for token_id, embed in zip(token_ids, embeds): - token_embeds2[token_id] = embed - - # promptを取得する - if args.from_file is not None: - print(f"reading prompts from {args.from_file}") - with open(args.from_file, "r", encoding="utf-8") as f: - prompt_list = f.read().splitlines() - prompt_list = [d for d in prompt_list if len(d.strip()) > 0] - elif args.prompt is not None: - prompt_list = [args.prompt] - else: - prompt_list = [] - - if args.interactive: - args.n_iter = 1 - - # img2imgの前処理、画像の読み込みなど - def load_images(path): - if os.path.isfile(path): - paths = [path] - else: - paths = ( - glob.glob(os.path.join(path, "*.png")) - + glob.glob(os.path.join(path, "*.jpg")) - + glob.glob(os.path.join(path, "*.jpeg")) - + glob.glob(os.path.join(path, "*.webp")) - ) - paths.sort() - - images = [] - for p in paths: - image = Image.open(p) - if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") - image = image.convert("RGB") - images.append(image) - - return images - - def resize_images(imgs, size): - resized = [] - for img in imgs: - r_img = img.resize(size, Image.Resampling.LANCZOS) - if hasattr(img, "filename"): # filename属性がない場合があるらしい - r_img.filename = img.filename - resized.append(r_img) - return resized - - if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") - init_images = load_images(args.image_path) - assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") - else: - init_images = None - - if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") - mask_images = load_images(args.mask_path) - assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") - else: - mask_images = None - - # promptがないとき、画像のPngInfoから取得する - if init_images is not None and len(prompt_list) == 0 and not args.interactive: - print("get prompts from images' meta data") - for img in init_images: - if "prompt" in img.text: - prompt = img.text["prompt"] - if "negative-prompt" in img.text: - prompt += " --n " + img.text["negative-prompt"] - prompt_list.append(prompt) - - # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) - l = [] - for im in init_images: - l.extend([im] * args.images_per_prompt) - init_images = l - - if mask_images is not None: - l = [] - for im in mask_images: - l.extend([im] * args.images_per_prompt) - mask_images = l - - # 画像サイズにオプション指定があるときはリサイズする - if args.W is not None and args.H is not None: - # highres fix を考慮に入れる - w, h = args.W, args.H - if highres_fix: - w = int(w * args.highres_fix_scale + 0.5) - h = int(h * args.highres_fix_scale + 0.5) - - if init_images is not None: - print(f"resize img2img source images to {w}*{h}") - init_images = resize_images(init_images, (w, h)) - if mask_images is not None: - print(f"resize img2img mask images to {w}*{h}") - mask_images = resize_images(mask_images, (w, h)) - - regional_network = False - if networks and mask_images: - # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 - regional_network = True - print("use mask as region") - - size = None - for i, network in enumerate(networks): - if i < 3: - np_mask = np.array(mask_images[0]) - np_mask = np_mask[:, :, i] - size = np_mask.shape - else: - np_mask = np.full(size, 255, dtype=np.uint8) - mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) - network.set_region(i, i == len(networks) - 1, mask) - mask_images = None - - prev_image = None # for VGG16 guided - if args.guide_image_path is not None: - print(f"load image for ControlNet guidance: {args.guide_image_path}") - guide_images = [] - for p in args.guide_image_path: - guide_images.extend(load_images(p)) - - print(f"loaded {len(guide_images)} guide images for guidance") - if len(guide_images) == 0: - print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") - guide_images = None - else: - guide_images = None - - # seed指定時はseedを決めておく - if args.seed is not None: - # dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう - random.seed(args.seed) - predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] - if len(predefined_seeds) == 1: - predefined_seeds[0] = args.seed - else: - predefined_seeds = None - - # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) - if args.W is None: - args.W = 1024 - if args.H is None: - args.H = 1024 - - # 画像生成のループ - os.makedirs(args.outdir, exist_ok=True) - max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples - - for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") - iter_seed = random.randint(0, 0x7FFFFFFF) - - # バッチ処理の関数 - def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): - batch_size = len(batch) - - # highres_fixの処理 - if highres_fix and not highres_1st: - # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す - is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - - print("process 1st stage") - batch_1st = [] - for _, base, ext in batch: - - def scale_and_round(x): - if x is None: - return None - return int(x * args.highres_fix_scale + 0.5) - - width_1st = scale_and_round(ext.width) - height_1st = scale_and_round(ext.height) - width_1st = width_1st - width_1st % 32 - height_1st = height_1st - height_1st % 32 - - original_width_1st = scale_and_round(ext.original_width) - original_height_1st = scale_and_round(ext.original_height) - original_width_negative_1st = scale_and_round(ext.original_width_negative) - original_height_negative_1st = scale_and_round(ext.original_height_negative) - crop_left_1st = scale_and_round(ext.crop_left) - crop_top_1st = scale_and_round(ext.crop_top) - - strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength - - ext_1st = BatchDataExt( - width_1st, - height_1st, - original_width_1st, - original_height_1st, - original_width_negative_1st, - original_height_negative_1st, - crop_left_1st, - crop_top_1st, - args.highres_fix_steps, - ext.scale, - ext.negative_scale, - strength_1st, - ext.network_muls, - ext.num_sub_prompts, - ) - batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) - - pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする - images_1st = process_batch(batch_1st, True, True) - - # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") - width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height - - if upscaler: - # upscalerを使って画像を拡大する - lowreso_imgs = None if is_1st_latent else images_1st - lowreso_latents = None if not is_1st_latent else images_1st - - # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents - batch_size = len(images_1st) - vae_batch_size = ( - batch_size - if args.vae_batch_size is None - else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) - ) - vae_batch_size = int(vae_batch_size) - images_1st = upscaler.upscale( - vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size - ) - - elif args.highres_fix_latents_upscaling: - # latentを拡大する - org_dtype = images_1st.dtype - if images_1st.dtype == torch.bfloat16: - images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない - images_1st = torch.nn.functional.interpolate( - images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" - ) # , antialias=True) - images_1st = images_1st.to(org_dtype) - - else: - # 画像をLANCZOSで拡大する - images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st] - - batch_2nd = [] - for i, (bd, image) in enumerate(zip(batch, images_1st)): - bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) - batch_2nd.append(bd_2nd) - batch = batch_2nd - - if args.highres_fix_disable_control_net: - pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする - - # このバッチの情報を取り出す - ( - return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image), - ( - width, - height, - original_width, - original_height, - original_width_negative, - original_height_negative, - crop_left, - crop_top, - steps, - scale, - negative_scale, - strength, - network_muls, - num_sub_prompts, - ), - ) = batch[0] - noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) - - prompts = [] - negative_prompts = [] - start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - noises = [ - torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - for _ in range(steps * scheduler_num_noises_per_step) - ] - seeds = [] - clip_prompts = [] - - if init_image is not None: # img2img? - i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - init_images = [] - - if mask_image is not None: - mask_images = [] - else: - mask_images = None - else: - i2i_noises = None - init_images = None - mask_images = None - - if guide_image is not None: # CLIP image guided? - guide_images = [] - else: - guide_images = None - - # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする - all_images_are_same = True - all_masks_are_same = True - all_guide_images_are_same = True - for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): - prompts.append(prompt) - negative_prompts.append(negative_prompt) - seeds.append(seed) - clip_prompts.append(clip_prompt) - - if init_image is not None: - init_images.append(init_image) - if i > 0 and all_images_are_same: - all_images_are_same = init_images[-2] is init_image - - if mask_image is not None: - mask_images.append(mask_image) - if i > 0 and all_masks_are_same: - all_masks_are_same = mask_images[-2] is mask_image - - if guide_image is not None: - if type(guide_image) is list: - guide_images.extend(guide_image) - all_guide_images_are_same = False - else: - guide_images.append(guide_image) - if i > 0 and all_guide_images_are_same: - all_guide_images_are_same = guide_images[-2] is guide_image - - # make start code - torch.manual_seed(seed) - start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) - - # make each noises - for j in range(steps * scheduler_num_noises_per_step): - noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) - - if i2i_noises is not None: # img2img noise - i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) - - noise_manager.reset_sampler_noises(noises) - - # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する - if init_images is not None and all_images_are_same: - init_images = init_images[0] - if mask_images is not None and all_masks_are_same: - mask_images = mask_images[0] - if guide_images is not None and all_guide_images_are_same: - guide_images = guide_images[0] - - # ControlNet使用時はguide imageをリサイズする - if control_nets: - # TODO resampleのメソッド - guide_images = guide_images if type(guide_images) == list else [guide_images] - guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] - if len(guide_images) == 1: - guide_images = guide_images[0] - - # generate - if networks: - # 追加ネットワークの処理 - shared = {} - for n, m in zip(networks, network_muls if network_muls else network_default_muls): - n.set_multiplier(m) - if regional_network: - n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) - - if not regional_network and network_pre_calc: - for n in networks: - n.restore_weights() - for n in networks: - n.pre_calculation() - print("pre-calculation... done") - - images = pipe( - prompts, - negative_prompts, - init_images, - mask_images, - height, - width, - original_height, - original_width, - original_height_negative, - original_width_negative, - crop_top, - crop_left, - steps, - scale, - negative_scale, - strength, - latents=start_code, - output_type="pil", - max_embeddings_multiples=max_embeddings_multiples, - img2img_noise=i2i_noises, - vae_batch_size=args.vae_batch_size, - return_latents=return_latents, - clip_prompts=clip_prompts, - clip_guide_images=guide_images, - ) - if highres_1st and not args.highres_fix_save_1st: # return images or latents - return images - - # save image - highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts) - ): - if highres_fix: - seed -= 1 # record original seed - metadata = PngInfo() - metadata.add_text("prompt", prompt) - metadata.add_text("seed", str(seed)) - metadata.add_text("sampler", args.sampler) - metadata.add_text("steps", str(steps)) - metadata.add_text("scale", str(scale)) - if negative_prompt is not None: - metadata.add_text("negative-prompt", negative_prompt) - if negative_scale is not None: - metadata.add_text("negative-scale", str(negative_scale)) - if clip_prompt is not None: - metadata.add_text("clip-prompt", clip_prompt) - metadata.add_text("original-height", str(original_height)) - metadata.add_text("original-width", str(original_width)) - metadata.add_text("original-height-negative", str(original_height_negative)) - metadata.add_text("original-width-negative", str(original_width_negative)) - metadata.add_text("crop-top", str(crop_top)) - metadata.add_text("crop-left", str(crop_left)) - - if args.use_original_file_name and init_images is not None: - if type(init_images) is list: - fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" - else: - fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" - elif args.sequential_file_name: - fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" - else: - fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" - - image.save(os.path.join(args.outdir, fln), pnginfo=metadata) - - if not args.no_preview and not highres_1st and args.interactive: - try: - import cv2 - - for prompt, image in zip(prompts, images): - cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ - cv2.waitKey() - cv2.destroyAllWindows() - except ImportError: - print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") - - return images - - # 画像生成のプロンプトが一周するまでのループ - prompt_index = 0 - global_step = 0 - batch_data = [] - while args.interactive or prompt_index < len(prompt_list): - if len(prompt_list) == 0: - # interactive - valid = False - while not valid: - print("\nType prompt:") - try: - raw_prompt = input() - except EOFError: - break - - valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 - if not valid: # EOF, end app - break - else: - raw_prompt = prompt_list[prompt_index] - - # sd-dynamic-prompts like variants: - # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) - raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) - - # repeat prompt - for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): - raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] - - if pi == 0 or len(raw_prompts) > 1: - # parse prompt: if prompt is not changed, skip parsing - width = args.W - height = args.H - original_width = args.original_width - original_height = args.original_height - original_width_negative = args.original_width_negative - original_height_negative = args.original_height_negative - crop_top = args.crop_top - crop_left = args.crop_left - scale = args.scale - negative_scale = args.negative_scale - steps = args.steps - seed = None - seeds = None - strength = 0.8 if args.strength is None else args.strength - negative_prompt = "" - clip_prompt = None - network_muls = None - - prompt_args = raw_prompt.strip().split(" --") - prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") - - for parg in prompt_args[1:]: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - print(f"width: {width}") - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - print(f"height: {height}") - continue - - m = re.match(r"ow (\d+)", parg, re.IGNORECASE) - if m: - original_width = int(m.group(1)) - print(f"original width: {original_width}") - continue - - m = re.match(r"oh (\d+)", parg, re.IGNORECASE) - if m: - original_height = int(m.group(1)) - print(f"original height: {original_height}") - continue - - m = re.match(r"nw (\d+)", parg, re.IGNORECASE) - if m: - original_width_negative = int(m.group(1)) - print(f"original width negative: {original_width_negative}") - continue - - m = re.match(r"nh (\d+)", parg, re.IGNORECASE) - if m: - original_height_negative = int(m.group(1)) - print(f"original height negative: {original_height_negative}") - continue - - m = re.match(r"ct (\d+)", parg, re.IGNORECASE) - if m: - crop_top = int(m.group(1)) - print(f"crop top: {crop_top}") - continue - - m = re.match(r"cl (\d+)", parg, re.IGNORECASE) - if m: - crop_left = int(m.group(1)) - print(f"crop left: {crop_left}") - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") - continue - - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - print(f"scale: {scale}") - continue - - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == "none": - negative_scale = None - else: - negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") - continue - - m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - print(f"strength: {strength}") - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") - continue - - m = re.match(r"c (.+)", parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") - continue - - m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") - continue - - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) - - # prepare seed - if seeds is not None: # given in prompt - # 数が足りないなら前のをそのまま使う - if len(seeds) > 0: - seed = seeds.pop(0) - else: - if predefined_seeds is not None: - if len(predefined_seeds) > 0: - seed = predefined_seeds.pop(0) - else: - print("predefined seeds are exhausted") - seed = None - elif args.iter_same_seed: - seeds = iter_seed - else: - seed = None # 前のを消す - - if seed is None: - seed = random.randint(0, 0x7FFFFFFF) - if args.interactive: - print(f"seed: {seed}") - - # prepare init image, guide image and mask - init_image = mask_image = guide_image = None - - # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する - if init_images is not None: - init_image = init_images[global_step % len(init_images)] - - # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する - # 32単位に丸めたやつにresizeされるので踏襲する - if not highres_fix: - width, height = init_image.size - width = width - width % 32 - height = height - height % 32 - if width != init_image.size[0] or height != init_image.size[1]: - print( - f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" - ) - - if mask_images is not None: - mask_image = mask_images[global_step % len(mask_images)] - - if guide_images is not None: - if control_nets: # 複数件の場合あり - c = len(control_nets) - p = global_step % (len(guide_images) // c) - guide_image = guide_images[p * c : p * c + c] - else: - guide_image = guide_images[global_step % len(guide_images)] - - if regional_network: - num_sub_prompts = len(prompt.split(" AND ")) - assert ( - len(networks) <= num_sub_prompts - ), "Number of networks must be less than or equal to number of sub prompts." - else: - num_sub_prompts = None - - b1 = BatchData( - False, - BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), - BatchDataExt( - width, - height, - original_width, - original_height, - original_width_negative, - original_height_negative, - crop_left, - crop_top, - steps, - scale, - negative_scale, - strength, - tuple(network_muls) if network_muls else None, - num_sub_prompts, - ), - ) - if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? - process_batch(batch_data, highres_fix) - batch_data.clear() - - batch_data.append(b1) - if len(batch_data) == args.batch_size: - prev_image = process_batch(batch_data, highres_fix)[0] - batch_data.clear() - - global_step += 1 - - prompt_index += 1 - - if len(batch_data) > 0: - process_batch(batch_data, highres_fix) - batch_data.clear() - - print("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") - parser.add_argument( - "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む" - ) - parser.add_argument( - "--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)" - ) - parser.add_argument( - "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" - ) - parser.add_argument( - "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" - ) - parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") - parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") - parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") - parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") - parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする") - parser.add_argument( - "--use_original_file_name", - action="store_true", - help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", - ) - # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) - parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") - parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") - parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") - parser.add_argument( - "--original_height", type=int, default=None, help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値" - ) - parser.add_argument( - "--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値" - ) - parser.add_argument( - "--original_height_negative", - type=int, - default=None, - help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値", - ) - parser.add_argument( - "--original_width_negative", - type=int, - default=None, - help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", - ) - parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値") - parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値") - parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") - parser.add_argument( - "--vae_batch_size", - type=float, - default=None, - help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", - ) - parser.add_argument( - "--vae_slices", - type=int, - default=None, - help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", - ) - parser.add_argument("--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない") - parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") - parser.add_argument( - "--sampler", - type=str, - default="ddim", - choices=[ - "ddim", - "pndm", - "lms", - "euler", - "euler_a", - "heun", - "dpm_2", - "dpm_2_a", - "dpmsolver", - "dpmsolver++", - "dpmsingle", - "k_lms", - "k_euler", - "k_euler_a", - "k_dpm_2", - "k_dpm_2_a", - ], - help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", - ) - parser.add_argument( - "--scale", - type=float, - default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", - ) - parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") - parser.add_argument( - "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" - ) - parser.add_argument( - "--tokenizer_cache_dir", - type=str, - default=None, - help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", - ) - # parser.add_argument("--replace_clip_l14_336", action='store_true', - # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") - parser.add_argument( - "--seed", - type=int, - default=None, - help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", - ) - parser.add_argument( - "--iter_same_seed", - action="store_true", - help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", - ) - parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") - parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") - parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") - parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") - parser.add_argument( - "--diffusers_xformers", - action="store_true", - help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", - ) - parser.add_argument( - "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する" - ) - parser.add_argument( - "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名" - ) - parser.add_argument( - "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" - ) - parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率") - parser.add_argument( - "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数" - ) - parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") - parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") - parser.add_argument( - "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" - ) - parser.add_argument( - "--textual_inversion_embeddings", - type=str, - default=None, - nargs="*", - help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", - ) - parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") - parser.add_argument( - "--max_embeddings_multiples", - type=int, - default=None, - help="max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", - ) - parser.add_argument( - "--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガイドに使う画像" - ) - parser.add_argument( - "--highres_fix_scale", - type=float, - default=None, - help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", - ) - parser.add_argument( - "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数" - ) - parser.add_argument( - "--highres_fix_strength", - type=float, - default=None, - help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", - ) - parser.add_argument( - "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する" - ) - parser.add_argument( - "--highres_fix_latents_upscaling", - action="store_true", - help="use latents upscaling for highres fix / highres fixでlatentで拡大する", - ) - parser.add_argument( - "--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名" - ) - parser.add_argument( - "--highres_fix_upscaler_args", - type=str, - default=None, - help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数", - ) - parser.add_argument( - "--highres_fix_disable_control_net", - action="store_true", - help="disable ControlNet for highres fix / highres fixでControlNetを使わない", - ) - - parser.add_argument( - "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" - ) - - parser.add_argument( - "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" - ) - parser.add_argument( - "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" - ) - parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み") - parser.add_argument( - "--control_net_ratios", - type=float, - default=None, - nargs="*", - help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", - ) - # parser.add_argument( - # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" - # ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - main(args) diff --git a/sdxl_train_lora_control_net.py b/sdxl_train_control_net_lllite.py similarity index 97% rename from sdxl_train_lora_control_net.py rename to sdxl_train_control_net_lllite.py index b6fb1dec..6e49e449 100644 --- a/sdxl_train_lora_control_net.py +++ b/sdxl_train_control_net_lllite.py @@ -34,7 +34,7 @@ from library.custom_train_functions import ( apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, ) -import networks.lora_control_net as lora_control_net +import networks.control_net_lllite as control_net_lllite # TODO 他のスクリプトと共通化する @@ -176,7 +176,7 @@ def train(args): accelerator.wait_for_everyone() # prepare ControlNet - network = lora_control_net.LoRAControlNet(unet, args.cond_emb_dim, args.network_dim, 1, args.network_dropout) + network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout) network.apply_to() if args.network_weights is not None: @@ -242,7 +242,7 @@ def train(args): unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, network, optimizer, train_dataloader, lr_scheduler ) - network: lora_control_net.LoRAControlNet + network: control_net_lllite.ControlNetLLLite # transform DDP after prepare (train_network here only) unet, network = train_util.transform_models_if_DDP([unet, network]) @@ -311,7 +311,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lora_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs ) loss_list = [] @@ -401,11 +401,9 @@ def train(args): controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) with accelerator.autocast(): - # conditioning image embeddingを計算する / calculate conditioning image embedding - cond_embs_4d, cond_embs_3d = network(controlnet_image) - - # 個別のLoRA的モジュールでさらにembeddingを計算する / calculate embedding in each LoRA-like module - network.set_cond_embs(cond_embs_4d, cond_embs_3d) + # conditioning imageをControlNetに渡す / pass conditioning image to ControlNet + # 内部でcond_embに変換される / it will be converted to cond_emb inside + network.set_cond_image(controlnet_image) # それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) @@ -562,7 +560,7 @@ def setup_parser() -> argparse.ArgumentParser: if __name__ == "__main__": - sdxl_original_unet.USE_REENTRANT = False + # sdxl_original_unet.USE_REENTRANT = False parser = setup_parser()