From 983698dd1b20c081c161c1a7a328d193c3d6e25c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 15 Aug 2023 18:23:22 +0900 Subject: [PATCH 01/19] add lora controlnet temporarily --- networks/lora_control_net.py | 294 +++++++++++++++++++++++++++++++++++ 1 file changed, 294 insertions(+) create mode 100644 networks/lora_control_net.py diff --git a/networks/lora_control_net.py b/networks/lora_control_net.py new file mode 100644 index 00000000..4b57e2fd --- /dev/null +++ b/networks/lora_control_net.py @@ -0,0 +1,294 @@ +import os +from typing import Optional, List, Type +import torch +from networks.lora import LoRAModule, LoRANetwork +from library import sdxl_original_unet + + +SKIP_OUTPUT_BLOCKS = False +SKIP_CONV2D = False + + +class LoRAModuleControlNet(LoRAModule): + def __init__(self, depth, cond_emb_dim, name, org_module, multiplier, lora_dim, alpha, dropout=None): + super().__init__(name, org_module, multiplier, lora_dim, alpha, dropout=dropout) + self.is_conv2d = org_module.__class__.__name__ == "Conv2d" + + # adjust channels of conditioning image to LoRA channels + ch = 2 ** (depth - 1) * cond_emb_dim + if self.is_conv2d: + self.conditioning = torch.nn.Conv2d(ch, lora_dim, kernel_size=1, stride=1, padding=0) + else: + self.conditioning = torch.nn.Linear(ch, lora_dim) + torch.nn.init.zeros_(self.conditioning.weight) # zero conv/linear layer + + self.depth = depth + self.cond_emb_dim = cond_emb_dim + self.cond_emb = None + + def set_control(self, cond_emb): + self.cond_emb = cond_emb + + def forward(self, x): + # conditioning image embs -> LoRA channels + cx = self.cond_emb + if not self.is_conv2d: + # b,c,h,w -> b,h*w,c + n, c, h, w = cx.shape + cx = cx.view(n, c, h * w).permute(0, 2, 1) + # print(f"C {self.lora_name}, x.shape={x.shape}, cx.shape={cx.shape}, weight.shape={self.conditioning.weight.shape}") + cx = self.conditioning(cx) + + # LoRA + # print(f"C {self.lora_name}, x.shape={x.shape}, cx.shape={cx.shape}") + lx = self.lora_down(x) + + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # add conditioning + lx = lx + cx + + lx = self.lora_up(lx) + + x = self.org_forward(x) + lx * self.multiplier * self.scale + return x + + +class LoRAControlNet(torch.nn.Module): + def __init__( + self, + unet: sdxl_original_unet.SdxlUNet2DConditionModel, + cond_emb_dim: int = 16, + lora_dim: int = 16, + alpha: float = 1, + dropout: Optional[float] = None, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + # self.unets = [unet] + + def create_modules( + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + module_class: Type[object], + ) -> List[torch.nn.Module]: + prefix = LoRANetwork.LORA_PREFIX_UNET + + loras = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + + if is_linear or (is_conv2d and not SKIP_CONV2D): + # block index to depth: depth is using to calculate conditioning size and channels + block_name, index1, index2 = (name + "." + child_name).split(".")[:3] + index1 = int(index1) + if block_name == "input_blocks": + depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3) + elif block_name == "middle_block": + depth = 3 + elif block_name == "output_blocks": + if SKIP_OUTPUT_BLOCKS: + continue + depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1) + if int(index2) >= 2: + depth -= 1 + else: + raise NotImplementedError() + + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + # skip time emb or clip emb + if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)): + continue + + lora = module_class( + depth, + cond_emb_dim, + lora_name, + child_module, + 1.0, + lora_dim, + alpha, + dropout=dropout, + ) + loras.append(lora) + return loras + + target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + + # create module instances + self.unet_loras: List[LoRAModuleControlNet] = create_modules(unet, target_modules, LoRAModuleControlNet) + print(f"create ControlNet LoRA for U-Net: {len(self.unet_loras)} modules.") + + # stem for conditioning image + self.cond_stem = torch.nn.Sequential( + torch.nn.Conv2d(3, cond_emb_dim, kernel_size=4, stride=4, padding=0), + torch.nn.ReLU(inplace=True), + ) + + # embs for each depth + self.cond_block0 = torch.nn.Sequential( + torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1), + torch.nn.ReLU(inplace=True), + ) + self.cond_block1 = torch.nn.Sequential( + torch.nn.Conv2d(cond_emb_dim, cond_emb_dim * 2, kernel_size=3, stride=2, padding=1), + torch.nn.ReLU(inplace=True), + ) + self.cond_block2 = torch.nn.Sequential( + torch.nn.Conv2d(cond_emb_dim * 2, cond_emb_dim * 4, kernel_size=3, stride=2, padding=1), + torch.nn.ReLU(inplace=True), + ) + self.cond_block3 = torch.nn.Sequential( + torch.nn.Conv2d(cond_emb_dim * 4, cond_emb_dim * 8, kernel_size=3, stride=2, padding=1), + torch.nn.ReLU(inplace=True), + ) + + # forawrdでなくset_controlに入れてもやはり動かない + def forward(self, x): + cx = self.cond_stem(x) + cx = self.cond_block0(cx) + c0 = cx + cx = self.cond_block1(cx) + c1 = cx + cx = self.cond_block2(cx) + c2 = cx + cx = self.cond_block3(cx) + c3 = cx + return c0, c1, c2, c3 + + def set_control(self, cond_embs): + for lora in self.unet_loras: + lora.set_control(cond_embs[lora.depth - 1]) + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self): + print("applying LoRA for U-Net...") + for lora in self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return False + + def merge_to(self, text_encoder, unet, weights_sd, dtype, device): + raise NotImplementedError() + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_optimizer_params(self): + self.requires_grad_(True) + return self.parameters() + + def prepare_grad_etc(self): + self.requires_grad_(True) + + def on_epoch_start(self): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + +if __name__ == "__main__": + # test shape etc + print("create unet") + unet = sdxl_original_unet.SdxlUNet2DConditionModel() + unet.to("cuda") # , dtype=torch.float16) + + print("create LoRA controlnet") + control_net = LoRAControlNet(unet, 16, 32, 1) + control_net.apply_to() + control_net.to("cuda") + + # print(controlnet) + # input() + + # print number of parameters + print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad)) + + unet.set_use_memory_efficient_attention(True, False) + unet.set_gradient_checkpointing(True) + unet.train() # for gradient checkpointing + + control_net.train() + + # # visualize + # import torchviz + # print("run visualize") + # controlnet.set_control(conditioning_image) + # output = unet(x, t, ctx, y) + # print("make_dot") + # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) + # print("render") + # image.format = "svg" # "png" + # image.render("NeuralNet") + # input() + + import bitsandbytes + + optimizer = bitsandbytes.adam.Adam8bit(control_net.prepare_optimizer_params(), 1e-3) + + scaler = torch.cuda.amp.GradScaler(enabled=True) + + print("start training") + steps = 10 + + for step in range(steps): + print(f"step {step}") + + batch_size = 1 + conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0 + x = torch.randn(batch_size, 4, 128, 128).cuda() + t = torch.randint(low=0, high=10, size=(batch_size,)).cuda() + ctx = torch.randn(batch_size, 77, 2048).cuda() + y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() + + with torch.cuda.amp.autocast(enabled=True): + cond_embs = control_net(conditioning_image) + control_net.set_control(cond_embs) + output = unet(x, t, ctx, y) + target = torch.randn_like(output) + loss = torch.nn.functional.mse_loss(output, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) From 3f7235c36fc85778b03a60ff7ed3350d50310e91 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 17 Aug 2023 10:08:02 +0900 Subject: [PATCH 02/19] add lora controlnet train/gen temporarily --- library/sdxl_original_unet.py | 15 +- library/train_util.py | 14 +- networks/check_lora_weights.py | 48 +- networks/lora_control_net.py | 163 +- sdxl_gen_img_lora_ctrl_test.py | 2602 ++++++++++++++++++++++++++++++++ sdxl_train_lora_control_net.py | 823 ++++++++++ 6 files changed, 3582 insertions(+), 83 deletions(-) create mode 100644 sdxl_gen_img_lora_ctrl_test.py create mode 100644 sdxl_train_lora_control_net.py diff --git a/library/sdxl_original_unet.py b/library/sdxl_original_unet.py index 6ea4bc33..586909bd 100644 --- a/library/sdxl_original_unet.py +++ b/library/sdxl_original_unet.py @@ -39,6 +39,7 @@ CONTEXT_DIM: int = 2048 MODEL_CHANNELS: int = 320 TIME_EMBED_DIM = 320 * 4 +USE_REENTRANT = True # region memory effcient attention @@ -322,7 +323,7 @@ class ResnetBlock2D(nn.Module): return custom_forward - x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb) + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb, use_reentrant=USE_REENTRANT) else: x = self.forward_body(x, emb) @@ -356,7 +357,9 @@ class Downsample2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), hidden_states) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.forward_body), hidden_states, use_reentrant=USE_REENTRANT + ) else: hidden_states = self.forward_body(hidden_states) @@ -641,7 +644,9 @@ class BasicTransformerBlock(nn.Module): return custom_forward - output = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), hidden_states, context, timestep) + output = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.forward_body), hidden_states, context, timestep, use_reentrant=USE_REENTRANT + ) else: output = self.forward_body(hidden_states, context, timestep) @@ -782,7 +787,9 @@ class Upsample2D(nn.Module): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), hidden_states, output_size) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.forward_body), hidden_states, output_size, use_reentrant=USE_REENTRANT + ) else: hidden_states = self.forward_body(hidden_states, output_size) diff --git a/library/train_util.py b/library/train_util.py index 0b40e3ed..631f1cb7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1743,6 +1743,9 @@ class ControlNetDataset(BaseDataset): self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + def __len__(self): return self.dreambooth_dataset_delegate.__len__() @@ -1775,9 +1778,14 @@ class ControlNetDataset(BaseDataset): h, w = target_size_hw cond_img = cond_img[ct : ct + h, cl : cl + w] else: - assert ( - cond_img.shape[0] == self.height and cond_img.shape[1] == self.width - ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + # assert ( + # cond_img.shape[0] == self.height and cond_img.shape[1] == self.width + # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + # resize to target + if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: + cond_img = cv2.resize( + cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4 + ) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index bb8dcd6b..51f581b2 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -5,35 +5,41 @@ from safetensors.torch import load_file def main(file): - print(f"loading: {file}") - if os.path.splitext(file)[1] == '.safetensors': - sd = load_file(file) - else: - sd = torch.load(file, map_location='cpu') + print(f"loading: {file}") + if os.path.splitext(file)[1] == ".safetensors": + sd = load_file(file) + else: + sd = torch.load(file, map_location="cpu") - values = [] + values = [] - keys = list(sd.keys()) - for key in keys: - if 'lora_up' in key or 'lora_down' in key: - values.append((key, sd[key])) - print(f"number of LoRA modules: {len(values)}") + keys = list(sd.keys()) + for key in keys: + if "lora_up" in key or "lora_down" in key: + values.append((key, sd[key])) + print(f"number of LoRA modules: {len(values)}") - for key, value in values: - value = value.to(torch.float32) - print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") + if args.show_all_keys: + for key in [k for k in keys if k not in values]: + values.append((key, sd[key])) + print(f"number of all modules: {len(values)}") + + for key, value in values: + value = value.to(torch.float32) + print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}") def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル") + parser = argparse.ArgumentParser() + parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル") + parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する") - return parser + return parser -if __name__ == '__main__': - parser = setup_parser() +if __name__ == "__main__": + parser = setup_parser() - args = parser.parse_args() + args = parser.parse_args() - main(args.file) + main(args.file) diff --git a/networks/lora_control_net.py b/networks/lora_control_net.py index 4b57e2fd..0189c632 100644 --- a/networks/lora_control_net.py +++ b/networks/lora_control_net.py @@ -7,51 +7,87 @@ from library import sdxl_original_unet SKIP_OUTPUT_BLOCKS = False SKIP_CONV2D = False +TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored +ATTN1_ETC_ONLY = True class LoRAModuleControlNet(LoRAModule): def __init__(self, depth, cond_emb_dim, name, org_module, multiplier, lora_dim, alpha, dropout=None): super().__init__(name, org_module, multiplier, lora_dim, alpha, dropout=dropout) self.is_conv2d = org_module.__class__.__name__ == "Conv2d" + self.cond_emb_dim = cond_emb_dim - # adjust channels of conditioning image to LoRA channels - ch = 2 ** (depth - 1) * cond_emb_dim if self.is_conv2d: - self.conditioning = torch.nn.Conv2d(ch, lora_dim, kernel_size=1, stride=1, padding=0) + self.conditioning1 = torch.nn.Sequential( + torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0), + torch.nn.ReLU(inplace=True), + ) + self.conditioning2 = torch.nn.Sequential( + torch.nn.Conv2d(lora_dim + cond_emb_dim, cond_emb_dim, kernel_size=1, stride=1, padding=0), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(cond_emb_dim, lora_dim, kernel_size=1, stride=1, padding=0), + torch.nn.ReLU(inplace=True), + ) else: - self.conditioning = torch.nn.Linear(ch, lora_dim) - torch.nn.init.zeros_(self.conditioning.weight) # zero conv/linear layer + self.conditioning1 = torch.nn.Sequential( + torch.nn.Linear(cond_emb_dim, cond_emb_dim), + torch.nn.ReLU(inplace=True), + torch.nn.Linear(cond_emb_dim, cond_emb_dim), + torch.nn.ReLU(inplace=True), + ) + self.conditioning2 = torch.nn.Sequential( + torch.nn.Linear(lora_dim + cond_emb_dim, cond_emb_dim), + torch.nn.ReLU(inplace=True), + torch.nn.Linear(cond_emb_dim, lora_dim), + torch.nn.ReLU(inplace=True), + ) + # torch.nn.init.zeros_(self.conditioning2[-2].weight) # zero conv self.depth = depth - self.cond_emb_dim = cond_emb_dim self.cond_emb = None + self.batch_cond_uncond_enabled = False - def set_control(self, cond_emb): - self.cond_emb = cond_emb + def set_cond_embs(self, cond_embs_4d, cond_embs_3d): + cond_embs = cond_embs_4d if self.is_conv2d else cond_embs_3d + cond_emb = cond_embs[self.depth - 1] + self.cond_emb = self.conditioning1(cond_emb) + + def set_batch_cond_uncond_enabled(self, enabled): + self.batch_cond_uncond_enabled = enabled def forward(self, x): - # conditioning image embs -> LoRA channels - cx = self.cond_emb - if not self.is_conv2d: - # b,c,h,w -> b,h*w,c - n, c, h, w = cx.shape - cx = cx.view(n, c, h * w).permute(0, 2, 1) - # print(f"C {self.lora_name}, x.shape={x.shape}, cx.shape={cx.shape}, weight.shape={self.conditioning.weight.shape}") - cx = self.conditioning(cx) + if self.cond_emb is None: + return self.org_forward(x) # LoRA - # print(f"C {self.lora_name}, x.shape={x.shape}, cx.shape={cx.shape}") - lx = self.lora_down(x) + lx = x + if self.batch_cond_uncond_enabled: + lx = lx[1::2] # cond only + + lx = self.lora_down(lx) if self.dropout is not None and self.training: lx = torch.nn.functional.dropout(lx, p=self.dropout) - # add conditioning - lx = lx + cx + # conditioning image + cx = self.cond_emb + # print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") + cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2) + cx = self.conditioning2(cx) + + lx = lx + cx lx = self.lora_up(lx) - x = self.org_forward(x) + lx * self.multiplier * self.scale + x = self.org_forward(x) + + if self.batch_cond_uncond_enabled: + x[1::2] += lx * self.multiplier * self.scale + else: + x += lx * self.multiplier * self.scale + return x @@ -106,6 +142,16 @@ class LoRAControlNet(torch.nn.Module): if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)): continue + if ATTN1_ETC_ONLY: + if "proj_out" in lora_name: + pass + elif "attn1" in lora_name and ("to_k" in lora_name or "to_v" in lora_name or "to_out" in lora_name): + pass + elif "ff_net_2" in lora_name: + pass + else: + continue + lora = module_class( depth, cond_emb_dim, @@ -119,52 +165,56 @@ class LoRAControlNet(torch.nn.Module): loras.append(lora) return loras - target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + if not TRANSFORMER_ONLY: + target_modules = target_modules + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 # create module instances self.unet_loras: List[LoRAModuleControlNet] = create_modules(unet, target_modules, LoRAModuleControlNet) print(f"create ControlNet LoRA for U-Net: {len(self.unet_loras)} modules.") - # stem for conditioning image - self.cond_stem = torch.nn.Sequential( - torch.nn.Conv2d(3, cond_emb_dim, kernel_size=4, stride=4, padding=0), - torch.nn.ReLU(inplace=True), - ) - - # embs for each depth + # conditioning image embedding self.cond_block0 = torch.nn.Sequential( - torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1), + torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0), # to latent size + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=3, stride=2, padding=1), torch.nn.ReLU(inplace=True), ) self.cond_block1 = torch.nn.Sequential( - torch.nn.Conv2d(cond_emb_dim, cond_emb_dim * 2, kernel_size=3, stride=2, padding=1), + torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=True), + torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1), torch.nn.ReLU(inplace=True), ) self.cond_block2 = torch.nn.Sequential( - torch.nn.Conv2d(cond_emb_dim * 2, cond_emb_dim * 4, kernel_size=3, stride=2, padding=1), - torch.nn.ReLU(inplace=True), - ) - self.cond_block3 = torch.nn.Sequential( - torch.nn.Conv2d(cond_emb_dim * 4, cond_emb_dim * 8, kernel_size=3, stride=2, padding=1), + torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1), torch.nn.ReLU(inplace=True), ) - # forawrdでなくset_controlに入れてもやはり動かない def forward(self, x): - cx = self.cond_stem(x) - cx = self.cond_block0(cx) - c0 = cx - cx = self.cond_block1(cx) - c1 = cx - cx = self.cond_block2(cx) - c2 = cx - cx = self.cond_block3(cx) - c3 = cx - return c0, c1, c2, c3 + x = self.cond_block0(x) + x0 = x + x = self.cond_block1(x) + x1 = x + x = self.cond_block2(x) + x2 = x - def set_control(self, cond_embs): + x_3d = [] + for x0 in [x0, x1, x2]: + # b,c,h,w -> b,h*w,c + n, c, h, w = x0.shape + x0 = x0.view(n, c, h * w).permute(0, 2, 1) + x_3d.append(x0) + + return [x0, x1, x2], x_3d + + def set_cond_embs(self, cond_embs_4d, cond_embs_3d): for lora in self.unet_loras: - lora.set_control(cond_embs[lora.depth - 1]) + lora.set_cond_embs(cond_embs_4d, cond_embs_3d) + + def set_batch_cond_uncond_enabled(self, enabled): + for lora in self.unet_loras: + lora.set_batch_cond_uncond_enabled(enabled) def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": @@ -228,18 +278,20 @@ class LoRAControlNet(torch.nn.Module): if __name__ == "__main__": + sdxl_original_unet.USE_REENTRANT = False + # test shape etc print("create unet") unet = sdxl_original_unet.SdxlUNet2DConditionModel() - unet.to("cuda") # , dtype=torch.float16) + unet.to("cuda").to(torch.float16) print("create LoRA controlnet") - control_net = LoRAControlNet(unet, 16, 32, 1) + control_net = LoRAControlNet(unet, 128, 32, 1) control_net.apply_to() control_net.to("cuda") - # print(controlnet) - # input() + print(control_net) + input() # print number of parameters print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad)) @@ -282,8 +334,9 @@ if __name__ == "__main__": y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() with torch.cuda.amp.autocast(enabled=True): - cond_embs = control_net(conditioning_image) - control_net.set_control(cond_embs) + cond_embs_4d, cond_embs_3d = control_net(conditioning_image) + control_net.set_cond_embs(cond_embs_4d, cond_embs_3d) + output = unet(x, t, ctx, y) target = torch.randn_like(output) loss = torch.nn.functional.mse_loss(output, target) diff --git a/sdxl_gen_img_lora_ctrl_test.py b/sdxl_gen_img_lora_ctrl_test.py new file mode 100644 index 00000000..4820aa3f --- /dev/null +++ b/sdxl_gen_img_lora_ctrl_test.py @@ -0,0 +1,2602 @@ +# temporary test code for LoRA control net + +import itertools +import json +from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable +import glob +import importlib +import inspect +import time +import zipfile +from diffusers.utils import deprecate +from diffusers.configuration_utils import FrozenDict +import argparse +import math +import os +import random +import re + +import diffusers +import numpy as np +import torch +import torchvision +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + DDIMScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + # UNet2DConditionModel, + StableDiffusionPipeline, +) +from einops import rearrange +from tqdm import tqdm +from torchvision import transforms +from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig +import PIL +from PIL import Image +from PIL.PngImagePlugin import PngInfo + +import library.model_util as model_util +import library.train_util as train_util +import library.sdxl_model_util as sdxl_model_util +import library.sdxl_train_util as sdxl_train_util +from networks.lora import LoRANetwork +from library.sdxl_original_unet import SdxlUNet2DConditionModel +from library.original_unet import FlashAttentionFunction +from networks.lora_control_net import LoRAControlNet + +# scheduler: +SCHEDULER_LINEAR_START = 0.00085 +SCHEDULER_LINEAR_END = 0.0120 +SCHEDULER_TIMESTEPS = 1000 +SCHEDLER_SCHEDULE = "scaled_linear" + +# その他の設定 +LATENT_CHANNELS = 4 +DOWNSAMPLING_FACTOR = 8 + +# region モジュール入れ替え部 +""" +高速化のためのモジュール入れ替え +""" + + +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + print("Enable memory efficient attention for U-Net") + + # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い + unet.set_use_memory_efficient_attention(False, True) + elif xformers: + print("Enable xformers for U-Net") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + + unet.set_use_memory_efficient_attention(True, False) + elif sdpa: + print("Enable SDPA for U-Net") + unet.set_use_memory_efficient_attention(False, False) + unet.set_use_sdpa(True) + + +# TODO common train_util.py +def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): + if mem_eff_attn: + replace_vae_attn_to_memory_efficient() + elif xformers: + # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す? + vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う + elif sdpa: + replace_vae_attn_to_sdpa() + + +def replace_vae_attn_to_memory_efficient(): + print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, hidden_states, **kwargs): + q_bucket_size = 512 + k_bucket_size = 1024 + + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_flash_attn_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_flash_attn(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_flash_attn + + +def replace_vae_attn_to_xformers(): + print("VAE: Attention.forward has been replaced to xformers") + import xformers.ops + + def forward_xformers(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + query_proj = query_proj.contiguous() + key_proj = key_proj.contiguous() + value_proj = value_proj.contiguous() + out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + + out = rearrange(out, "b h n d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_xformers_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_xformers(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_xformers + + +def replace_vae_attn_to_sdpa(): + print("VAE: Attention.forward has been replaced to sdpa") + + def forward_sdpa(self, hidden_states, **kwargs): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.to_q(hidden_states) + key_proj = self.to_k(hidden_states) + value_proj = self.to_v(hidden_states) + + query_proj, key_proj, value_proj = map( + lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) + ) + + out = torch.nn.functional.scaled_dot_product_attention( + query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + out = rearrange(out, "b n h d -> b n (h d)") + + # compute next hidden_states + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + def forward_sdpa_0_14(self, hidden_states, **kwargs): + if not hasattr(self, "to_q"): + self.to_q = self.query + self.to_k = self.key + self.to_v = self.value + self.to_out = [self.proj_attn, torch.nn.Identity()] + self.heads = self.num_heads + return forward_sdpa(self, hidden_states, **kwargs) + + if diffusers.__version__ < "0.15.0": + diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 + else: + diffusers.models.attention_processor.Attention.forward = forward_sdpa + + +# endregion + +# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 +# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py +# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 + + +class PipelineLike: + def __init__( + self, + device, + vae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + tokenizers: List[CLIPTokenizer], + unet: SdxlUNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + clip_skip: int, + ): + super().__init__() + self.device = device + self.clip_skip = clip_skip + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + self.vae = vae + self.text_encoders = text_encoders + self.tokenizers = tokenizers + self.unet: SdxlUNet2DConditionModel = unet + self.scheduler = scheduler + self.safety_checker = None + + # Textual Inversion + self.token_replacements_list = [] + for _ in range(len(self.text_encoders)): + self.token_replacements_list.append({}) + + # ControlNet # not supported yet + self.control_nets: List[LoRAControlNet] = [] + self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない + + # Textual Inversion + def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): + self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids + + def set_enable_control_net(self, en: bool): + self.control_net_enabled = en + + def get_token_replacer(self, tokenizer): + tokenizer_index = self.tokenizers.index(tokenizer) + token_replacements = self.token_replacements_list[tokenizer_index] + + def replace_tokens(tokens): + # print("replace_tokens", tokens, "=>", token_replacements) + if isinstance(tokens, torch.Tensor): + tokens = tokens.tolist() + + new_tokens = [] + for token in tokens: + if token in token_replacements: + replacement = token_replacements[token] + new_tokens.extend(replacement) + else: + new_tokens.append(token) + return new_tokens + + return replace_tokens + + def set_control_nets(self, ctrl_nets): + self.control_nets = ctrl_nets + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + height: int = 1024, + width: int = 1024, + original_height: int = None, + original_width: int = None, + original_height_negative: int = None, + original_width_negative: int = None, + crop_top: int = 0, + crop_left: int = 0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_scale: float = None, + strength: float = 0.8, + # num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, + max_embeddings_multiples: Optional[int] = 3, + output_type: Optional[str] = "pil", + vae_batch_size: float = None, + return_latents: bool = False, + # return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + is_cancelled_callback: Optional[Callable[[], bool]] = None, + callback_steps: Optional[int] = 1, + img2img_noise=None, + clip_guide_images=None, + **kwargs, + ): + # TODO support secondary prompt + num_images_per_prompt = 1 # fixed because already prompt is repeated + + if isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + reginonal_network = " AND " in prompt[0] + + vae_batch_size = ( + batch_size + if vae_batch_size is None + else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) + ) + + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." + ) + + # get prompt text embeddings + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if not do_classifier_free_guidance and negative_scale is not None: + print(f"negative_scale is ignored if guidance scalle <= 1.0") + negative_scale = None + + # get unconditional embeddings for classifier free guidance + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + tes_text_embs = [] + tes_uncond_embs = [] + tes_real_uncond_embs = [] + # use last pool + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): + token_replacer = self.get_token_replacer(tokenizer) + + text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings( + tokenizer, + text_encoder, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + **kwargs, + ) + tes_text_embs.append(text_embeddings) + tes_uncond_embs.append(uncond_embeddings) + + if negative_scale is not None: + _, real_uncond_embeddings, _ = get_weighted_text_embeddings( + token_replacer, + prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 + uncond_prompt=[""] * batch_size, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + token_replacer=token_replacer, + device=self.device, + **kwargs, + ) + tes_real_uncond_embs.append(real_uncond_embeddings) + + # concat text encoder outputs + text_embeddings = tes_text_embs[0] + uncond_embeddings = tes_uncond_embs[0] + for i in range(1, len(tes_text_embs)): + text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 + uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 + + if do_classifier_free_guidance: + if negative_scale is None: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + else: + text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) + + if self.control_nets: + # ControlNetのhintにguide imageを流用する + if isinstance(clip_guide_images, PIL.Image.Image): + clip_guide_images = [clip_guide_images] + if isinstance(clip_guide_images[0], PIL.Image.Image): + clip_guide_images = [preprocess_image(im) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images) + if isinstance(clip_guide_images, list): + clip_guide_images = torch.stack(clip_guide_images) + + clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype) + + # create size embs + if original_height is None: + original_height = height + if original_width is None: + original_width = width + if original_height_negative is None: + original_height_negative = original_height + if original_width_negative is None: + original_width_negative = original_width + if crop_top is None: + crop_top = 0 + if crop_left is None: + crop_left = 0 + emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) + uc_emb1 = sdxl_train_util.get_timestep_embedding( + torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256 + ) + emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) + emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) + c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) + + c_vector = torch.cat([text_pool, c_vector], dim=1) + uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) + + vector_embeddings = torch.cat([uc_vector, c_vector]) + + # set timesteps + self.scheduler.set_timesteps(num_inference_steps, self.device) + + latents_dtype = text_embeddings.dtype + init_latents_orig = None + mask = None + + if init_image is None: + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = ( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height // 8, + width // 8, + ) + + if latents is None: + if self.device.type == "mps": + # randn does not exist on mps + latents = torch.randn( + latents_shape, + generator=generator, + device="cpu", + dtype=latents_dtype, + ).to(self.device) + else: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + dtype=latents_dtype, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + timesteps = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + else: + # image to tensor + if isinstance(init_image, PIL.Image.Image): + init_image = [init_image] + if isinstance(init_image[0], PIL.Image.Image): + init_image = [preprocess_image(im) for im in init_image] + init_image = torch.cat(init_image) + if isinstance(init_image, list): + init_image = torch.stack(init_image) + + # mask image to tensor + if mask_image is not None: + if isinstance(mask_image, PIL.Image.Image): + mask_image = [mask_image] + if isinstance(mask_image[0], PIL.Image.Image): + mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint + + # encode the init image into latents and scale the latents + init_image = init_image.to(device=self.device, dtype=latents_dtype) + if init_image.size()[-2:] == (height // 8, width // 8): + init_latents = init_image + else: + if vae_batch_size >= batch_size: + init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + init_latents = [] + for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): + init_latent_dist = self.vae.encode( + (init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)).to( + self.vae.dtype + ) + ).latent_dist + init_latents.append(init_latent_dist.sample(generator=generator)) + init_latents = torch.cat(init_latents) + + init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents + + if len(init_latents) == 1: + init_latents = init_latents.repeat((batch_size, 1, 1, 1)) + init_latents_orig = init_latents + + # preprocess mask + if mask_image is not None: + mask = mask_image.to(device=self.device, dtype=latents_dtype) + if len(mask) == 1: + mask = mask.repeat((batch_size, 1, 1, 1)) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 + + if self.control_nets: + if self.control_net_enabled: + for control_net in self.control_nets: + with torch.no_grad(): + cond_embs_4d, cond_embs_3d = control_net(clip_guide_images) + control_net.set_cond_embs(cond_embs_4d, cond_embs_3d) + else: + for control_net in self.control_nets: + control_net.set_cond_image(None) + + for i, t in enumerate(tqdm(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # # disable control net if ratio is set + # if self.control_nets and self.control_net_enabled: + # pass # TODO + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + + # perform guidance + if do_classifier_free_guidance: + if negative_scale is None: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + else: + noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( + num_latent_input + ) # uncond is real uncond + noise_pred = ( + noise_pred_uncond + + guidance_scale * (noise_pred_text - noise_pred_uncond) + - negative_scale * (noise_pred_negative - noise_pred_uncond) + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + if return_latents: + return (latents, False) + + latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents + if vae_batch_size >= batch_size: + image = self.vae.decode(latents.to(self.vae.dtype)).sample + else: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + images = [] + for i in tqdm(range(0, batch_size, vae_batch_size)): + images.append( + self.vae.decode( + (latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).to(self.vae.dtype) + ).sample + ) + image = torch.cat(images) + + image = (image / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + # image = self.numpy_to_pil(image) + image = (image * 255).round().astype("uint8") + image = [Image.fromarray(im) for im in image] + + return image + + # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + +re_attention = re.compile( + r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", + re.X, +) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + # keep break as separate token + text = text.replace("BREAK", "\\BREAK\\") + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + if word.strip() == "BREAK": + # pad until next multiple of tokenizer's max token length + pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) + print(f"BREAK pad_len: {pad_len}") + for i in range(pad_len): + # v2のときEOSをつけるべきかどうかわからないぜ + # if i == 0: + # text_token.append(tokenizer.eos_token_id) + # else: + text_token.append(tokenizer.pad_token_id) + text_weight.append(1.0) + continue + + # tokenize and discard the starting and the ending token + token = tokenizer(word).input_ids[1:-1] + + token = token_replacer(token) # for Textual Inversion + + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + text_encoder: CLIPTextModel, + text_input: torch.Tensor, + chunk_length: int, + clip_skip: int, + eos: int, + pad: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + pool = None + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + if pad == eos: # v1 + text_input_chunk[:, -1] = text_input[0, -1] + else: # v2 + for j in range(len(text_input_chunk)): + if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある + text_input_chunk[j, -1] = eos + if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD + text_input_chunk[j, 1] = eos + + # -2 is same for Text Encoder 1 and 2 + enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) + text_embedding = enc_out["hidden_states"][-2] + if pool is None: + pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos) + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-2] + pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input, eos) + return text_embeddings, pool + + +def get_weighted_text_embeddings( + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 1, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, + clip_skip=None, + token_replacer=None, + device=None, + **kwargs, +): + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + # split the prompts with "AND". each prompt must have the same number of splits + new_prompts = [] + for p in prompt: + new_prompts.extend(p.split(" AND ")) + prompt = new_prompts + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2) + else: + prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = tokenizer.bos_token_id + eos = tokenizer.eos_token_id + pad = tokenizer.pad_token_id + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) + + # get the embeddings + text_embeddings, text_pool = get_unweighted_text_embeddings( + text_encoder, + prompt_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) + if uncond_prompt is not None: + uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( + text_encoder, + uncond_tokens, + tokenizer.model_max_length, + clip_skip, + eos, + pad, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + # →全体でいいんじゃないかな + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens + return text_embeddings, text_pool, None, None, prompt_tokens + + +def preprocess_image(image): + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def preprocess_mask(mask): + mask = mask.convert("L") + w, h = mask.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + +# regular expression for dynamic prompt: +# starts and ends with "{" and "}" +# contains at least one variant divided by "|" +# optional framgments divided by "$$" at start +# if the first fragment is "E" or "e", enumerate all variants +# if the second fragment is a number or two numbers, repeat the variants in the range +# if the third fragment is a string, use it as a separator + +RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") + + +def handle_dynamic_prompt_variants(prompt, repeat_count): + founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) + if not founds: + return [prompt] + + # make each replacement for each variant + enumerating = False + replacers = [] + for found in founds: + # if "e$$" is found, enumerate all variants + found_enumerating = found.group(2) is not None + enumerating = enumerating or found_enumerating + + separator = ", " if found.group(6) is None else found.group(6) + variants = found.group(7).split("|") + + # parse count range + count_range = found.group(4) + if count_range is None: + count_range = [1, 1] + else: + count_range = count_range.split("-") + if len(count_range) == 1: + count_range = [int(count_range[0]), int(count_range[0])] + elif len(count_range) == 2: + count_range = [int(count_range[0]), int(count_range[1])] + else: + print(f"invalid count range: {count_range}") + count_range = [1, 1] + if count_range[0] > count_range[1]: + count_range = [count_range[1], count_range[0]] + if count_range[0] < 0: + count_range[0] = 0 + if count_range[1] > len(variants): + count_range[1] = len(variants) + + if found_enumerating: + # make function to enumerate all combinations + def make_replacer_enum(vari, cr, sep): + def replacer(): + values = [] + for count in range(cr[0], cr[1] + 1): + for comb in itertools.combinations(vari, count): + values.append(sep.join(comb)) + return values + + return replacer + + replacers.append(make_replacer_enum(variants, count_range, separator)) + else: + # make function to choose random combinations + def make_replacer_single(vari, cr, sep): + def replacer(): + count = random.randint(cr[0], cr[1]) + comb = random.sample(vari, count) + return [sep.join(comb)] + + return replacer + + replacers.append(make_replacer_single(variants, count_range, separator)) + + # make each prompt + if not enumerating: + # if not enumerating, repeat the prompt, replace each variant randomly + prompts = [] + for _ in range(repeat_count): + current = prompt + for found, replacer in zip(founds, replacers): + current = current.replace(found.group(0), replacer()[0], 1) + prompts.append(current) + else: + # if enumerating, iterate all combinations for previous prompts + prompts = [prompt] + + for found, replacer in zip(founds, replacers): + if found.group(2) is not None: + # make all combinations for existing prompts + new_prompts = [] + for current in prompts: + replecements = replacer() + for replecement in replecements: + new_prompts.append(current.replace(found.group(0), replecement, 1)) + prompts = new_prompts + + for found, replacer in zip(founds, replacers): + # make random selection for existing prompts + if found.group(2) is None: + for i in range(len(prompts)): + prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) + + return prompts + + +# endregion + + +# def load_clip_l14_336(dtype): +# print(f"loading CLIP: {CLIP_ID_L14_336}") +# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) +# return text_encoder + + +class BatchDataBase(NamedTuple): + # バッチ分割が必要ないデータ + step: int + prompt: str + negative_prompt: str + seed: int + init_image: Any + mask_image: Any + clip_prompt: str + guide_image: Any + + +class BatchDataExt(NamedTuple): + # バッチ分割が必要なデータ + width: int + height: int + original_width: int + original_height: int + original_width_negative: int + original_height_negative: int + crop_left: int + crop_top: int + steps: int + scale: float + negative_scale: float + strength: float + network_muls: Tuple[float] + num_sub_prompts: int + + +class BatchData(NamedTuple): + return_latents: bool + base: BatchDataBase + ext: BatchDataExt + + +def main(args): + if args.fp16: + dtype = torch.float16 + elif args.bf16: + dtype = torch.bfloat16 + else: + dtype = torch.float32 + + highres_fix = args.highres_fix_scale is not None + # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" + + # モデルを読み込む + if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う + files = glob.glob(args.ckpt) + if len(files) == 1: + args.ckpt = files[0] + + (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype + ) + + # xformers、Hypernetwork対応 + if not args.diffusers_xformers: + mem_eff = not (args.xformers or args.sdpa) + replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) + replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) + + # tokenizerを読み込む + print("loading tokenizer") + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + + # schedulerを用意する + sched_init_args = {} + has_steps_offset = True + has_clip_sample = True + scheduler_num_noises_per_step = 1 + + if args.sampler == "ddim": + scheduler_cls = DDIMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddim + elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある + scheduler_cls = DDPMScheduler + scheduler_module = diffusers.schedulers.scheduling_ddpm + elif args.sampler == "pndm": + scheduler_cls = PNDMScheduler + scheduler_module = diffusers.schedulers.scheduling_pndm + has_clip_sample = False + elif args.sampler == "lms" or args.sampler == "k_lms": + scheduler_cls = LMSDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_lms_discrete + has_clip_sample = False + elif args.sampler == "euler" or args.sampler == "k_euler": + scheduler_cls = EulerDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_discrete + has_clip_sample = False + elif args.sampler == "euler_a" or args.sampler == "k_euler_a": + scheduler_cls = EulerAncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete + has_clip_sample = False + elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": + scheduler_cls = DPMSolverMultistepScheduler + sched_init_args["algorithm_type"] = args.sampler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep + has_clip_sample = False + elif args.sampler == "dpmsingle": + scheduler_cls = DPMSolverSinglestepScheduler + scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep + has_clip_sample = False + has_steps_offset = False + elif args.sampler == "heun": + scheduler_cls = HeunDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_heun_discrete + has_clip_sample = False + elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": + scheduler_cls = KDPM2DiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete + has_clip_sample = False + elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": + scheduler_cls = KDPM2AncestralDiscreteScheduler + scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete + scheduler_num_noises_per_step = 2 + has_clip_sample = False + + # 警告を出さないようにする + if has_steps_offset: + sched_init_args["steps_offset"] = 1 + if has_clip_sample: + sched_init_args["clip_sample"] = False + + # samplerの乱数をあらかじめ指定するための処理 + + # replace randn + class NoiseManager: + def __init__(self): + self.sampler_noises = None + self.sampler_noise_index = 0 + + def reset_sampler_noises(self, noises): + self.sampler_noise_index = 0 + self.sampler_noises = noises + + def randn(self, shape, device=None, dtype=None, layout=None, generator=None): + # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) + if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): + noise = self.sampler_noises[self.sampler_noise_index] + if shape != noise.shape: + noise = None + else: + noise = None + + if noise == None: + print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") + noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) + + self.sampler_noise_index += 1 + return noise + + class TorchRandReplacer: + def __init__(self, noise_manager): + self.noise_manager = noise_manager + + def __getattr__(self, item): + if item == "randn": + return self.noise_manager.randn + if hasattr(torch, item): + return getattr(torch, item) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) + + noise_manager = NoiseManager() + if scheduler_module is not None: + scheduler_module.torch = TorchRandReplacer(noise_manager) + + scheduler = scheduler_cls( + num_train_timesteps=SCHEDULER_TIMESTEPS, + beta_start=SCHEDULER_LINEAR_START, + beta_end=SCHEDULER_LINEAR_END, + beta_schedule=SCHEDLER_SCHEDULE, + **sched_init_args, + ) + + # ↓以下は結局PipeでFalseに設定されるので意味がなかった + # # clip_sample=Trueにする + # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: + # print("set clip_sample to True") + # scheduler.config.clip_sample = True + + # deviceを決定する + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + + # custom pipelineをコピったやつを生成する + if args.vae_slices: + from library.slicing_vae import SlicingAutoencoderKL + + sli_vae = SlicingAutoencoderKL( + act_fn="silu", + block_out_channels=(128, 256, 512, 512), + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + in_channels=3, + latent_channels=4, + layers_per_block=2, + norm_num_groups=32, + out_channels=3, + sample_size=512, + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + num_slices=args.vae_slices, + ) + sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする + vae = sli_vae + del sli_vae + + vae_dtype = dtype + if args.no_half_vae: + print("set vae_dtype to float32") + vae_dtype = torch.float32 + vae.to(vae_dtype).to(device) + + text_encoder1.to(dtype).to(device) + text_encoder2.to(dtype).to(device) + unet.to(dtype).to(device) + + # networkを組み込む + if args.network_module: + networks = [] + network_default_muls = [] + network_pre_calc = args.network_pre_calc + + for i, network_module in enumerate(args.network_module): + print("import network module:", network_module) + imported_module = importlib.import_module(network_module) + + network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + network_default_muls.append(network_mul) + + net_kwargs = {} + if args.network_args and i < len(args.network_args): + network_args = args.network_args[i] + # TODO escape special chars + network_args = network_args.split(";") + for net_arg in network_args: + key, value = net_arg.split("=") + net_kwargs[key] = value + + if args.network_weights and i < len(args.network_weights): + network_weight = args.network_weights[i] + print("load network weights from:", network_weight) + + if model_util.is_safetensors(network_weight) and args.network_show_meta: + from safetensors.torch import safe_open + + with safe_open(network_weight, framework="pt") as f: + metadata = f.metadata() + if metadata is not None: + print(f"metadata for: {network_weight}: {metadata}") + + network, weights_sd = imported_module.create_network_from_weights( + network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs + ) + else: + raise ValueError("No weight. Weight is required.") + if network is None: + return + + mergeable = network.is_mergeable() + if args.network_merge and not mergeable: + print("network is not mergiable. ignore merge option.") + + if not args.network_merge or not mergeable: + network.apply_to([text_encoder1, text_encoder2], unet) + info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい + print(f"weights are loaded: {info}") + + if args.opt_channels_last: + network.to(memory_format=torch.channels_last) + network.to(dtype).to(device) + + if network_pre_calc: + print("backup original weights") + network.backup_weights() + + networks.append(network) + else: + network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device) + + else: + networks = [] + + # upscalerの指定があれば取得する + upscaler = None + if args.highres_fix_upscaler: + print("import upscaler module:", args.highres_fix_upscaler) + imported_module = importlib.import_module(args.highres_fix_upscaler) + + us_kwargs = {} + if args.highres_fix_upscaler_args: + for net_arg in args.highres_fix_upscaler_args.split(";"): + key, value = net_arg.split("=") + us_kwargs[key] = value + + print("create upscaler") + upscaler = imported_module.create_upscaler(**us_kwargs) + upscaler.to(dtype).to(device) + + # ControlNetの処理 + control_nets: List[LoRAControlNet] = [] + if args.control_net_models: + for i, model_file in enumerate(args.control_net_models): + print(f"loading control net: {model_file}") + + from safetensors.torch import load_file + + state_dict = load_file(model_file) + + control_net = LoRAControlNet(unet, 128, 32, 1) # TODO load from weights + control_net.apply_to() + control_net.load_state_dict(state_dict) + control_net.to(dtype).to(device) + control_net.set_batch_cond_uncond_enabled(True) + control_nets.append(control_net) + + if args.opt_channels_last: + print(f"set optimizing: channels last") + text_encoder1.to(memory_format=torch.channels_last) + text_encoder2.to(memory_format=torch.channels_last) + vae.to(memory_format=torch.channels_last) + unet.to(memory_format=torch.channels_last) + if networks: + for network in networks: + network.to(memory_format=torch.channels_last) + + for cn in control_nets: + cn.to(memory_format=torch.channels_last) + + pipe = PipelineLike( + device, + vae, + [text_encoder1, text_encoder2], + [tokenizer1, tokenizer2], + unet, + scheduler, + args.clip_skip, + ) + pipe.set_control_nets(control_nets) + print("pipeline is ready.") + + if args.diffusers_xformers: + pipe.enable_xformers_memory_efficient_attention() + + # Textual Inversionを処理する + if args.textual_inversion_embeddings: + token_ids_embeds1 = [] + token_ids_embeds2 = [] + for embeds_file in args.textual_inversion_embeddings: + if model_util.is_safetensors(embeds_file): + from safetensors.torch import load_file + + data = load_file(embeds_file) + else: + data = torch.load(embeds_file, map_location="cpu") + + if "string_to_param" in data: + data = data["string_to_param"] + + embeds1 = data["clip_l"] # text encoder 1 + embeds2 = data["clip_g"] # text encoder 2 + + num_vectors_per_token = embeds1.size()[0] + token_string = os.path.splitext(os.path.basename(embeds_file))[0] + + token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] + + # add new word to tokenizer, count is num_vectors_per_token + num_added_tokens1 = tokenizer1.add_tokens(token_strings) + num_added_tokens2 = tokenizer2.add_tokens(token_strings) + assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, ( + f"tokenizer has same word to token string (filename): {embeds_file}" + + f" / 指定した名前(ファイル名)のトークンが既に存在します: {embeds_file}" + ) + + token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) + token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) + print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") + assert ( + min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 + ), f"token ids1 is not ordered" + assert ( + min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 + ), f"token ids2 is not ordered" + assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}" + assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}" + + if num_vectors_per_token > 1: + pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ... + pipe.add_token_replacement(1, token_ids2[0], token_ids2) + + token_ids_embeds1.append((token_ids1, embeds1)) + token_ids_embeds2.append((token_ids2, embeds2)) + + text_encoder1.resize_token_embeddings(len(tokenizer1)) + text_encoder2.resize_token_embeddings(len(tokenizer2)) + token_embeds1 = text_encoder1.get_input_embeddings().weight.data + token_embeds2 = text_encoder2.get_input_embeddings().weight.data + for token_ids, embeds in token_ids_embeds1: + for token_id, embed in zip(token_ids, embeds): + token_embeds1[token_id] = embed + for token_ids, embeds in token_ids_embeds2: + for token_id, embed in zip(token_ids, embeds): + token_embeds2[token_id] = embed + + # promptを取得する + if args.from_file is not None: + print(f"reading prompts from {args.from_file}") + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_list = f.read().splitlines() + prompt_list = [d for d in prompt_list if len(d.strip()) > 0] + elif args.prompt is not None: + prompt_list = [args.prompt] + else: + prompt_list = [] + + if args.interactive: + args.n_iter = 1 + + # img2imgの前処理、画像の読み込みなど + def load_images(path): + if os.path.isfile(path): + paths = [path] + else: + paths = ( + glob.glob(os.path.join(path, "*.png")) + + glob.glob(os.path.join(path, "*.jpg")) + + glob.glob(os.path.join(path, "*.jpeg")) + + glob.glob(os.path.join(path, "*.webp")) + ) + paths.sort() + + images = [] + for p in paths: + image = Image.open(p) + if image.mode != "RGB": + print(f"convert image to RGB from {image.mode}: {p}") + image = image.convert("RGB") + images.append(image) + + return images + + def resize_images(imgs, size): + resized = [] + for img in imgs: + r_img = img.resize(size, Image.Resampling.LANCZOS) + if hasattr(img, "filename"): # filename属性がない場合があるらしい + r_img.filename = img.filename + resized.append(r_img) + return resized + + if args.image_path is not None: + print(f"load image for img2img: {args.image_path}") + init_images = load_images(args.image_path) + assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" + print(f"loaded {len(init_images)} images for img2img") + else: + init_images = None + + if args.mask_path is not None: + print(f"load mask for inpainting: {args.mask_path}") + mask_images = load_images(args.mask_path) + assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" + print(f"loaded {len(mask_images)} mask images for inpainting") + else: + mask_images = None + + # promptがないとき、画像のPngInfoから取得する + if init_images is not None and len(prompt_list) == 0 and not args.interactive: + print("get prompts from images' meta data") + for img in init_images: + if "prompt" in img.text: + prompt = img.text["prompt"] + if "negative-prompt" in img.text: + prompt += " --n " + img.text["negative-prompt"] + prompt_list.append(prompt) + + # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) + l = [] + for im in init_images: + l.extend([im] * args.images_per_prompt) + init_images = l + + if mask_images is not None: + l = [] + for im in mask_images: + l.extend([im] * args.images_per_prompt) + mask_images = l + + # 画像サイズにオプション指定があるときはリサイズする + if args.W is not None and args.H is not None: + # highres fix を考慮に入れる + w, h = args.W, args.H + if highres_fix: + w = int(w * args.highres_fix_scale + 0.5) + h = int(h * args.highres_fix_scale + 0.5) + + if init_images is not None: + print(f"resize img2img source images to {w}*{h}") + init_images = resize_images(init_images, (w, h)) + if mask_images is not None: + print(f"resize img2img mask images to {w}*{h}") + mask_images = resize_images(mask_images, (w, h)) + + regional_network = False + if networks and mask_images: + # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 + regional_network = True + print("use mask as region") + + size = None + for i, network in enumerate(networks): + if i < 3: + np_mask = np.array(mask_images[0]) + np_mask = np_mask[:, :, i] + size = np_mask.shape + else: + np_mask = np.full(size, 255, dtype=np.uint8) + mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) + network.set_region(i, i == len(networks) - 1, mask) + mask_images = None + + prev_image = None # for VGG16 guided + if args.guide_image_path is not None: + print(f"load image for ControlNet guidance: {args.guide_image_path}") + guide_images = [] + for p in args.guide_image_path: + guide_images.extend(load_images(p)) + + print(f"loaded {len(guide_images)} guide images for guidance") + if len(guide_images) == 0: + print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") + guide_images = None + else: + guide_images = None + + # seed指定時はseedを決めておく + if args.seed is not None: + # dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう + random.seed(args.seed) + predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] + if len(predefined_seeds) == 1: + predefined_seeds[0] = args.seed + else: + predefined_seeds = None + + # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) + if args.W is None: + args.W = 1024 + if args.H is None: + args.H = 1024 + + # 画像生成のループ + os.makedirs(args.outdir, exist_ok=True) + max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples + + for gen_iter in range(args.n_iter): + print(f"iteration {gen_iter+1}/{args.n_iter}") + iter_seed = random.randint(0, 0x7FFFFFFF) + + # バッチ処理の関数 + def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): + batch_size = len(batch) + + # highres_fixの処理 + if highres_fix and not highres_1st: + # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す + is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling + + print("process 1st stage") + batch_1st = [] + for _, base, ext in batch: + + def scale_and_round(x): + if x is None: + return None + return int(x * args.highres_fix_scale + 0.5) + + width_1st = scale_and_round(ext.width) + height_1st = scale_and_round(ext.height) + width_1st = width_1st - width_1st % 32 + height_1st = height_1st - height_1st % 32 + + original_width_1st = scale_and_round(ext.original_width) + original_height_1st = scale_and_round(ext.original_height) + original_width_negative_1st = scale_and_round(ext.original_width_negative) + original_height_negative_1st = scale_and_round(ext.original_height_negative) + crop_left_1st = scale_and_round(ext.crop_left) + crop_top_1st = scale_and_round(ext.crop_top) + + strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength + + ext_1st = BatchDataExt( + width_1st, + height_1st, + original_width_1st, + original_height_1st, + original_width_negative_1st, + original_height_negative_1st, + crop_left_1st, + crop_top_1st, + args.highres_fix_steps, + ext.scale, + ext.negative_scale, + strength_1st, + ext.network_muls, + ext.num_sub_prompts, + ) + batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) + + pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする + images_1st = process_batch(batch_1st, True, True) + + # 2nd stageのバッチを作成して以下処理する + print("process 2nd stage") + width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height + + if upscaler: + # upscalerを使って画像を拡大する + lowreso_imgs = None if is_1st_latent else images_1st + lowreso_latents = None if not is_1st_latent else images_1st + + # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents + batch_size = len(images_1st) + vae_batch_size = ( + batch_size + if args.vae_batch_size is None + else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) + ) + vae_batch_size = int(vae_batch_size) + images_1st = upscaler.upscale( + vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size + ) + + elif args.highres_fix_latents_upscaling: + # latentを拡大する + org_dtype = images_1st.dtype + if images_1st.dtype == torch.bfloat16: + images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない + images_1st = torch.nn.functional.interpolate( + images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" + ) # , antialias=True) + images_1st = images_1st.to(org_dtype) + + else: + # 画像をLANCZOSで拡大する + images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st] + + batch_2nd = [] + for i, (bd, image) in enumerate(zip(batch, images_1st)): + bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) + batch_2nd.append(bd_2nd) + batch = batch_2nd + + if args.highres_fix_disable_control_net: + pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする + + # このバッチの情報を取り出す + ( + return_latents, + (step_first, _, _, _, init_image, mask_image, _, guide_image), + ( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + network_muls, + num_sub_prompts, + ), + ) = batch[0] + noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) + + prompts = [] + negative_prompts = [] + start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + noises = [ + torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + for _ in range(steps * scheduler_num_noises_per_step) + ] + seeds = [] + clip_prompts = [] + + if init_image is not None: # img2img? + i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) + init_images = [] + + if mask_image is not None: + mask_images = [] + else: + mask_images = None + else: + i2i_noises = None + init_images = None + mask_images = None + + if guide_image is not None: # CLIP image guided? + guide_images = [] + else: + guide_images = None + + # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする + all_images_are_same = True + all_masks_are_same = True + all_guide_images_are_same = True + for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): + prompts.append(prompt) + negative_prompts.append(negative_prompt) + seeds.append(seed) + clip_prompts.append(clip_prompt) + + if init_image is not None: + init_images.append(init_image) + if i > 0 and all_images_are_same: + all_images_are_same = init_images[-2] is init_image + + if mask_image is not None: + mask_images.append(mask_image) + if i > 0 and all_masks_are_same: + all_masks_are_same = mask_images[-2] is mask_image + + if guide_image is not None: + if type(guide_image) is list: + guide_images.extend(guide_image) + all_guide_images_are_same = False + else: + guide_images.append(guide_image) + if i > 0 and all_guide_images_are_same: + all_guide_images_are_same = guide_images[-2] is guide_image + + # make start code + torch.manual_seed(seed) + start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + # make each noises + for j in range(steps * scheduler_num_noises_per_step): + noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) + + if i2i_noises is not None: # img2img noise + i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) + + noise_manager.reset_sampler_noises(noises) + + # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する + if init_images is not None and all_images_are_same: + init_images = init_images[0] + if mask_images is not None and all_masks_are_same: + mask_images = mask_images[0] + if guide_images is not None and all_guide_images_are_same: + guide_images = guide_images[0] + + # ControlNet使用時はguide imageをリサイズする + if control_nets: + # TODO resampleのメソッド + guide_images = guide_images if type(guide_images) == list else [guide_images] + guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] + if len(guide_images) == 1: + guide_images = guide_images[0] + + # generate + if networks: + # 追加ネットワークの処理 + shared = {} + for n, m in zip(networks, network_muls if network_muls else network_default_muls): + n.set_multiplier(m) + if regional_network: + n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) + + if not regional_network and network_pre_calc: + for n in networks: + n.restore_weights() + for n in networks: + n.pre_calculation() + print("pre-calculation... done") + + images = pipe( + prompts, + negative_prompts, + init_images, + mask_images, + height, + width, + original_height, + original_width, + original_height_negative, + original_width_negative, + crop_top, + crop_left, + steps, + scale, + negative_scale, + strength, + latents=start_code, + output_type="pil", + max_embeddings_multiples=max_embeddings_multiples, + img2img_noise=i2i_noises, + vae_batch_size=args.vae_batch_size, + return_latents=return_latents, + clip_prompts=clip_prompts, + clip_guide_images=guide_images, + ) + if highres_1st and not args.highres_fix_save_1st: # return images or latents + return images + + # save image + highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts) + ): + if highres_fix: + seed -= 1 # record original seed + metadata = PngInfo() + metadata.add_text("prompt", prompt) + metadata.add_text("seed", str(seed)) + metadata.add_text("sampler", args.sampler) + metadata.add_text("steps", str(steps)) + metadata.add_text("scale", str(scale)) + if negative_prompt is not None: + metadata.add_text("negative-prompt", negative_prompt) + if negative_scale is not None: + metadata.add_text("negative-scale", str(negative_scale)) + if clip_prompt is not None: + metadata.add_text("clip-prompt", clip_prompt) + metadata.add_text("original-height", str(original_height)) + metadata.add_text("original-width", str(original_width)) + metadata.add_text("original-height-negative", str(original_height_negative)) + metadata.add_text("original-width-negative", str(original_width_negative)) + metadata.add_text("crop-top", str(crop_top)) + metadata.add_text("crop-left", str(crop_left)) + + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + + if not args.no_preview and not highres_1st and args.interactive: + try: + import cv2 + + for prompt, image in zip(prompts, images): + cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ + cv2.waitKey() + cv2.destroyAllWindows() + except ImportError: + print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") + + return images + + # 画像生成のプロンプトが一周するまでのループ + prompt_index = 0 + global_step = 0 + batch_data = [] + while args.interactive or prompt_index < len(prompt_list): + if len(prompt_list) == 0: + # interactive + valid = False + while not valid: + print("\nType prompt:") + try: + raw_prompt = input() + except EOFError: + break + + valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 + if not valid: # EOF, end app + break + else: + raw_prompt = prompt_list[prompt_index] + + # sd-dynamic-prompts like variants: + # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) + raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) + + # repeat prompt + for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): + raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + + if pi == 0 or len(raw_prompts) > 1: + # parse prompt: if prompt is not changed, skip parsing + width = args.W + height = args.H + original_width = args.original_width + original_height = args.original_height + original_width_negative = args.original_width_negative + original_height_negative = args.original_height_negative + crop_top = args.crop_top + crop_left = args.crop_left + scale = args.scale + negative_scale = args.negative_scale + steps = args.steps + seed = None + seeds = None + strength = 0.8 if args.strength is None else args.strength + negative_prompt = "" + clip_prompt = None + network_muls = None + + prompt_args = raw_prompt.strip().split(" --") + prompt = prompt_args[0] + print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + + for parg in prompt_args[1:]: + try: + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + print(f"width: {width}") + continue + + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + print(f"height: {height}") + continue + + m = re.match(r"ow (\d+)", parg, re.IGNORECASE) + if m: + original_width = int(m.group(1)) + print(f"original width: {original_width}") + continue + + m = re.match(r"oh (\d+)", parg, re.IGNORECASE) + if m: + original_height = int(m.group(1)) + print(f"original height: {original_height}") + continue + + m = re.match(r"nw (\d+)", parg, re.IGNORECASE) + if m: + original_width_negative = int(m.group(1)) + print(f"original width negative: {original_width_negative}") + continue + + m = re.match(r"nh (\d+)", parg, re.IGNORECASE) + if m: + original_height_negative = int(m.group(1)) + print(f"original height negative: {original_height_negative}") + continue + + m = re.match(r"ct (\d+)", parg, re.IGNORECASE) + if m: + crop_top = int(m.group(1)) + print(f"crop top: {crop_top}") + continue + + m = re.match(r"cl (\d+)", parg, re.IGNORECASE) + if m: + crop_left = int(m.group(1)) + print(f"crop left: {crop_left}") + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + steps = max(1, min(1000, int(m.group(1)))) + print(f"steps: {steps}") + continue + + m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) + if m: # seed + seeds = [int(d) for d in m.group(1).split(",")] + print(f"seeds: {seeds}") + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + print(f"scale: {scale}") + continue + + m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) + if m: # negative scale + if m.group(1).lower() == "none": + negative_scale = None + else: + negative_scale = float(m.group(1)) + print(f"negative scale: {negative_scale}") + continue + + m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) + if m: # strength + strength = float(m.group(1)) + print(f"strength: {strength}") + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + print(f"negative prompt: {negative_prompt}") + continue + + m = re.match(r"c (.+)", parg, re.IGNORECASE) + if m: # clip prompt + clip_prompt = m.group(1) + print(f"clip prompt: {clip_prompt}") + continue + + m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + print(f"network mul: {network_muls}") + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + # prepare seed + if seeds is not None: # given in prompt + # 数が足りないなら前のをそのまま使う + if len(seeds) > 0: + seed = seeds.pop(0) + else: + if predefined_seeds is not None: + if len(predefined_seeds) > 0: + seed = predefined_seeds.pop(0) + else: + print("predefined seeds are exhausted") + seed = None + elif args.iter_same_seed: + seeds = iter_seed + else: + seed = None # 前のを消す + + if seed is None: + seed = random.randint(0, 0x7FFFFFFF) + if args.interactive: + print(f"seed: {seed}") + + # prepare init image, guide image and mask + init_image = mask_image = guide_image = None + + # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する + if init_images is not None: + init_image = init_images[global_step % len(init_images)] + + # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する + # 32単位に丸めたやつにresizeされるので踏襲する + if not highres_fix: + width, height = init_image.size + width = width - width % 32 + height = height - height % 32 + if width != init_image.size[0] or height != init_image.size[1]: + print( + f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" + ) + + if mask_images is not None: + mask_image = mask_images[global_step % len(mask_images)] + + if guide_images is not None: + if control_nets: # 複数件の場合あり + c = len(control_nets) + p = global_step % (len(guide_images) // c) + guide_image = guide_images[p * c : p * c + c] + else: + guide_image = guide_images[global_step % len(guide_images)] + + if regional_network: + num_sub_prompts = len(prompt.split(" AND ")) + assert ( + len(networks) <= num_sub_prompts + ), "Number of networks must be less than or equal to number of sub prompts." + else: + num_sub_prompts = None + + b1 = BatchData( + False, + BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + BatchDataExt( + width, + height, + original_width, + original_height, + original_width_negative, + original_height_negative, + crop_left, + crop_top, + steps, + scale, + negative_scale, + strength, + tuple(network_muls) if network_muls else None, + num_sub_prompts, + ), + ) + if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? + process_batch(batch_data, highres_fix) + batch_data.clear() + + batch_data.append(b1) + if len(batch_data) == args.batch_size: + prev_image = process_batch(batch_data, highres_fix)[0] + batch_data.clear() + + global_step += 1 + + prompt_index += 1 + + if len(batch_data) > 0: + process_batch(batch_data, highres_fix) + batch_data.clear() + + print("done!") + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") + parser.add_argument( + "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む" + ) + parser.add_argument( + "--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)" + ) + parser.add_argument( + "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" + ) + parser.add_argument( + "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" + ) + parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") + parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") + parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") + parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") + parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする") + parser.add_argument( + "--use_original_file_name", + action="store_true", + help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", + ) + # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) + parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") + parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") + parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") + parser.add_argument( + "--original_height", type=int, default=None, help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値" + ) + parser.add_argument( + "--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値" + ) + parser.add_argument( + "--original_height_negative", + type=int, + default=None, + help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値", + ) + parser.add_argument( + "--original_width_negative", + type=int, + default=None, + help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", + ) + parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値") + parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値") + parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") + parser.add_argument( + "--vae_batch_size", + type=float, + default=None, + help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", + ) + parser.add_argument( + "--vae_slices", + type=int, + default=None, + help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", + ) + parser.add_argument("--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない") + parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") + parser.add_argument( + "--sampler", + type=str, + default="ddim", + choices=[ + "ddim", + "pndm", + "lms", + "euler", + "euler_a", + "heun", + "dpm_2", + "dpm_2_a", + "dpmsolver", + "dpmsolver++", + "dpmsingle", + "k_lms", + "k_euler", + "k_euler_a", + "k_dpm_2", + "k_dpm_2_a", + ], + help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", + ) + parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") + parser.add_argument( + "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" + ) + parser.add_argument( + "--tokenizer_cache_dir", + type=str, + default=None, + help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", + ) + # parser.add_argument("--replace_clip_l14_336", action='store_true', + # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") + parser.add_argument( + "--seed", + type=int, + default=None, + help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", + ) + parser.add_argument( + "--iter_same_seed", + action="store_true", + help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", + ) + parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") + parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") + parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") + parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") + parser.add_argument( + "--diffusers_xformers", + action="store_true", + help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", + ) + parser.add_argument( + "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する" + ) + parser.add_argument( + "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名" + ) + parser.add_argument( + "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" + ) + parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率") + parser.add_argument( + "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数" + ) + parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") + parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") + parser.add_argument( + "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" + ) + parser.add_argument( + "--textual_inversion_embeddings", + type=str, + default=None, + nargs="*", + help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", + ) + parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") + parser.add_argument( + "--max_embeddings_multiples", + type=int, + default=None, + help="max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", + ) + parser.add_argument( + "--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガイドに使う画像" + ) + parser.add_argument( + "--highres_fix_scale", + type=float, + default=None, + help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", + ) + parser.add_argument( + "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数" + ) + parser.add_argument( + "--highres_fix_strength", + type=float, + default=None, + help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", + ) + parser.add_argument( + "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する" + ) + parser.add_argument( + "--highres_fix_latents_upscaling", + action="store_true", + help="use latents upscaling for highres fix / highres fixでlatentで拡大する", + ) + parser.add_argument( + "--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名" + ) + parser.add_argument( + "--highres_fix_upscaler_args", + type=str, + default=None, + help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数", + ) + parser.add_argument( + "--highres_fix_disable_control_net", + action="store_true", + help="disable ControlNet for highres fix / highres fixでControlNetを使わない", + ) + + parser.add_argument( + "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" + ) + + parser.add_argument( + "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + ) + parser.add_argument( + "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" + ) + parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み") + parser.add_argument( + "--control_net_ratios", + type=float, + default=None, + nargs="*", + help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + ) + # parser.add_argument( + # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" + # ) + + return parser + + +if __name__ == "__main__": + parser = setup_parser() + + args = parser.parse_args() + main(args) diff --git a/sdxl_train_lora_control_net.py b/sdxl_train_lora_control_net.py new file mode 100644 index 00000000..489c7936 --- /dev/null +++ b/sdxl_train_lora_control_net.py @@ -0,0 +1,823 @@ +import argparse +import gc +import json +import math +import os +import random +import time +from multiprocessing import Value +from types import SimpleNamespace +import toml + +from tqdm import tqdm +import torch +from torch.nn.parallel import DistributedDataParallel as DDP +from accelerate.utils import set_seed +from diffusers import DDPMScheduler, ControlNetModel +from safetensors.torch import load_file +from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util + +import library.model_util as model_util +import library.train_util as train_util +import library.config_util as config_util +from library.config_util import ( + ConfigSanitizer, + BlueprintGenerator, +) +import library.huggingface_util as huggingface_util +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import ( + add_v_prediction_like_loss, + apply_snr_weight, + prepare_scheduler_for_custom_training, + pyramid_noise_like, + apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, +) +import networks.lora_control_net as lora_control_net + + +# TODO 他のスクリプトと共通化する +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = { + "loss/current": current_loss, + "loss/average": avr_loss, + "lr": lr_scheduler.get_last_lr()[0], + } + + if args.optimizer_type.lower().startswith("DAdapt".lower()): + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + + return logs + + +def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + sdxl_train_util.verify_sdxl_training_args(args) + + cache_latents = args.cache_latents + use_user_config = args.dataset_config is not None + + if args.seed is None: + args.seed = random.randint(0, 2**32) + set_seed(args.seed) + + tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) + + # データセットを準備する + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True)) + if use_user_config: + print(f"Load dataset config from {args.dataset_config}") + user_config = config_util.load_user_config(args.dataset_config) + ignored = ["train_data_dir", "conditioning_data_dir"] + if any(getattr(args, attr) is not None for attr in ignored): + print( + "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format( + ", ".join(ignored) + ) + ) + else: + user_config = { + "datasets": [ + { + "subsets": config_util.generate_controlnet_subsets_config_by_subdirs( + args.train_data_dir, + args.conditioning_data_dir, + args.caption_extension, + ) + } + ] + } + + blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2]) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) + + train_dataset_group.verify_bucket_reso_steps(32) + + if args.debug_dataset: + train_util.debug_dataset(train_dataset_group) + return + if len(train_dataset_group) == 0: + print( + "No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)" + ) + return + + if cache_latents: + assert ( + train_dataset_group.is_latent_cacheable() + ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + + if args.cache_text_encoder_outputs: + assert ( + train_dataset_group.is_text_encoder_output_cacheable() + ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません" + + # acceleratorを準備する + print("prepare accelerator") + accelerator = train_util.prepare_accelerator(args) + is_main_process = accelerator.is_main_process + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + vae_dtype = torch.float32 if args.no_half_vae else weight_dtype + + # モデルを読み込む + ( + load_stable_diffusion_format, + text_encoder1, + text_encoder2, + vae, + unet, + logit_scale, + ckpt_info, + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) + + # モデルに xformers とか memory efficient attention を組み込む + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=vae_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset_group.cache_latents( + vae, + args.vae_batch_size, + args.cache_latents_to_disk, + accelerator.is_main_process, + ) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + accelerator.wait_for_everyone() + + # TextEncoderの出力をキャッシュする + if args.cache_text_encoder_outputs: + # Text Encodes are eval and no grad + with torch.no_grad(): + train_dataset_group.cache_text_encoder_outputs( + (tokenizer1, tokenizer2), + (text_encoder1, text_encoder2), + accelerator.device, + None, + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) + accelerator.wait_for_everyone() + + # prepare ControlNet + network = lora_control_net.LoRAControlNet(unet, args.cond_emb_dim, args.network_dim, 1, args.network_dropout) + network.apply_to() + + if args.network_weights is not None: + info = network.load_weights(args.network_weights) + accelerator.print(f"load ControlNet weights from {args.network_weights}: {info}") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + network.enable_gradient_checkpointing() # may have no effect + + # 学習に必要なクラスを準備する + accelerator.print("prepare optimizer, data loader etc.") + + trainable_params = list(network.prepare_optimizer_params()) + print(f"trainable params count: {len(trainable_params)}") + print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}") + + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + + train_dataloader = torch.utils.data.DataLoader( + train_dataset_group, + batch_size=1, + shuffle=True, + collate_fn=collater, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) + + # 学習ステップ数を計算する + if args.max_train_epochs is not None: + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + + # lr schedulerを用意する + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする + if args.full_fp16: + assert ( + args.mixed_precision == "fp16" + ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + accelerator.print("enable full fp16 training.") + unet.to(weight_dtype) + network.to(weight_dtype) + elif args.full_bf16: + assert ( + args.mixed_precision == "bf16" + ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" + accelerator.print("enable full bf16 training.") + unet.to(weight_dtype) + network.to(weight_dtype) + + # acceleratorがなんかよろしくやってくれるらしい + unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader, lr_scheduler + ) + network: lora_control_net.LoRAControlNet + + # transform DDP after prepare (train_network here only) + unet, network = train_util.transform_models_if_DDP([unet, network]) + + if args.gradient_checkpointing: + unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる + else: + unet.eval() + + network.prepare_grad_etc() + + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + text_encoder2.to("cpu", dtype=torch.float32) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=vae_dtype) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 + + # 学習する + # TODO: find a way to handle total batch size when there are multiple datasets + accelerator.print("running training / 学習開始") + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") + accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") + accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + accelerator.print(f" num epochs / epoch数: {num_train_epochs}") + accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") + # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") + + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + if accelerator.is_main_process: + init_kwargs = {} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "lora_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + ) + + loss_list = [] + loss_total = 0.0 + del train_dataset_group + + # function for saving/removing + def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) + sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/lora-control-net" + + unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata) + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # training loop + for epoch in range(num_train_epochs): + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + network.on_epoch_start() # train() + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(network): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample() + + # NaNが含まれていれば警告を表示し0に置き換える + if torch.any(torch.isnan(latents)): + accelerator.print("NaN found in latents, replacing with zeros") + latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents) + latents = latents * sdxl_model_util.VAE_SCALE_FACTOR + + if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: + input_ids1 = batch["input_ids"] + input_ids2 = batch["input_ids2"] + with torch.no_grad(): + # Get the text embedding for conditioning + input_ids1 = input_ids1.to(accelerator.device) + input_ids2 = input_ids2.to(accelerator.device) + encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( + args.max_token_length, + input_ids1, + input_ids2, + tokenizer1, + tokenizer2, + text_encoder1, + text_encoder2, + None if not args.full_fp16 else weight_dtype, + ) + else: + encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) + encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype) + pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype) + + # get size embeddings + orig_size = batch["original_sizes_hw"] + crop_size = batch["crop_top_lefts"] + target_size = batch["target_sizes_hw"] + embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype) + + # concat embeddings + vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) + text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + + # Sample noise, sample a random timestep for each image, and add noise to the latents, + # with noise offset and/or multires noise if specified + noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + + noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype + + controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + + with accelerator.autocast(): + cond_embs_4d, cond_embs_3d = network(controlnet_image) + network.set_cond_embs(cond_embs_4d, cond_embs_3d) + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = network.get_trainable_params() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + # sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + # self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + + # end of epoch + + if is_main_process: + network = accelerator.unwrap_model(network) + + accelerator.end_training() + + if is_main_process and args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True) + + print("model saved.") + + r""" + progress_bar = tqdm( + range(args.max_train_steps), + smoothing=0, + disable=not accelerator.is_local_main_process, + desc="steps", + ) + global_step = 0 + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + ) + if accelerator.is_main_process: + init_kwargs = {} + if args.log_tracker_config is not None: + init_kwargs = toml.load(args.log_tracker_config) + accelerator.init_trackers( + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + ) + + loss_list = [] + loss_total = 0.0 + del train_dataset_group + + # function for saving/removing + def save_model(ckpt_name, model, force_sync_upload=False): + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + + state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) + + if save_dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + if os.path.splitext(ckpt_file)[1] == ".safetensors": + from safetensors.torch import save_file + + save_file(state_dict, ckpt_file) + else: + torch.save(state_dict, ckpt_file) + + if args.huggingface_repo_id is not None: + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) + + def remove_model(old_ckpt_name): + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + accelerator.print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + # training loop + for epoch in range(num_train_epochs): + if is_main_process: + accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") + current_epoch.value = epoch + 1 + + for step, batch in enumerate(train_dataloader): + current_step.value = global_step + with accelerator.accumulate(controlnet): + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + b_size = latents.shape[0] + + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + if args.noise_offset: + noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) + elif args.multires_noise_iterations: + noise = pyramid_noise_like( + noise, + latents.device, + args.multires_noise_iterations, + args.multires_noise_discount, + ) + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (b_size,), + device=latents.device, + ) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) + + with accelerator.autocast(): + down_block_res_samples, mid_block_res_sample = controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_image, + return_dict=False, + ) + + # Predict the noise residual + noise_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states, + down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), + ).sample + + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + loss_weights = batch["loss_weights"] # 各sampleごとのweight + loss = loss * loss_weights + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし + + accelerator.backward(loss) + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + train_util.sample_images( + accelerator, + args, + None, + global_step, + accelerator.device, + vae, + tokenizer, + text_encoder, + unet, + controlnet=controlnet, + ) + + # 指定ステップごとにモデルを保存 + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) + save_model( + ckpt_name, + accelerator.unwrap_model(controlnet), + ) + + if args.save_state: + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) + + remove_step_no = train_util.get_remove_step_no(args, global_step) + if remove_step_no is not None: + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) + remove_model(remove_ckpt_name) + + current_loss = loss.detach().item() + if epoch == 0: + loss_list.append(current_loss) + else: + loss_total -= loss_list[step] + loss_list[step] = current_loss + loss_total += current_loss + avr_loss = loss_total / len(loss_list) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if args.logging_dir is not None: + logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"loss/epoch": loss_total / len(loss_list)} + accelerator.log(logs, step=epoch + 1) + + accelerator.wait_for_everyone() + + # 指定エポックごとにモデルを保存 + if args.save_every_n_epochs is not None: + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs + if is_main_process and saving: + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(controlnet)) + + remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) + if remove_epoch_no is not None: + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) + remove_model(remove_ckpt_name) + + if args.save_state: + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) + + train_util.sample_images( + accelerator, + args, + epoch + 1, + global_step, + accelerator.device, + vae, + tokenizer, + text_encoder, + unet, + controlnet=controlnet, + ) + + # end of epoch + if is_main_process: + controlnet = accelerator.unwrap_model(controlnet) + + accelerator.end_training() + + if is_main_process and args.save_state: + train_util.save_state_on_train_end(args, accelerator) + + # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく + + if is_main_process: + ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) + save_model(ckpt_name, controlnet, force_sync_upload=True) + + print("model saved.") + """ + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True, True) + train_util.add_training_arguments(parser, False) + train_util.add_optimizer_arguments(parser) + config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) + sdxl_train_util.add_sdxl_training_arguments(parser) + + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数") + parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") + parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数") + parser.add_argument( + "--network_dropout", + type=float, + default=None, + help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)", + ) + parser.add_argument( + "--conditioning_data_dir", + type=str, + default=None, + help="conditioning data directory / 条件付けデータのディレクトリ", + ) + parser.add_argument( + "--no_half_vae", + action="store_true", + help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", + ) + return parser + + +if __name__ == "__main__": + # sdxl_original_unet.USE_REENTRANT = False + + parser = setup_parser() + + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + + train(args) From 306ee24c90a9f49b4b535a6e735ebe4ef930ca3a Mon Sep 17 00:00:00 2001 From: ykume Date: Thu, 17 Aug 2023 10:19:14 +0900 Subject: [PATCH 03/19] change to use_reentrant=False --- networks/lora_control_net.py | 4 ++-- sdxl_train_lora_control_net.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/networks/lora_control_net.py b/networks/lora_control_net.py index 0189c632..a65cc1fd 100644 --- a/networks/lora_control_net.py +++ b/networks/lora_control_net.py @@ -5,7 +5,7 @@ from networks.lora import LoRAModule, LoRANetwork from library import sdxl_original_unet -SKIP_OUTPUT_BLOCKS = False +SKIP_OUTPUT_BLOCKS = True SKIP_CONV2D = False TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored ATTN1_ETC_ONLY = True @@ -286,7 +286,7 @@ if __name__ == "__main__": unet.to("cuda").to(torch.float16) print("create LoRA controlnet") - control_net = LoRAControlNet(unet, 128, 32, 1) + control_net = LoRAControlNet(unet, 256, 64, 1) control_net.apply_to() control_net.to("cuda") diff --git a/sdxl_train_lora_control_net.py b/sdxl_train_lora_control_net.py index 489c7936..92469add 100644 --- a/sdxl_train_lora_control_net.py +++ b/sdxl_train_lora_control_net.py @@ -813,7 +813,7 @@ def setup_parser() -> argparse.ArgumentParser: if __name__ == "__main__": - # sdxl_original_unet.USE_REENTRANT = False + sdxl_original_unet.USE_REENTRANT = False parser = setup_parser() From afc03af3ca0ff67cbcd7991654329ee1c311b301 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 17 Aug 2023 12:10:52 +0900 Subject: [PATCH 04/19] read dim/rank from weights --- sdxl_gen_img_lora_ctrl_test.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/sdxl_gen_img_lora_ctrl_test.py b/sdxl_gen_img_lora_ctrl_test.py index 4820aa3f..e8b22ee1 100644 --- a/sdxl_gen_img_lora_ctrl_test.py +++ b/sdxl_gen_img_lora_ctrl_test.py @@ -1556,8 +1556,18 @@ def main(args): from safetensors.torch import load_file state_dict = load_file(model_file) + lora_rank = None + emb_dim = None + for key, value in state_dict.items(): + if lora_rank is None and "lora_down.weight" in key: + lora_rank = value.shape[0] + elif emb_dim is None and "conditioning1.0" in key: + emb_dim = value.shape[0] + if lora_rank is not None and emb_dim is not None: + break + assert lora_rank is not None and emb_dim is not None, f"invalid control net: {model_file}" - control_net = LoRAControlNet(unet, 128, 32, 1) # TODO load from weights + control_net = LoRAControlNet(unet, emb_dim, lora_rank, 1) control_net.apply_to() control_net.load_state_dict(state_dict) control_net.to(dtype).to(device) From 6111151f509269fec0a6ba9baf2f5c7f7f63f6a1 Mon Sep 17 00:00:00 2001 From: ykume Date: Thu, 17 Aug 2023 13:17:43 +0900 Subject: [PATCH 05/19] add skip input blocks to lora control net --- networks/lora_control_net.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/networks/lora_control_net.py b/networks/lora_control_net.py index a65cc1fd..7a026eba 100644 --- a/networks/lora_control_net.py +++ b/networks/lora_control_net.py @@ -5,7 +5,8 @@ from networks.lora import LoRAModule, LoRANetwork from library import sdxl_original_unet -SKIP_OUTPUT_BLOCKS = True +SKIP_INPUT_BLOCKS = True +SKIP_OUTPUT_BLOCKS = False SKIP_CONV2D = False TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored ATTN1_ETC_ONLY = True @@ -123,6 +124,8 @@ class LoRAControlNet(torch.nn.Module): block_name, index1, index2 = (name + "." + child_name).split(".")[:3] index1 = int(index1) if block_name == "input_blocks": + if SKIP_INPUT_BLOCKS: + continue depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3) elif block_name == "middle_block": depth = 3 From 5fa473d5f3974fd4b1ad80f586e07b5c97634e94 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 17 Aug 2023 16:25:23 +0900 Subject: [PATCH 06/19] add cond/uncond, update config --- networks/lora_control_net.py | 35 ++++++++++++++++++++++++---------- sdxl_gen_img_lora_ctrl_test.py | 2 +- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/networks/lora_control_net.py b/networks/lora_control_net.py index 7a026eba..9259389f 100644 --- a/networks/lora_control_net.py +++ b/networks/lora_control_net.py @@ -5,11 +5,12 @@ from networks.lora import LoRAModule, LoRANetwork from library import sdxl_original_unet -SKIP_INPUT_BLOCKS = True -SKIP_OUTPUT_BLOCKS = False +SKIP_INPUT_BLOCKS = False +SKIP_OUTPUT_BLOCKS = True SKIP_CONV2D = False TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored ATTN1_ETC_ONLY = True +TRANSFORMER_MAX_BLOCK_INDEX = 3 # None # 2 # None for all blocks class LoRAModuleControlNet(LoRAModule): @@ -48,15 +49,17 @@ class LoRAModuleControlNet(LoRAModule): self.depth = depth self.cond_emb = None - self.batch_cond_uncond_enabled = False + self.batch_cond_only = False + self.use_zeros_for_batch_uncond = False def set_cond_embs(self, cond_embs_4d, cond_embs_3d): cond_embs = cond_embs_4d if self.is_conv2d else cond_embs_3d cond_emb = cond_embs[self.depth - 1] self.cond_emb = self.conditioning1(cond_emb) - def set_batch_cond_uncond_enabled(self, enabled): - self.batch_cond_uncond_enabled = enabled + def set_batch_cond_only(self, cond_only, zeros): + self.batch_cond_only = cond_only + self.use_zeros_for_batch_uncond = zeros def forward(self, x): if self.cond_emb is None: @@ -64,7 +67,7 @@ class LoRAModuleControlNet(LoRAModule): # LoRA lx = x - if self.batch_cond_uncond_enabled: + if self.batch_cond_only: lx = lx[1::2] # cond only lx = self.lora_down(lx) @@ -75,6 +78,10 @@ class LoRAModuleControlNet(LoRAModule): # conditioning image cx = self.cond_emb # print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") + if not self.batch_cond_only and cx.shape[0] // 2 == lx.shape[0]: # inference only + cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) + if self.use_zeros_for_batch_uncond: + cx[0::2] = 0.0 # uncond is zero cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2) cx = self.conditioning2(cx) @@ -84,7 +91,7 @@ class LoRAModuleControlNet(LoRAModule): x = self.org_forward(x) - if self.batch_cond_uncond_enabled: + if self.batch_cond_only: x[1::2] += lx * self.multiplier * self.scale else: x += lx * self.multiplier * self.scale @@ -141,6 +148,13 @@ class LoRAControlNet(torch.nn.Module): lora_name = prefix + "." + name + "." + child_name lora_name = lora_name.replace(".", "_") + if TRANSFORMER_MAX_BLOCK_INDEX is not None: + p = lora_name.find("transformer_blocks") + if p >= 0: + tf_index = int(lora_name[p:].split("_")[2]) + if tf_index > TRANSFORMER_MAX_BLOCK_INDEX: + continue + # skip time emb or clip emb if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)): continue @@ -215,9 +229,9 @@ class LoRAControlNet(torch.nn.Module): for lora in self.unet_loras: lora.set_cond_embs(cond_embs_4d, cond_embs_3d) - def set_batch_cond_uncond_enabled(self, enabled): + def set_batch_cond_only(self, cond_only, zeros): for lora in self.unet_loras: - lora.set_batch_cond_uncond_enabled(enabled) + lora.set_batch_cond_only(cond_only, zeros) def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": @@ -294,11 +308,12 @@ if __name__ == "__main__": control_net.to("cuda") print(control_net) - input() # print number of parameters print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad)) + input() + unet.set_use_memory_efficient_attention(True, False) unet.set_gradient_checkpointing(True) unet.train() # for gradient checkpointing diff --git a/sdxl_gen_img_lora_ctrl_test.py b/sdxl_gen_img_lora_ctrl_test.py index e8b22ee1..b1cab576 100644 --- a/sdxl_gen_img_lora_ctrl_test.py +++ b/sdxl_gen_img_lora_ctrl_test.py @@ -1571,7 +1571,7 @@ def main(args): control_net.apply_to() control_net.load_state_dict(state_dict) control_net.to(dtype).to(device) - control_net.set_batch_cond_uncond_enabled(True) + control_net.set_batch_cond_only(False, False) control_nets.append(control_net) if args.opt_channels_last: From 809fca0be9783dc9887595f72a116da6bd2ca601 Mon Sep 17 00:00:00 2001 From: ykume Date: Thu, 17 Aug 2023 18:31:29 +0900 Subject: [PATCH 07/19] fix error in generation --- networks/lora_control_net.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/networks/lora_control_net.py b/networks/lora_control_net.py index 9259389f..11b4db90 100644 --- a/networks/lora_control_net.py +++ b/networks/lora_control_net.py @@ -9,8 +9,8 @@ SKIP_INPUT_BLOCKS = False SKIP_OUTPUT_BLOCKS = True SKIP_CONV2D = False TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored -ATTN1_ETC_ONLY = True -TRANSFORMER_MAX_BLOCK_INDEX = 3 # None # 2 # None for all blocks +ATTN1_ETC_ONLY = False # True +TRANSFORMER_MAX_BLOCK_INDEX = None # 3 # None # 2 # None for all blocks class LoRAModuleControlNet(LoRAModule): @@ -77,11 +77,11 @@ class LoRAModuleControlNet(LoRAModule): # conditioning image cx = self.cond_emb - # print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") - if not self.batch_cond_only and cx.shape[0] // 2 == lx.shape[0]: # inference only + if not self.batch_cond_only and lx.shape[0] // 2 == cx.shape[0]: # inference only cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) if self.use_zeros_for_batch_uncond: cx[0::2] = 0.0 # uncond is zero + # print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2) cx = self.conditioning2(cx) @@ -303,7 +303,7 @@ if __name__ == "__main__": unet.to("cuda").to(torch.float16) print("create LoRA controlnet") - control_net = LoRAControlNet(unet, 256, 64, 1) + control_net = LoRAControlNet(unet, 128, 64, 1) control_net.apply_to() control_net.to("cuda") From 1e52fe6e09ea1aec95716b9aaa2e6837eaa08213 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 17 Aug 2023 20:49:39 +0900 Subject: [PATCH 08/19] add comments --- networks/lora_control_net.py | 95 ++++++++++-- sdxl_train_lora_control_net.py | 261 +-------------------------------- 2 files changed, 85 insertions(+), 271 deletions(-) diff --git a/networks/lora_control_net.py b/networks/lora_control_net.py index 11b4db90..0dd2a0a1 100644 --- a/networks/lora_control_net.py +++ b/networks/lora_control_net.py @@ -5,12 +5,25 @@ from networks.lora import LoRAModule, LoRANetwork from library import sdxl_original_unet +# input_blocksに適用するかどうか / if True, input_blocks are not applied SKIP_INPUT_BLOCKS = False + +# output_blocksに適用するかどうか / if True, output_blocks are not applied SKIP_OUTPUT_BLOCKS = True + +# conv2dに適用するかどうか / if True, conv2d are not applied SKIP_CONV2D = False -TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored -ATTN1_ETC_ONLY = False # True -TRANSFORMER_MAX_BLOCK_INDEX = None # 3 # None # 2 # None for all blocks + +# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない +# if True, only transformer_blocks are applied, and ResBlocks are not applied +TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks + +# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2 +ATTN1_ETC_ONLY = False # True + +# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用 +# max index of transformer_blocks. if None, apply to all transformer_blocks +TRANSFORMER_MAX_BLOCK_INDEX = None class LoRAModuleControlNet(LoRAModule): @@ -19,6 +32,16 @@ class LoRAModuleControlNet(LoRAModule): self.is_conv2d = org_module.__class__.__name__ == "Conv2d" self.cond_emb_dim = cond_emb_dim + # conditioning1は、conditioning image embeddingを、各LoRA的モジュールでさらに学習する。ここはtimestepごとに呼ばれない + # それぞれのモジュールで異なる表現を学習することを期待している + # conditioning1 learns conditioning image embedding in each LoRA-like module. this is not called for each timestep + # we expect to learn different representations in each module + + # conditioning2は、conditioning1の出力とLoRAの出力を結合し、LoRAの出力に加算する。timestepごとに呼ばれる + # conditioning image embeddingとU-Netの出力を合わせて学ぶことで、conditioningに応じたU-Netの調整を行う + # conditioning2 combines the output of conditioning1 and the output of LoRA, and adds it to the output of LoRA. this is called for each timestep + # by learning the output of conditioning image embedding and U-Net together, we adjust U-Net according to conditioning + if self.is_conv2d: self.conditioning1 = torch.nn.Sequential( torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0), @@ -45,16 +68,26 @@ class LoRAModuleControlNet(LoRAModule): torch.nn.Linear(cond_emb_dim, lora_dim), torch.nn.ReLU(inplace=True), ) + + # Zero-Convにするならコメントを外す / uncomment if you want to use Zero-Conv # torch.nn.init.zeros_(self.conditioning2[-2].weight) # zero conv - self.depth = depth + self.depth = depth # 1~3 self.cond_emb = None - self.batch_cond_only = False - self.use_zeros_for_batch_uncond = False + self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference + self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0 def set_cond_embs(self, cond_embs_4d, cond_embs_3d): + r""" + 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む + / call the model inside, so if necessary, surround it with torch.no_grad() + """ + # conv2dとlinearでshapeが違うので必要な方を選択 / select the required one because the shape is different for conv2d and linear cond_embs = cond_embs_4d if self.is_conv2d else cond_embs_3d + cond_emb = cond_embs[self.depth - 1] + + # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance self.cond_emb = self.conditioning1(cond_emb) def set_batch_cond_only(self, cond_only, zeros): @@ -65,32 +98,39 @@ class LoRAModuleControlNet(LoRAModule): if self.cond_emb is None: return self.org_forward(x) - # LoRA + # LoRA-Down lx = x if self.batch_cond_only: - lx = lx[1::2] # cond only + lx = lx[1::2] # cond only in inference lx = self.lora_down(lx) if self.dropout is not None and self.training: lx = torch.nn.functional.dropout(lx, p=self.dropout) - # conditioning image + # conditioning image embeddingを結合 / combine conditioning image embedding cx = self.cond_emb + if not self.batch_cond_only and lx.shape[0] // 2 == cx.shape[0]: # inference only cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) if self.use_zeros_for_batch_uncond: cx[0::2] = 0.0 # uncond is zero # print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") + # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している + # we expect that it will mix well by combining in the channel direction instead of adding cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2) cx = self.conditioning2(cx) - lx = lx + cx + lx = lx + cx # lxはresidual的に加算される / lx is added residually + + # LoRA-Up lx = self.lora_up(lx) + # call original module x = self.org_forward(x) + # add LoRA if self.batch_cond_only: x[1::2] += lx * self.multiplier * self.scale else: @@ -127,6 +167,7 @@ class LoRAControlNet(torch.nn.Module): is_conv2d = child_module.__class__.__name__ == "Conv2d" if is_linear or (is_conv2d and not SKIP_CONV2D): + # block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う # block index to depth: depth is using to calculate conditioning size and channels block_name, index1, index2 = (name + "." + child_name).split(".")[:3] index1 = int(index1) @@ -155,7 +196,10 @@ class LoRAControlNet(torch.nn.Module): if tf_index > TRANSFORMER_MAX_BLOCK_INDEX: continue - # skip time emb or clip emb + # time embは適用外とする + # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない + # time emb is not applied + # attn2 conditioning (input from CLIP) cannot be applied because the shape is different if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)): continue @@ -191,8 +235,22 @@ class LoRAControlNet(torch.nn.Module): print(f"create ControlNet LoRA for U-Net: {len(self.unet_loras)} modules.") # conditioning image embedding + + # control画像そのままではLoRA的モジュールの入力にはサイズもチャネルも扱いにくいので、 + # 適切な潜在空間に変換する。ここでは、conditioning image embeddingと呼ぶ + # ただcontrol画像自体にはあまり情報量はないので、conditioning image embeddingはわりと小さくてよいはず + # また、conditioning image embeddingは、各LoRA的モジュールでさらに個別に学習する + # depthに応じて3つのサイズを用意する + + # conditioning image embedding is converted to an appropriate latent space + # because the size and channels of the input to the LoRA-like module are difficult to handle + # we call it conditioning image embedding + # however, the control image itself does not have much information, so the conditioning image embedding should be small + # conditioning image embedding is also learned individually in each LoRA-like module + # prepare three sizes according to depth + self.cond_block0 = torch.nn.Sequential( - torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0), # to latent size + torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0), # to latent (from VAE) size torch.nn.ReLU(inplace=True), torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=3, stride=2, padding=1), torch.nn.ReLU(inplace=True), @@ -216,7 +274,7 @@ class LoRAControlNet(torch.nn.Module): x = self.cond_block2(x) x2 = x - x_3d = [] + x_3d = [] # for Linear for x0 in [x0, x1, x2]: # b,c,h,w -> b,h*w,c n, c, h, w = x0.shape @@ -226,6 +284,10 @@ class LoRAControlNet(torch.nn.Module): return [x0, x1, x2], x_3d def set_cond_embs(self, cond_embs_4d, cond_embs_3d): + r""" + 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む + / call the model inside, so if necessary, surround it with torch.no_grad() + """ for lora in self.unet_loras: lora.set_cond_embs(cond_embs_4d, cond_embs_3d) @@ -295,6 +357,9 @@ class LoRAControlNet(torch.nn.Module): if __name__ == "__main__": + # デバッグ用 / for debug + + # これを指定しないとエラーが出てcond_blockが学習できない / if not specified, an error occurs and cond_block cannot be learned sdxl_original_unet.USE_REENTRANT = False # test shape etc @@ -303,7 +368,7 @@ if __name__ == "__main__": unet.to("cuda").to(torch.float16) print("create LoRA controlnet") - control_net = LoRAControlNet(unet, 128, 64, 1) + control_net = LoRAControlNet(unet, 64, 16, 1) control_net.apply_to() control_net.to("cuda") @@ -329,7 +394,7 @@ if __name__ == "__main__": # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) # print("render") # image.format = "svg" # "png" - # image.render("NeuralNet") + # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time # input() import bitsandbytes diff --git a/sdxl_train_lora_control_net.py b/sdxl_train_lora_control_net.py index 92469add..e0ec3a6a 100644 --- a/sdxl_train_lora_control_net.py +++ b/sdxl_train_lora_control_net.py @@ -401,8 +401,13 @@ def train(args): controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) with accelerator.autocast(): + # conditioning image embeddingを計算する / calculate conditioning image embedding cond_embs_4d, cond_embs_3d = network(controlnet_image) + + # 個別のLoRA的モジュールでさらにembeddingを計算する / calculate embedding in each LoRA-like module network.set_cond_embs(cond_embs_4d, cond_embs_3d) + + # それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) if args.v_parameterization: @@ -514,262 +519,6 @@ def train(args): print("model saved.") - r""" - progress_bar = tqdm( - range(args.max_train_steps), - smoothing=0, - disable=not accelerator.is_local_main_process, - desc="steps", - ) - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - clip_sample=False, - ) - if accelerator.is_main_process: - init_kwargs = {} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs - ) - - loss_list = [] - loss_total = 0.0 - del train_dataset_group - - # function for saving/removing - def save_model(ckpt_name, model, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - accelerator.print(f"\nsaving checkpoint: {ckpt_file}") - - state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) - - if save_dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - if os.path.splitext(ckpt_file)[1] == ".safetensors": - from safetensors.torch import save_file - - save_file(state_dict, ckpt_file) - else: - torch.save(state_dict, ckpt_file) - - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) - - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - accelerator.print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - # training loop - for epoch in range(num_train_epochs): - if is_main_process: - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(controlnet): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) - elif args.multires_noise_iterations: - noise = pyramid_noise_like( - noise, - latents.device, - args.multires_noise_iterations, - args.multires_noise_discount, - ) - - # Sample a random timestep for each image - timesteps = torch.randint( - 0, - noise_scheduler.config.num_train_timesteps, - (b_size,), - device=latents.device, - ) - timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) - - with accelerator.autocast(): - down_block_res_samples, mid_block_res_sample = controlnet( - noisy_latents, - timesteps, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=controlnet_image, - return_dict=False, - ) - - # Predict the noise residual - noise_pred = unet( - noisy_latents, - timesteps, - encoder_hidden_states, - down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples], - mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), - ).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = controlnet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - train_util.sample_images( - accelerator, - args, - None, - global_step, - accelerator.device, - vae, - tokenizer, - text_encoder, - unet, - controlnet=controlnet, - ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model( - ckpt_name, - accelerator.unwrap_model(controlnet), - ) - - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) - - current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if args.logging_dir is not None: - logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - # 指定エポックごとにモデルを保存 - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, accelerator.unwrap_model(controlnet)) - - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) - - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - - train_util.sample_images( - accelerator, - args, - epoch + 1, - global_step, - accelerator.device, - vae, - tokenizer, - text_encoder, - unet, - controlnet=controlnet, - ) - - # end of epoch - if is_main_process: - controlnet = accelerator.unwrap_model(controlnet) - - accelerator.end_training() - - if is_main_process and args.save_state: - train_util.save_state_on_train_end(args, accelerator) - - # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく - - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, controlnet, force_sync_upload=True) - - print("model saved.") - """ - def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() From 3e1591661e2a7180eb33e5bba07c2e6c7e6442a2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 17 Aug 2023 22:02:07 +0900 Subject: [PATCH 09/19] add readme about controlnet-lora --- docs/train_lll_README-ja.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 docs/train_lll_README-ja.md diff --git a/docs/train_lll_README-ja.md b/docs/train_lll_README-ja.md new file mode 100644 index 00000000..33ee0ca8 --- /dev/null +++ b/docs/train_lll_README-ja.md @@ -0,0 +1,29 @@ +# ConrtolNet-LLLite について + +## 概要 +ConrtolNet-LLLite は、[ConrtolNet](https://github.com/lllyasviel/ControlNet) の軽量版です。LoRA Like Lite という意味で、LoRAに似た構造の軽量なControlNetです。現在はSDXLにのみ対応しています。 + +## モデル構造 +制御用画像(以下conditioning image)を潜在空間に写像するconditioning image embeddingと、U-Netの各モジュールに付与されるLoRA的な構造を持つモジュールを組み合わせたモデルです。詳しくはソースコードを参照してください。 + +## モデルの学習 + +### データセットの準備 +通常のdatasetに加え、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。 + +```toml +[[datasets.subsets]] +image_dir = "path/to/image/dir" +caption_extension = ".txt" +conditioning_data_dir = "path/to/conditioning/image/dir" +``` + +### 学習 +`sdxl_train_lora_control_net.py` を実行してください。`--cond_emb_dim` でconditioning image embeddingの次元数を指定できます。`--network_dim` でLoRA的モジュールのrankを指定できます。その他のオプションは`sdxl_train_network.py`に準じます。 + +### 推論 +`sdxl_gen_img_lora_ctrl_test.py` を実行してください。`--control_net_models` でLLLiteのモデルファイルを指定できます。次元数はモデルファイルから自動取得します。 + +`--guide_image_path`で推論に用いるconditioning imageを指定してください。なおpreprocessは行われないため、たとえばCannyならCanny処理を行った画像を指定してください(背景黒に白線)。`--control_net_preps`, `--control_net_weights`, `--control_net_ratios` には未対応です。 + +その他のオプションは`sdxl_gen_img.py`に準じます。 \ No newline at end of file From b5db90c8a848203f028e2b0d2c50a4d3f4dfd882 Mon Sep 17 00:00:00 2001 From: ykume Date: Fri, 18 Aug 2023 09:00:22 +0900 Subject: [PATCH 10/19] modify to attn1/attn2 only --- networks/lora_control_net.py | 10 +++++++++- sdxl_train_lora_control_net.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/networks/lora_control_net.py b/networks/lora_control_net.py index 0dd2a0a1..120ab0ac 100644 --- a/networks/lora_control_net.py +++ b/networks/lora_control_net.py @@ -18,7 +18,11 @@ SKIP_CONV2D = False # if True, only transformer_blocks are applied, and ResBlocks are not applied TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks +# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc. +ATTN1_2_ONLY = True + # Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2 +# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY ATTN1_ETC_ONLY = False # True # transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用 @@ -203,6 +207,10 @@ class LoRAControlNet(torch.nn.Module): if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)): continue + if ATTN1_2_ONLY: + if not ("attn1" in lora_name or "attn2" in lora_name): + continue + if ATTN1_ETC_ONLY: if "proj_out" in lora_name: pass @@ -368,7 +376,7 @@ if __name__ == "__main__": unet.to("cuda").to(torch.float16) print("create LoRA controlnet") - control_net = LoRAControlNet(unet, 64, 16, 1) + control_net = LoRAControlNet(unet, 64, 32, 1) control_net.apply_to() control_net.to("cuda") diff --git a/sdxl_train_lora_control_net.py b/sdxl_train_lora_control_net.py index e0ec3a6a..b6fb1dec 100644 --- a/sdxl_train_lora_control_net.py +++ b/sdxl_train_lora_control_net.py @@ -325,7 +325,7 @@ def train(args): accelerator.print(f"\nsaving checkpoint: {ckpt_file}") sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) - sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/lora-control-net" + sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-llite" unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata) if args.huggingface_repo_id is not None: From fef7eb73ad77566759fc74276ffb9fc8d1cafeec Mon Sep 17 00:00:00 2001 From: ykume Date: Sat, 19 Aug 2023 18:44:40 +0900 Subject: [PATCH 11/19] rename and update --- ...a_control_net.py => control_net_lllite.py} | 302 +- sdxl_gen_img.py | 145 +- sdxl_gen_img_lora_ctrl_test.py | 2612 ----------------- ...net.py => sdxl_train_control_net_lllite.py | 18 +- 4 files changed, 253 insertions(+), 2824 deletions(-) rename networks/{lora_control_net.py => control_net_lllite.py} (54%) delete mode 100644 sdxl_gen_img_lora_ctrl_test.py rename sdxl_train_lora_control_net.py => sdxl_train_control_net_lllite.py (97%) diff --git a/networks/lora_control_net.py b/networks/control_net_lllite.py similarity index 54% rename from networks/lora_control_net.py rename to networks/control_net_lllite.py index 120ab0ac..e0157080 100644 --- a/networks/lora_control_net.py +++ b/networks/control_net_lllite.py @@ -1,7 +1,6 @@ import os from typing import Optional, List, Type import torch -from networks.lora import LoRAModule, LoRANetwork from library import sdxl_original_unet @@ -21,6 +20,9 @@ TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not # Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc. ATTN1_2_ONLY = True +# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified +ATTN_QKV_ONLY = True + # Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2 # ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY ATTN1_ETC_ONLY = False # True @@ -30,126 +32,159 @@ ATTN1_ETC_ONLY = False # True TRANSFORMER_MAX_BLOCK_INDEX = None -class LoRAModuleControlNet(LoRAModule): - def __init__(self, depth, cond_emb_dim, name, org_module, multiplier, lora_dim, alpha, dropout=None): - super().__init__(name, org_module, multiplier, lora_dim, alpha, dropout=dropout) +class LLLiteModule(torch.nn.Module): + def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None): + super().__init__() + self.is_conv2d = org_module.__class__.__name__ == "Conv2d" + self.lllite_name = name self.cond_emb_dim = cond_emb_dim - - # conditioning1は、conditioning image embeddingを、各LoRA的モジュールでさらに学習する。ここはtimestepごとに呼ばれない - # それぞれのモジュールで異なる表現を学習することを期待している - # conditioning1 learns conditioning image embedding in each LoRA-like module. this is not called for each timestep - # we expect to learn different representations in each module - - # conditioning2は、conditioning1の出力とLoRAの出力を結合し、LoRAの出力に加算する。timestepごとに呼ばれる - # conditioning image embeddingとU-Netの出力を合わせて学ぶことで、conditioningに応じたU-Netの調整を行う - # conditioning2 combines the output of conditioning1 and the output of LoRA, and adds it to the output of LoRA. this is called for each timestep - # by learning the output of conditioning image embedding and U-Net together, we adjust U-Net according to conditioning + self.org_module = [org_module] + self.dropout = dropout if self.is_conv2d: - self.conditioning1 = torch.nn.Sequential( - torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0), - torch.nn.ReLU(inplace=True), - torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0), + in_dim = org_module.in_channels + else: + in_dim = org_module.in_features + + # conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない + # conditioning1 embeds conditioning image. it is not called for each timestep + modules = [] + modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size + if depth == 1: + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) + elif depth == 2: + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0)) + elif depth == 3: + # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4 + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) + modules.append(torch.nn.ReLU(inplace=True)) + modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0)) + + self.conditioning1 = torch.nn.Sequential(*modules) + + # downで入力の次元数を削減する。LoRAにヒントを得ていることにする + # midでconditioning image embeddingと入力を結合する + # upで元の次元数に戻す + # これらはtimestepごとに呼ばれる + # reduce the number of input dimensions with down. inspired by LoRA + # combine conditioning image embedding and input with mid + # restore to the original dimension with up + # these are called for each timestep + + if self.is_conv2d: + self.down = torch.nn.Sequential( + torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0), torch.nn.ReLU(inplace=True), ) - self.conditioning2 = torch.nn.Sequential( - torch.nn.Conv2d(lora_dim + cond_emb_dim, cond_emb_dim, kernel_size=1, stride=1, padding=0), - torch.nn.ReLU(inplace=True), - torch.nn.Conv2d(cond_emb_dim, lora_dim, kernel_size=1, stride=1, padding=0), + self.mid = torch.nn.Sequential( + torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0), torch.nn.ReLU(inplace=True), ) + self.up = torch.nn.Sequential( + torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0), + ) else: - self.conditioning1 = torch.nn.Sequential( - torch.nn.Linear(cond_emb_dim, cond_emb_dim), - torch.nn.ReLU(inplace=True), - torch.nn.Linear(cond_emb_dim, cond_emb_dim), + # midの前にconditioningをreshapeすること / reshape conditioning before mid + self.down = torch.nn.Sequential( + torch.nn.Linear(in_dim, mlp_dim), torch.nn.ReLU(inplace=True), ) - self.conditioning2 = torch.nn.Sequential( - torch.nn.Linear(lora_dim + cond_emb_dim, cond_emb_dim), - torch.nn.ReLU(inplace=True), - torch.nn.Linear(cond_emb_dim, lora_dim), + self.mid = torch.nn.Sequential( + torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim), torch.nn.ReLU(inplace=True), ) + self.up = torch.nn.Sequential( + torch.nn.Linear(mlp_dim, in_dim), + ) - # Zero-Convにするならコメントを外す / uncomment if you want to use Zero-Conv - # torch.nn.init.zeros_(self.conditioning2[-2].weight) # zero conv + # Zero-Convにする / set to Zero-Conv + torch.nn.init.zeros_(self.up[0].weight) # zero conv self.depth = depth # 1~3 self.cond_emb = None self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0 - def set_cond_embs(self, cond_embs_4d, cond_embs_3d): + # batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない + # Controlの種類によっては使えるかも + # both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice + # it may be available depending on the type of Control + + def set_cond_image(self, cond_image): r""" 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む / call the model inside, so if necessary, surround it with torch.no_grad() """ - # conv2dとlinearでshapeが違うので必要な方を選択 / select the required one because the shape is different for conv2d and linear - cond_embs = cond_embs_4d if self.is_conv2d else cond_embs_3d - - cond_emb = cond_embs[self.depth - 1] - # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance - self.cond_emb = self.conditioning1(cond_emb) + cx = self.conditioning1(cond_image) + if not self.is_conv2d: + # reshape / b,c,h,w -> b,h*w,c + n, c, h, w = cx.shape + cx = cx.view(n, c, h * w).permute(0, 2, 1) + self.cond_emb = cx def set_batch_cond_only(self, cond_only, zeros): self.batch_cond_only = cond_only self.use_zeros_for_batch_uncond = zeros + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward + def forward(self, x): - if self.cond_emb is None: - return self.org_forward(x) - - # LoRA-Down - lx = x - if self.batch_cond_only: - lx = lx[1::2] # cond only in inference - - lx = self.lora_down(lx) - - if self.dropout is not None and self.training: - lx = torch.nn.functional.dropout(lx, p=self.dropout) - - # conditioning image embeddingを結合 / combine conditioning image embedding + r""" + 学習用の便利forward。元のモジュールのforwardを呼び出す + / convenient forward for training. call the forward of the original module + """ cx = self.cond_emb - if not self.batch_cond_only and lx.shape[0] // 2 == cx.shape[0]: # inference only + if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) if self.use_zeros_for_batch_uncond: cx[0::2] = 0.0 # uncond is zero - # print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") + # print(f"C {self.lllite_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") + # downで入力の次元数を削減し、conditioning image embeddingと結合する # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している + # down reduces the number of input dimensions and combines it with conditioning image embedding # we expect that it will mix well by combining in the channel direction instead of adding - cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2) - cx = self.conditioning2(cx) - lx = lx + cx # lxはresidual的に加算される / lx is added residually + cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2) + cx = self.mid(cx) - # LoRA-Up - lx = self.lora_up(lx) + if self.dropout is not None and self.training: + cx = torch.nn.functional.dropout(cx, p=self.dropout) - # call original module - x = self.org_forward(x) + cx = self.up(cx) - # add LoRA + # residualを加算する / add residual if self.batch_cond_only: - x[1::2] += lx * self.multiplier * self.scale + x[1::2] += cx else: - x += lx * self.multiplier * self.scale + # to_outを対象とすると、cloneがないと次のエラーが出る / if to_out is the target, the following error will occur without clone + # RuntimeError: Output 0 of ReshapeAliasBackward0 is a view and is being modified inplace. + # This view was created inside a custom Function ... + # x = x.clone() + x += cx + + x = self.org_forward(x) # ここで元のモジュールを呼び出す / call the original module here return x -class LoRAControlNet(torch.nn.Module): +class ControlNetLLLite(torch.nn.Module): + UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + def __init__( self, unet: sdxl_original_unet.SdxlUNet2DConditionModel, cond_emb_dim: int = 16, - lora_dim: int = 16, - alpha: float = 1, + mlp_dim: int = 16, dropout: Optional[float] = None, varbose: Optional[bool] = False, ) -> None: @@ -161,9 +196,9 @@ class LoRAControlNet(torch.nn.Module): target_replace_modules: List[torch.nn.Module], module_class: Type[object], ) -> List[torch.nn.Module]: - prefix = LoRANetwork.LORA_PREFIX_UNET + prefix = "lllite_unet" - loras = [] + modules = [] for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): @@ -190,13 +225,13 @@ class LoRAControlNet(torch.nn.Module): else: raise NotImplementedError() - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") + lllite_name = prefix + "." + name + "." + child_name + lllite_name = lllite_name.replace(".", "_") if TRANSFORMER_MAX_BLOCK_INDEX is not None: - p = lora_name.find("transformer_blocks") + p = lllite_name.find("transformer_blocks") if p >= 0: - tf_index = int(lora_name[p:].split("_")[2]) + tf_index = int(lllite_name[p:].split("_")[2]) if tf_index > TRANSFORMER_MAX_BLOCK_INDEX: continue @@ -204,104 +239,63 @@ class LoRAControlNet(torch.nn.Module): # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない # time emb is not applied # attn2 conditioning (input from CLIP) cannot be applied because the shape is different - if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)): + if "emb_layers" in lllite_name or ( + "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name) + ): continue if ATTN1_2_ONLY: - if not ("attn1" in lora_name or "attn2" in lora_name): + if not ("attn1" in lllite_name or "attn2" in lllite_name): continue + if ATTN_QKV_ONLY: + if "to_out" in lllite_name: + continue if ATTN1_ETC_ONLY: - if "proj_out" in lora_name: + if "proj_out" in lllite_name: pass - elif "attn1" in lora_name and ("to_k" in lora_name or "to_v" in lora_name or "to_out" in lora_name): + elif "attn1" in lllite_name and ( + "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name + ): pass - elif "ff_net_2" in lora_name: + elif "ff_net_2" in lllite_name: pass else: continue - lora = module_class( + module = module_class( depth, cond_emb_dim, - lora_name, + lllite_name, child_module, - 1.0, - lora_dim, - alpha, + mlp_dim, dropout=dropout, ) - loras.append(lora) - return loras + modules.append(module) + return modules - target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE if not TRANSFORMER_ONLY: - target_modules = target_modules + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 # create module instances - self.unet_loras: List[LoRAModuleControlNet] = create_modules(unet, target_modules, LoRAModuleControlNet) - print(f"create ControlNet LoRA for U-Net: {len(self.unet_loras)} modules.") - - # conditioning image embedding - - # control画像そのままではLoRA的モジュールの入力にはサイズもチャネルも扱いにくいので、 - # 適切な潜在空間に変換する。ここでは、conditioning image embeddingと呼ぶ - # ただcontrol画像自体にはあまり情報量はないので、conditioning image embeddingはわりと小さくてよいはず - # また、conditioning image embeddingは、各LoRA的モジュールでさらに個別に学習する - # depthに応じて3つのサイズを用意する - - # conditioning image embedding is converted to an appropriate latent space - # because the size and channels of the input to the LoRA-like module are difficult to handle - # we call it conditioning image embedding - # however, the control image itself does not have much information, so the conditioning image embedding should be small - # conditioning image embedding is also learned individually in each LoRA-like module - # prepare three sizes according to depth - - self.cond_block0 = torch.nn.Sequential( - torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0), # to latent (from VAE) size - torch.nn.ReLU(inplace=True), - torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=3, stride=2, padding=1), - torch.nn.ReLU(inplace=True), - ) - self.cond_block1 = torch.nn.Sequential( - torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=1), - torch.nn.ReLU(inplace=True), - torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1), - torch.nn.ReLU(inplace=True), - ) - self.cond_block2 = torch.nn.Sequential( - torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=2, padding=1), - torch.nn.ReLU(inplace=True), - ) + self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule) + print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.") def forward(self, x): - x = self.cond_block0(x) - x0 = x - x = self.cond_block1(x) - x1 = x - x = self.cond_block2(x) - x2 = x + return x # dummy - x_3d = [] # for Linear - for x0 in [x0, x1, x2]: - # b,c,h,w -> b,h*w,c - n, c, h, w = x0.shape - x0 = x0.view(n, c, h * w).permute(0, 2, 1) - x_3d.append(x0) - - return [x0, x1, x2], x_3d - - def set_cond_embs(self, cond_embs_4d, cond_embs_3d): + def set_cond_image(self, cond_image): r""" 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む / call the model inside, so if necessary, surround it with torch.no_grad() """ - for lora in self.unet_loras: - lora.set_cond_embs(cond_embs_4d, cond_embs_3d) + for module in self.unet_modules: + module.set_cond_image(cond_image) def set_batch_cond_only(self, cond_only, zeros): - for lora in self.unet_loras: - lora.set_batch_cond_only(cond_only, zeros) + for module in self.unet_modules: + module.set_batch_cond_only(cond_only, zeros) def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": @@ -315,10 +309,10 @@ class LoRAControlNet(torch.nn.Module): return info def apply_to(self): - print("applying LoRA for U-Net...") - for lora in self.unet_loras: - lora.apply_to() - self.add_module(lora.lora_name, lora) + print("applying LLLite for U-Net...") + for module in self.unet_modules: + module.apply_to() + self.add_module(module.lllite_name, module) # マージできるかどうかを返す def is_mergeable(self): @@ -367,16 +361,15 @@ class LoRAControlNet(torch.nn.Module): if __name__ == "__main__": # デバッグ用 / for debug - # これを指定しないとエラーが出てcond_blockが学習できない / if not specified, an error occurs and cond_block cannot be learned - sdxl_original_unet.USE_REENTRANT = False + # sdxl_original_unet.USE_REENTRANT = False # test shape etc print("create unet") unet = sdxl_original_unet.SdxlUNet2DConditionModel() unet.to("cuda").to(torch.float16) - print("create LoRA controlnet") - control_net = LoRAControlNet(unet, 64, 32, 1) + print("create ControlNet-LLLite") + control_net = ControlNetLLLite(unet, 32, 64) control_net.apply_to() control_net.to("cuda") @@ -414,6 +407,7 @@ if __name__ == "__main__": print("start training") steps = 10 + sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0] for step in range(steps): print(f"step {step}") @@ -425,8 +419,7 @@ if __name__ == "__main__": y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda() with torch.cuda.amp.autocast(enabled=True): - cond_embs_4d, cond_embs_3d = control_net(conditioning_image) - control_net.set_cond_embs(cond_embs_4d, cond_embs_3d) + control_net.set_cond_image(conditioning_image) output = unet(x, t, ctx, y) target = torch.randn_like(output) @@ -436,3 +429,8 @@ if __name__ == "__main__": scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) + print(sample_param) + + # from safetensors.torch import save_file + + # save_file(control_net.state_dict(), "logs/control_net.safetensors") diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 69c0bd1d..cb16a781 100644 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -47,10 +47,9 @@ import library.train_util as train_util import library.sdxl_model_util as sdxl_model_util import library.sdxl_train_util as sdxl_train_util from networks.lora import LoRANetwork -import tools.original_control_net as original_control_net -from tools.original_control_net import ControlNetInfo from library.sdxl_original_unet import SdxlUNet2DConditionModel from library.original_unet import FlashAttentionFunction +from networks.control_net_lllite import ControlNetLLLite # scheduler: SCHEDULER_LINEAR_START = 0.00085 @@ -327,7 +326,7 @@ class PipelineLike: self.token_replacements_list.append({}) # ControlNet # not supported yet - self.control_nets: List[ControlNetInfo] = [] + self.control_nets: List[ControlNetLLLite] = [] self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない # Textual Inversion @@ -392,6 +391,7 @@ class PipelineLike: is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: Optional[int] = 1, img2img_noise=None, + clip_guide_images=None, **kwargs, ): # TODO support secondary prompt @@ -496,11 +496,16 @@ class PipelineLike: text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) if self.control_nets: + # ControlNetのhintにguide imageを流用する if isinstance(clip_guide_images, PIL.Image.Image): clip_guide_images = [clip_guide_images] + if isinstance(clip_guide_images[0], PIL.Image.Image): + clip_guide_images = [preprocess_image(im) for im in clip_guide_images] + clip_guide_images = torch.cat(clip_guide_images) + if isinstance(clip_guide_images, list): + clip_guide_images = torch.stack(clip_guide_images) - # ControlNetのhintにguide imageを流用する - # 前処理はControlNet側で行う + clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype) # create size embs if original_height is None: @@ -654,35 +659,47 @@ class PipelineLike: num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 if self.control_nets: - guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + # guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images) + if self.control_net_enabled: + for control_net in self.control_nets: + with torch.no_grad(): + control_net.set_cond_image(clip_guide_images) + else: + for control_net in self.control_nets: + control_net.set_cond_image(None) for i, t in enumerate(tqdm(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # predict the noise residual - if self.control_nets and self.control_net_enabled: - if reginonal_network: - num_sub_and_neg_prompts = len(text_embeddings) // batch_size - text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt - else: - text_emb_last = text_embeddings + # # disable control net if ratio is set + # if self.control_nets and self.control_net_enabled: + # pass # TODO - # not working yet - noise_pred = original_control_net.call_unet_and_control_net( - i, - num_latent_input, - self.unet, - self.control_nets, - guided_hints, - i / len(timesteps), - latent_model_input, - t, - text_emb_last, - ).sample - else: - noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) + # predict the noise residual + # TODO Diffusers' ControlNet + # if self.control_nets and self.control_net_enabled: + # if reginonal_network: + # num_sub_and_neg_prompts = len(text_embeddings) // batch_size + # text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt + # else: + # text_emb_last = text_embeddings + + # # not working yet + # noise_pred = original_control_net.call_unet_and_control_net( + # i, + # num_latent_input, + # self.unet, + # self.control_nets, + # guided_hints, + # i / len(timesteps), + # latent_model_input, + # t, + # text_emb_last, + # ).sample + # else: + noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) # perform guidance if do_classifier_free_guidance: @@ -1550,16 +1567,40 @@ def main(args): upscaler.to(dtype).to(device) # ControlNetの処理 - control_nets: List[ControlNetInfo] = [] - if args.control_net_models: - for i, model in enumerate(args.control_net_models): - prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] - weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] - ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] + control_nets: List[ControlNetLLLite] = [] + # if args.control_net_models: + # for i, model in enumerate(args.control_net_models): + # prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i] + # weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i] + # ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i] - ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model) - prep = original_control_net.load_preprocess(prep_type) - control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + # ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model) + # prep = original_control_net.load_preprocess(prep_type) + # control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio)) + if args.control_net_lllite_models: + for i, model_file in enumerate(args.control_net_lllite_models): + print(f"loading ControlNet-LLLite: {model_file}") + + from safetensors.torch import load_file + + state_dict = load_file(model_file) + mlp_dim = None + cond_emb_dim = None + for key, value in state_dict.items(): + if mlp_dim is None and "down.0.weight" in key: + mlp_dim = value.shape[0] + elif cond_emb_dim is None and "conditioning1.0" in key: + cond_emb_dim = value.shape[0] * 2 + if mlp_dim is not None and cond_emb_dim is not None: + break + assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}" + + control_net = ControlNetLLLite(unet, cond_emb_dim, mlp_dim) + control_net.apply_to() + control_net.load_state_dict(state_dict) + control_net.to(dtype).to(device) + control_net.set_batch_cond_only(False, False) + control_nets.append(control_net) if args.opt_channels_last: print(f"set optimizing: channels last") @@ -1572,8 +1613,9 @@ def main(args): network.to(memory_format=torch.channels_last) for cn in control_nets: - cn.unet.to(memory_format=torch.channels_last) - cn.net.to(memory_format=torch.channels_last) + cn.to(memory_format=torch.channels_last) + # cn.unet.to(memory_format=torch.channels_last) + # cn.net.to(memory_format=torch.channels_last) pipe = PipelineLike( device, @@ -2573,20 +2615,23 @@ def setup_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" - ) - parser.add_argument( - "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" - ) - parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み") - parser.add_argument( - "--control_net_ratios", - type=float, - default=None, - nargs="*", - help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + "--control_net_lllite_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" ) # parser.add_argument( + # "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" + # ) + # parser.add_argument( + # "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" + # ) + # parser.add_argument("--control_net_multiplier", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率") + # parser.add_argument( + # "--control_net_ratios", + # type=float, + # default=None, + # nargs="*", + # help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", + # ) + # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # ) diff --git a/sdxl_gen_img_lora_ctrl_test.py b/sdxl_gen_img_lora_ctrl_test.py deleted file mode 100644 index b1cab576..00000000 --- a/sdxl_gen_img_lora_ctrl_test.py +++ /dev/null @@ -1,2612 +0,0 @@ -# temporary test code for LoRA control net - -import itertools -import json -from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable -import glob -import importlib -import inspect -import time -import zipfile -from diffusers.utils import deprecate -from diffusers.configuration_utils import FrozenDict -import argparse -import math -import os -import random -import re - -import diffusers -import numpy as np -import torch -import torchvision -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - DPMSolverSinglestepScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - DDIMScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - KDPM2DiscreteScheduler, - KDPM2AncestralDiscreteScheduler, - # UNet2DConditionModel, - StableDiffusionPipeline, -) -from einops import rearrange -from tqdm import tqdm -from torchvision import transforms -from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig -import PIL -from PIL import Image -from PIL.PngImagePlugin import PngInfo - -import library.model_util as model_util -import library.train_util as train_util -import library.sdxl_model_util as sdxl_model_util -import library.sdxl_train_util as sdxl_train_util -from networks.lora import LoRANetwork -from library.sdxl_original_unet import SdxlUNet2DConditionModel -from library.original_unet import FlashAttentionFunction -from networks.lora_control_net import LoRAControlNet - -# scheduler: -SCHEDULER_LINEAR_START = 0.00085 -SCHEDULER_LINEAR_END = 0.0120 -SCHEDULER_TIMESTEPS = 1000 -SCHEDLER_SCHEDULE = "scaled_linear" - -# その他の設定 -LATENT_CHANNELS = 4 -DOWNSAMPLING_FACTOR = 8 - -# region モジュール入れ替え部 -""" -高速化のためのモジュール入れ替え -""" - - -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): - if mem_eff_attn: - print("Enable memory efficient attention for U-Net") - - # これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い - unet.set_use_memory_efficient_attention(False, True) - elif xformers: - print("Enable xformers for U-Net") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - unet.set_use_memory_efficient_attention(True, False) - elif sdpa: - print("Enable SDPA for U-Net") - unet.set_use_memory_efficient_attention(False, False) - unet.set_use_sdpa(True) - - -# TODO common train_util.py -def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers, sdpa): - if mem_eff_attn: - replace_vae_attn_to_memory_efficient() - elif xformers: - # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す? - vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う - elif sdpa: - replace_vae_attn_to_sdpa() - - -def replace_vae_attn_to_memory_efficient(): - print("VAE Attention.forward has been replaced to FlashAttention (not xformers)") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, hidden_states, **kwargs): - q_bucket_size = 512 - k_bucket_size = 1024 - - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_flash_attn_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_flash_attn(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_flash_attn_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_flash_attn - - -def replace_vae_attn_to_xformers(): - print("VAE: Attention.forward has been replaced to xformers") - import xformers.ops - - def forward_xformers(self, hidden_states, **kwargs): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - query_proj = query_proj.contiguous() - key_proj = key_proj.contiguous() - value_proj = value_proj.contiguous() - out = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) - - out = rearrange(out, "b h n d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_xformers_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_xformers(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_xformers_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_xformers - - -def replace_vae_attn_to_sdpa(): - print("VAE: Attention.forward has been replaced to sdpa") - - def forward_sdpa(self, hidden_states, **kwargs): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.to_q(hidden_states) - key_proj = self.to_k(hidden_states) - value_proj = self.to_v(hidden_states) - - query_proj, key_proj, value_proj = map( - lambda t: rearrange(t, "b n (h d) -> b n h d", h=self.heads), (query_proj, key_proj, value_proj) - ) - - out = torch.nn.functional.scaled_dot_product_attention( - query_proj, key_proj, value_proj, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - out = rearrange(out, "b n h d -> b n (h d)") - - # compute next hidden_states - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - def forward_sdpa_0_14(self, hidden_states, **kwargs): - if not hasattr(self, "to_q"): - self.to_q = self.query - self.to_k = self.key - self.to_v = self.value - self.to_out = [self.proj_attn, torch.nn.Identity()] - self.heads = self.num_heads - return forward_sdpa(self, hidden_states, **kwargs) - - if diffusers.__version__ < "0.15.0": - diffusers.models.attention.AttentionBlock.forward = forward_sdpa_0_14 - else: - diffusers.models.attention_processor.Attention.forward = forward_sdpa - - -# endregion - -# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正 -# https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py -# Pipelineだけ独立して使えないのと機能追加するのとでコピーして修正 - - -class PipelineLike: - def __init__( - self, - device, - vae: AutoencoderKL, - text_encoders: List[CLIPTextModel], - tokenizers: List[CLIPTokenizer], - unet: SdxlUNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - clip_skip: int, - ): - super().__init__() - self.device = device - self.clip_skip = clip_skip - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - self.vae = vae - self.text_encoders = text_encoders - self.tokenizers = tokenizers - self.unet: SdxlUNet2DConditionModel = unet - self.scheduler = scheduler - self.safety_checker = None - - # Textual Inversion - self.token_replacements_list = [] - for _ in range(len(self.text_encoders)): - self.token_replacements_list.append({}) - - # ControlNet # not supported yet - self.control_nets: List[LoRAControlNet] = [] - self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない - - # Textual Inversion - def add_token_replacement(self, text_encoder_index, target_token_id, rep_token_ids): - self.token_replacements_list[text_encoder_index][target_token_id] = rep_token_ids - - def set_enable_control_net(self, en: bool): - self.control_net_enabled = en - - def get_token_replacer(self, tokenizer): - tokenizer_index = self.tokenizers.index(tokenizer) - token_replacements = self.token_replacements_list[tokenizer_index] - - def replace_tokens(tokens): - # print("replace_tokens", tokens, "=>", token_replacements) - if isinstance(tokens, torch.Tensor): - tokens = tokens.tolist() - - new_tokens = [] - for token in tokens: - if token in token_replacements: - replacement = token_replacements[token] - new_tokens.extend(replacement) - else: - new_tokens.append(token) - return new_tokens - - return replace_tokens - - def set_control_nets(self, ctrl_nets): - self.control_nets = ctrl_nets - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - init_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - mask_image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, - height: int = 1024, - width: int = 1024, - original_height: int = None, - original_width: int = None, - original_height_negative: int = None, - original_width_negative: int = None, - crop_top: int = 0, - crop_left: int = 0, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_scale: float = None, - strength: float = 0.8, - # num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - max_embeddings_multiples: Optional[int] = 3, - output_type: Optional[str] = "pil", - vae_batch_size: float = None, - return_latents: bool = False, - # return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - is_cancelled_callback: Optional[Callable[[], bool]] = None, - callback_steps: Optional[int] = 1, - img2img_noise=None, - clip_guide_images=None, - **kwargs, - ): - # TODO support secondary prompt - num_images_per_prompt = 1 # fixed because already prompt is repeated - - if isinstance(prompt, str): - batch_size = 1 - prompt = [prompt] - elif isinstance(prompt, list): - batch_size = len(prompt) - else: - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - reginonal_network = " AND " in prompt[0] - - vae_batch_size = ( - batch_size - if vae_batch_size is None - else (int(vae_batch_size) if vae_batch_size >= 1 else max(1, int(batch_size * vae_batch_size))) - ) - - if strength < 0 or strength > 1: - raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." - ) - - # get prompt text embeddings - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - if not do_classifier_free_guidance and negative_scale is not None: - print(f"negative_scale is ignored if guidance scalle <= 1.0") - negative_scale = None - - # get unconditional embeddings for classifier free guidance - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - tes_text_embs = [] - tes_uncond_embs = [] - tes_real_uncond_embs = [] - # use last pool - for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): - token_replacer = self.get_token_replacer(tokenizer) - - text_embeddings, text_pool, uncond_embeddings, uncond_pool, _ = get_weighted_text_embeddings( - tokenizer, - text_encoder, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - token_replacer=token_replacer, - device=self.device, - **kwargs, - ) - tes_text_embs.append(text_embeddings) - tes_uncond_embs.append(uncond_embeddings) - - if negative_scale is not None: - _, real_uncond_embeddings, _ = get_weighted_text_embeddings( - token_replacer, - prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 - uncond_prompt=[""] * batch_size, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - token_replacer=token_replacer, - device=self.device, - **kwargs, - ) - tes_real_uncond_embs.append(real_uncond_embeddings) - - # concat text encoder outputs - text_embeddings = tes_text_embs[0] - uncond_embeddings = tes_uncond_embs[0] - for i in range(1, len(tes_text_embs)): - text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048 - uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048 - - if do_classifier_free_guidance: - if negative_scale is None: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - else: - text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - - if self.control_nets: - # ControlNetのhintにguide imageを流用する - if isinstance(clip_guide_images, PIL.Image.Image): - clip_guide_images = [clip_guide_images] - if isinstance(clip_guide_images[0], PIL.Image.Image): - clip_guide_images = [preprocess_image(im) for im in clip_guide_images] - clip_guide_images = torch.cat(clip_guide_images) - if isinstance(clip_guide_images, list): - clip_guide_images = torch.stack(clip_guide_images) - - clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype) - - # create size embs - if original_height is None: - original_height = height - if original_width is None: - original_width = width - if original_height_negative is None: - original_height_negative = original_height - if original_width_negative is None: - original_width_negative = original_width - if crop_top is None: - crop_top = 0 - if crop_left is None: - crop_left = 0 - emb1 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256) - uc_emb1 = sdxl_train_util.get_timestep_embedding( - torch.FloatTensor([original_height_negative, original_width_negative]).unsqueeze(0), 256 - ) - emb2 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256) - emb3 = sdxl_train_util.get_timestep_embedding(torch.FloatTensor([height, width]).unsqueeze(0), 256) - c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) - uc_vector = torch.cat([uc_emb1, emb2, emb3], dim=1).to(self.device, dtype=text_embeddings.dtype).repeat(batch_size, 1) - - c_vector = torch.cat([text_pool, c_vector], dim=1) - uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) - - vector_embeddings = torch.cat([uc_vector, c_vector]) - - # set timesteps - self.scheduler.set_timesteps(num_inference_steps, self.device) - - latents_dtype = text_embeddings.dtype - init_latents_orig = None - mask = None - - if init_image is None: - # get the initial random noise unless the user supplied it - - # Unlike in other pipelines, latents need to be generated in the target device - # for 1-to-1 results reproducibility with the CompVis implementation. - # However this currently doesn't work in `mps`. - latents_shape = ( - batch_size * num_images_per_prompt, - self.unet.in_channels, - height // 8, - width // 8, - ) - - if latents is None: - if self.device.type == "mps": - # randn does not exist on mps - latents = torch.randn( - latents_shape, - generator=generator, - device="cpu", - dtype=latents_dtype, - ).to(self.device) - else: - latents = torch.randn( - latents_shape, - generator=generator, - device=self.device, - dtype=latents_dtype, - ) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) - - timesteps = self.scheduler.timesteps.to(self.device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - else: - # image to tensor - if isinstance(init_image, PIL.Image.Image): - init_image = [init_image] - if isinstance(init_image[0], PIL.Image.Image): - init_image = [preprocess_image(im) for im in init_image] - init_image = torch.cat(init_image) - if isinstance(init_image, list): - init_image = torch.stack(init_image) - - # mask image to tensor - if mask_image is not None: - if isinstance(mask_image, PIL.Image.Image): - mask_image = [mask_image] - if isinstance(mask_image[0], PIL.Image.Image): - mask_image = torch.cat([preprocess_mask(im) for im in mask_image]) # H*W, 0 for repaint - - # encode the init image into latents and scale the latents - init_image = init_image.to(device=self.device, dtype=latents_dtype) - if init_image.size()[-2:] == (height // 8, width // 8): - init_latents = init_image - else: - if vae_batch_size >= batch_size: - init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() - init_latents = [] - for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): - init_latent_dist = self.vae.encode( - (init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)).to( - self.vae.dtype - ) - ).latent_dist - init_latents.append(init_latent_dist.sample(generator=generator)) - init_latents = torch.cat(init_latents) - - init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents - - if len(init_latents) == 1: - init_latents = init_latents.repeat((batch_size, 1, 1, 1)) - init_latents_orig = init_latents - - # preprocess mask - if mask_image is not None: - mask = mask_image.to(device=self.device, dtype=latents_dtype) - if len(mask) == 1: - mask = mask.repeat((batch_size, 1, 1, 1)) - - # check sizes - if not mask.shape == init_latents.shape: - raise ValueError("The mask and init_image should be the same size!") - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) - - # add noise to latents using the timesteps - latents = self.scheduler.add_noise(init_latents, img2img_noise, timesteps) - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].to(self.device) - - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1 - - if self.control_nets: - if self.control_net_enabled: - for control_net in self.control_nets: - with torch.no_grad(): - cond_embs_4d, cond_embs_3d = control_net(clip_guide_images) - control_net.set_cond_embs(cond_embs_4d, cond_embs_3d) - else: - for control_net in self.control_nets: - control_net.set_cond_image(None) - - for i, t in enumerate(tqdm(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = latents.repeat((num_latent_input, 1, 1, 1)) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # # disable control net if ratio is set - # if self.control_nets and self.control_net_enabled: - # pass # TODO - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings) - - # perform guidance - if do_classifier_free_guidance: - if negative_scale is None: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - else: - noise_pred_negative, noise_pred_text, noise_pred_uncond = noise_pred.chunk( - num_latent_input - ) # uncond is real uncond - noise_pred = ( - noise_pred_uncond - + guidance_scale * (noise_pred_text - noise_pred_uncond) - - negative_scale * (noise_pred_negative - noise_pred_uncond) - ) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if mask is not None: - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, img2img_noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # call the callback, if provided - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None - - if return_latents: - return (latents, False) - - latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents - if vae_batch_size >= batch_size: - image = self.vae.decode(latents.to(self.vae.dtype)).sample - else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() - images = [] - for i in tqdm(range(0, batch_size, vae_batch_size)): - images.append( - self.vae.decode( - (latents[i : i + vae_batch_size] if vae_batch_size > 1 else latents[i].unsqueeze(0)).to(self.vae.dtype) - ).sample - ) - image = torch.cat(images) - - image = (image / 2 + 0.5).clamp(0, 1) - - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - - if output_type == "pil": - # image = self.numpy_to_pil(image) - image = (image * 255).round().astype("uint8") - image = [Image.fromarray(im) for im in image] - - return image - - # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) - - -re_attention = re.compile( - r""" -\\\(| -\\\)| -\\\[| -\\]| -\\\\| -\\| -\(| -\[| -:([+-]?[.\d]+)\)| -\)| -]| -[^\\()\[\]:]+| -: -""", - re.X, -) - - -def parse_prompt_attention(text): - """ - Parses a string with attention tokens and returns a list of pairs: text and its associated weight. - Accepted tokens are: - (abc) - increases attention to abc by a multiplier of 1.1 - (abc:3.12) - increases attention to abc by a multiplier of 3.12 - [abc] - decreases attention to abc by a multiplier of 1.1 - \( - literal character '(' - \[ - literal character '[' - \) - literal character ')' - \] - literal character ']' - \\ - literal character '\' - anything else - just text - >>> parse_prompt_attention('normal text') - [['normal text', 1.0]] - >>> parse_prompt_attention('an (important) word') - [['an ', 1.0], ['important', 1.1], [' word', 1.0]] - >>> parse_prompt_attention('(unbalanced') - [['unbalanced', 1.1]] - >>> parse_prompt_attention('\(literal\]') - [['(literal]', 1.0]] - >>> parse_prompt_attention('(unnecessary)(parens)') - [['unnecessaryparens', 1.1]] - >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') - [['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1]] - """ - - res = [] - round_brackets = [] - square_brackets = [] - - round_bracket_multiplier = 1.1 - square_bracket_multiplier = 1 / 1.1 - - def multiply_range(start_position, multiplier): - for p in range(start_position, len(res)): - res[p][1] *= multiplier - - # keep break as separate token - text = text.replace("BREAK", "\\BREAK\\") - - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) - - if text.startswith("\\"): - res.append([text[1:], 1.0]) - elif text == "(": - round_brackets.append(len(res)) - elif text == "[": - square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ")" and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == "]" and len(square_brackets) > 0: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - res.append([text, 1.0]) - - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) - - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) - - if len(res) == 0: - res = [["", 1.0]] - - # merge runs of identical weights - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK": - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 - - return res - - -def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: List[str], max_length: int): - r""" - Tokenize a list of prompts and return its tokens with weights of each token. - No padding, starting or ending token is included. - """ - tokens = [] - weights = [] - truncated = False - - for text in prompt: - texts_and_weights = parse_prompt_attention(text) - text_token = [] - text_weight = [] - for word, weight in texts_and_weights: - if word.strip() == "BREAK": - # pad until next multiple of tokenizer's max token length - pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length) - print(f"BREAK pad_len: {pad_len}") - for i in range(pad_len): - # v2のときEOSをつけるべきかどうかわからないぜ - # if i == 0: - # text_token.append(tokenizer.eos_token_id) - # else: - text_token.append(tokenizer.pad_token_id) - text_weight.append(1.0) - continue - - # tokenize and discard the starting and the ending token - token = tokenizer(word).input_ids[1:-1] - - token = token_replacer(token) # for Textual Inversion - - text_token += token - # copy the weight by length of token - text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) - if len(text_token) > max_length: - truncated = True - break - # truncate - if len(text_token) > max_length: - truncated = True - text_token = text_token[:max_length] - text_weight = text_weight[:max_length] - tokens.append(text_token) - weights.append(text_weight) - if truncated: - print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") - return tokens, weights - - -def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): - r""" - Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. - """ - max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) - weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length - for i in range(len(tokens)): - tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i])) - if no_boseos_middle: - weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) - else: - w = [] - if len(weights[i]) == 0: - w = [1.0] * weights_length - else: - for j in range(max_embeddings_multiples): - w.append(1.0) # weight for starting token in this chunk - w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] - w.append(1.0) # weight for ending token in this chunk - w += [1.0] * (weights_length - len(w)) - weights[i] = w[:] - - return tokens, weights - - -def get_unweighted_text_embeddings( - text_encoder: CLIPTextModel, - text_input: torch.Tensor, - chunk_length: int, - clip_skip: int, - eos: int, - pad: int, - no_boseos_middle: Optional[bool] = True, -): - """ - When the length of tokens is a multiple of the capacity of the text encoder, - it should be split into chunks and sent to the text encoder individually. - """ - max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) - if max_embeddings_multiples > 1: - text_embeddings = [] - pool = None - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() - - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - if pad == eos: # v1 - text_input_chunk[:, -1] = text_input[0, -1] - else: # v2 - for j in range(len(text_input_chunk)): - if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある - text_input_chunk[j, -1] = eos - if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD - text_input_chunk[j, 1] = eos - - # -2 is same for Text Encoder 1 and 2 - enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) - text_embedding = enc_out["hidden_states"][-2] - if pool is None: - pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided - if pool is not None: - pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos) - - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] - - text_embeddings.append(text_embedding) - text_embeddings = torch.concat(text_embeddings, axis=1) - else: - enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) - text_embeddings = enc_out["hidden_states"][-2] - pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this - if pool is not None: - pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input, eos) - return text_embeddings, pool - - -def get_weighted_text_embeddings( - tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, - prompt: Union[str, List[str]], - uncond_prompt: Optional[Union[str, List[str]]] = None, - max_embeddings_multiples: Optional[int] = 1, - no_boseos_middle: Optional[bool] = False, - skip_parsing: Optional[bool] = False, - skip_weighting: Optional[bool] = False, - clip_skip=None, - token_replacer=None, - device=None, - **kwargs, -): - max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - if isinstance(prompt, str): - prompt = [prompt] - - # split the prompts with "AND". each prompt must have the same number of splits - new_prompts = [] - for p in prompt: - new_prompts.extend(p.split(" AND ")) - prompt = new_prompts - - if not skip_parsing: - prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, token_replacer, prompt, max_length - 2) - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens, uncond_weights = get_prompts_with_weights(tokenizer, token_replacer, uncond_prompt, max_length - 2) - else: - prompt_tokens = [token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids] - prompt_weights = [[1.0] * len(token) for token in prompt_tokens] - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens = [token[1:-1] for token in tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids] - uncond_weights = [[1.0] * len(token) for token in uncond_tokens] - - # round up the longest length of tokens to a multiple of (model_max_length - 2) - max_length = max([len(token) for token in prompt_tokens]) - if uncond_prompt is not None: - max_length = max(max_length, max([len(token) for token in uncond_tokens])) - - max_embeddings_multiples = min( - max_embeddings_multiples, - (max_length - 1) // (tokenizer.model_max_length - 2) + 1, - ) - max_embeddings_multiples = max(1, max_embeddings_multiples) - max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 - - # pad the length of tokens and weights - bos = tokenizer.bos_token_id - eos = tokenizer.eos_token_id - pad = tokenizer.pad_token_id - prompt_tokens, prompt_weights = pad_tokens_and_weights( - prompt_tokens, - prompt_weights, - max_length, - bos, - eos, - pad, - no_boseos_middle=no_boseos_middle, - chunk_length=tokenizer.model_max_length, - ) - prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device) - if uncond_prompt is not None: - uncond_tokens, uncond_weights = pad_tokens_and_weights( - uncond_tokens, - uncond_weights, - max_length, - bos, - eos, - pad, - no_boseos_middle=no_boseos_middle, - chunk_length=tokenizer.model_max_length, - ) - uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) - - # get the embeddings - text_embeddings, text_pool = get_unweighted_text_embeddings( - text_encoder, - prompt_tokens, - tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device) - if uncond_prompt is not None: - uncond_embeddings, uncond_pool = get_unweighted_text_embeddings( - text_encoder, - uncond_tokens, - tokenizer.model_max_length, - clip_skip, - eos, - pad, - no_boseos_middle=no_boseos_middle, - ) - uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=device) - - # assign weights to the prompts and normalize in the sense of mean - # TODO: should we normalize by chunk or in a whole (current implementation)? - # →全体でいいんじゃないかな - if (not skip_parsing) and (not skip_weighting): - previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= prompt_weights.unsqueeze(-1) - current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - if uncond_prompt is not None: - previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= uncond_weights.unsqueeze(-1) - current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - - if uncond_prompt is not None: - return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens - return text_embeddings, text_pool, None, None, prompt_tokens - - -def preprocess_image(image): - w, h = image.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - image = image.resize((w, h), resample=PIL.Image.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.0 * image - 1.0 - - -def preprocess_mask(mask): - mask = mask.convert("L") - w, h = mask.size - w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // 8, h // 8), resample=PIL.Image.BILINEAR) # LANCZOS) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) - return mask - - -# regular expression for dynamic prompt: -# starts and ends with "{" and "}" -# contains at least one variant divided by "|" -# optional framgments divided by "$$" at start -# if the first fragment is "E" or "e", enumerate all variants -# if the second fragment is a number or two numbers, repeat the variants in the range -# if the third fragment is a string, use it as a separator - -RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}") - - -def handle_dynamic_prompt_variants(prompt, repeat_count): - founds = list(RE_DYNAMIC_PROMPT.finditer(prompt)) - if not founds: - return [prompt] - - # make each replacement for each variant - enumerating = False - replacers = [] - for found in founds: - # if "e$$" is found, enumerate all variants - found_enumerating = found.group(2) is not None - enumerating = enumerating or found_enumerating - - separator = ", " if found.group(6) is None else found.group(6) - variants = found.group(7).split("|") - - # parse count range - count_range = found.group(4) - if count_range is None: - count_range = [1, 1] - else: - count_range = count_range.split("-") - if len(count_range) == 1: - count_range = [int(count_range[0]), int(count_range[0])] - elif len(count_range) == 2: - count_range = [int(count_range[0]), int(count_range[1])] - else: - print(f"invalid count range: {count_range}") - count_range = [1, 1] - if count_range[0] > count_range[1]: - count_range = [count_range[1], count_range[0]] - if count_range[0] < 0: - count_range[0] = 0 - if count_range[1] > len(variants): - count_range[1] = len(variants) - - if found_enumerating: - # make function to enumerate all combinations - def make_replacer_enum(vari, cr, sep): - def replacer(): - values = [] - for count in range(cr[0], cr[1] + 1): - for comb in itertools.combinations(vari, count): - values.append(sep.join(comb)) - return values - - return replacer - - replacers.append(make_replacer_enum(variants, count_range, separator)) - else: - # make function to choose random combinations - def make_replacer_single(vari, cr, sep): - def replacer(): - count = random.randint(cr[0], cr[1]) - comb = random.sample(vari, count) - return [sep.join(comb)] - - return replacer - - replacers.append(make_replacer_single(variants, count_range, separator)) - - # make each prompt - if not enumerating: - # if not enumerating, repeat the prompt, replace each variant randomly - prompts = [] - for _ in range(repeat_count): - current = prompt - for found, replacer in zip(founds, replacers): - current = current.replace(found.group(0), replacer()[0], 1) - prompts.append(current) - else: - # if enumerating, iterate all combinations for previous prompts - prompts = [prompt] - - for found, replacer in zip(founds, replacers): - if found.group(2) is not None: - # make all combinations for existing prompts - new_prompts = [] - for current in prompts: - replecements = replacer() - for replecement in replecements: - new_prompts.append(current.replace(found.group(0), replecement, 1)) - prompts = new_prompts - - for found, replacer in zip(founds, replacers): - # make random selection for existing prompts - if found.group(2) is None: - for i in range(len(prompts)): - prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1) - - return prompts - - -# endregion - - -# def load_clip_l14_336(dtype): -# print(f"loading CLIP: {CLIP_ID_L14_336}") -# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) -# return text_encoder - - -class BatchDataBase(NamedTuple): - # バッチ分割が必要ないデータ - step: int - prompt: str - negative_prompt: str - seed: int - init_image: Any - mask_image: Any - clip_prompt: str - guide_image: Any - - -class BatchDataExt(NamedTuple): - # バッチ分割が必要なデータ - width: int - height: int - original_width: int - original_height: int - original_width_negative: int - original_height_negative: int - crop_left: int - crop_top: int - steps: int - scale: float - negative_scale: float - strength: float - network_muls: Tuple[float] - num_sub_prompts: int - - -class BatchData(NamedTuple): - return_latents: bool - base: BatchDataBase - ext: BatchDataExt - - -def main(args): - if args.fp16: - dtype = torch.float16 - elif args.bf16: - dtype = torch.bfloat16 - else: - dtype = torch.float32 - - highres_fix = args.highres_fix_scale is not None - # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - - # モデルを読み込む - if not os.path.isfile(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う - files = glob.glob(args.ckpt) - if len(files) == 1: - args.ckpt = files[0] - - (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype - ) - - # xformers、Hypernetwork対応 - if not args.diffusers_xformers: - mem_eff = not (args.xformers or args.sdpa) - replace_unet_modules(unet, mem_eff, args.xformers, args.sdpa) - replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa) - - # tokenizerを読み込む - print("loading tokenizer") - tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args) - - # schedulerを用意する - sched_init_args = {} - has_steps_offset = True - has_clip_sample = True - scheduler_num_noises_per_step = 1 - - if args.sampler == "ddim": - scheduler_cls = DDIMScheduler - scheduler_module = diffusers.schedulers.scheduling_ddim - elif args.sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある - scheduler_cls = DDPMScheduler - scheduler_module = diffusers.schedulers.scheduling_ddpm - elif args.sampler == "pndm": - scheduler_cls = PNDMScheduler - scheduler_module = diffusers.schedulers.scheduling_pndm - has_clip_sample = False - elif args.sampler == "lms" or args.sampler == "k_lms": - scheduler_cls = LMSDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_lms_discrete - has_clip_sample = False - elif args.sampler == "euler" or args.sampler == "k_euler": - scheduler_cls = EulerDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_euler_discrete - has_clip_sample = False - elif args.sampler == "euler_a" or args.sampler == "k_euler_a": - scheduler_cls = EulerAncestralDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete - has_clip_sample = False - elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++": - scheduler_cls = DPMSolverMultistepScheduler - sched_init_args["algorithm_type"] = args.sampler - scheduler_module = diffusers.schedulers.scheduling_dpmsolver_multistep - has_clip_sample = False - elif args.sampler == "dpmsingle": - scheduler_cls = DPMSolverSinglestepScheduler - scheduler_module = diffusers.schedulers.scheduling_dpmsolver_singlestep - has_clip_sample = False - has_steps_offset = False - elif args.sampler == "heun": - scheduler_cls = HeunDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_heun_discrete - has_clip_sample = False - elif args.sampler == "dpm_2" or args.sampler == "k_dpm_2": - scheduler_cls = KDPM2DiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_discrete - has_clip_sample = False - elif args.sampler == "dpm_2_a" or args.sampler == "k_dpm_2_a": - scheduler_cls = KDPM2AncestralDiscreteScheduler - scheduler_module = diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete - scheduler_num_noises_per_step = 2 - has_clip_sample = False - - # 警告を出さないようにする - if has_steps_offset: - sched_init_args["steps_offset"] = 1 - if has_clip_sample: - sched_init_args["clip_sample"] = False - - # samplerの乱数をあらかじめ指定するための処理 - - # replace randn - class NoiseManager: - def __init__(self): - self.sampler_noises = None - self.sampler_noise_index = 0 - - def reset_sampler_noises(self, noises): - self.sampler_noise_index = 0 - self.sampler_noises = noises - - def randn(self, shape, device=None, dtype=None, layout=None, generator=None): - # print("replacing", shape, len(self.sampler_noises), self.sampler_noise_index) - if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises): - noise = self.sampler_noises[self.sampler_noise_index] - if shape != noise.shape: - noise = None - else: - noise = None - - if noise == None: - print(f"unexpected noise request: {self.sampler_noise_index}, {shape}") - noise = torch.randn(shape, dtype=dtype, device=device, generator=generator) - - self.sampler_noise_index += 1 - return noise - - class TorchRandReplacer: - def __init__(self, noise_manager): - self.noise_manager = noise_manager - - def __getattr__(self, item): - if item == "randn": - return self.noise_manager.randn - if hasattr(torch, item): - return getattr(torch, item) - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) - - noise_manager = NoiseManager() - if scheduler_module is not None: - scheduler_module.torch = TorchRandReplacer(noise_manager) - - scheduler = scheduler_cls( - num_train_timesteps=SCHEDULER_TIMESTEPS, - beta_start=SCHEDULER_LINEAR_START, - beta_end=SCHEDULER_LINEAR_END, - beta_schedule=SCHEDLER_SCHEDULE, - **sched_init_args, - ) - - # ↓以下は結局PipeでFalseに設定されるので意味がなかった - # # clip_sample=Trueにする - # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: - # print("set clip_sample to True") - # scheduler.config.clip_sample = True - - # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない - - # custom pipelineをコピったやつを生成する - if args.vae_slices: - from library.slicing_vae import SlicingAutoencoderKL - - sli_vae = SlicingAutoencoderKL( - act_fn="silu", - block_out_channels=(128, 256, 512, 512), - down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], - in_channels=3, - latent_channels=4, - layers_per_block=2, - norm_num_groups=32, - out_channels=3, - sample_size=512, - up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], - num_slices=args.vae_slices, - ) - sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする - vae = sli_vae - del sli_vae - - vae_dtype = dtype - if args.no_half_vae: - print("set vae_dtype to float32") - vae_dtype = torch.float32 - vae.to(vae_dtype).to(device) - - text_encoder1.to(dtype).to(device) - text_encoder2.to(dtype).to(device) - unet.to(dtype).to(device) - - # networkを組み込む - if args.network_module: - networks = [] - network_default_muls = [] - network_pre_calc = args.network_pre_calc - - for i, network_module in enumerate(args.network_module): - print("import network module:", network_module) - imported_module = importlib.import_module(network_module) - - network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] - network_default_muls.append(network_mul) - - net_kwargs = {} - if args.network_args and i < len(args.network_args): - network_args = args.network_args[i] - # TODO escape special chars - network_args = network_args.split(";") - for net_arg in network_args: - key, value = net_arg.split("=") - net_kwargs[key] = value - - if args.network_weights and i < len(args.network_weights): - network_weight = args.network_weights[i] - print("load network weights from:", network_weight) - - if model_util.is_safetensors(network_weight) and args.network_show_meta: - from safetensors.torch import safe_open - - with safe_open(network_weight, framework="pt") as f: - metadata = f.metadata() - if metadata is not None: - print(f"metadata for: {network_weight}: {metadata}") - - network, weights_sd = imported_module.create_network_from_weights( - network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs - ) - else: - raise ValueError("No weight. Weight is required.") - if network is None: - return - - mergeable = network.is_mergeable() - if args.network_merge and not mergeable: - print("network is not mergiable. ignore merge option.") - - if not args.network_merge or not mergeable: - network.apply_to([text_encoder1, text_encoder2], unet) - info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい - print(f"weights are loaded: {info}") - - if args.opt_channels_last: - network.to(memory_format=torch.channels_last) - network.to(dtype).to(device) - - if network_pre_calc: - print("backup original weights") - network.backup_weights() - - networks.append(network) - else: - network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device) - - else: - networks = [] - - # upscalerの指定があれば取得する - upscaler = None - if args.highres_fix_upscaler: - print("import upscaler module:", args.highres_fix_upscaler) - imported_module = importlib.import_module(args.highres_fix_upscaler) - - us_kwargs = {} - if args.highres_fix_upscaler_args: - for net_arg in args.highres_fix_upscaler_args.split(";"): - key, value = net_arg.split("=") - us_kwargs[key] = value - - print("create upscaler") - upscaler = imported_module.create_upscaler(**us_kwargs) - upscaler.to(dtype).to(device) - - # ControlNetの処理 - control_nets: List[LoRAControlNet] = [] - if args.control_net_models: - for i, model_file in enumerate(args.control_net_models): - print(f"loading control net: {model_file}") - - from safetensors.torch import load_file - - state_dict = load_file(model_file) - lora_rank = None - emb_dim = None - for key, value in state_dict.items(): - if lora_rank is None and "lora_down.weight" in key: - lora_rank = value.shape[0] - elif emb_dim is None and "conditioning1.0" in key: - emb_dim = value.shape[0] - if lora_rank is not None and emb_dim is not None: - break - assert lora_rank is not None and emb_dim is not None, f"invalid control net: {model_file}" - - control_net = LoRAControlNet(unet, emb_dim, lora_rank, 1) - control_net.apply_to() - control_net.load_state_dict(state_dict) - control_net.to(dtype).to(device) - control_net.set_batch_cond_only(False, False) - control_nets.append(control_net) - - if args.opt_channels_last: - print(f"set optimizing: channels last") - text_encoder1.to(memory_format=torch.channels_last) - text_encoder2.to(memory_format=torch.channels_last) - vae.to(memory_format=torch.channels_last) - unet.to(memory_format=torch.channels_last) - if networks: - for network in networks: - network.to(memory_format=torch.channels_last) - - for cn in control_nets: - cn.to(memory_format=torch.channels_last) - - pipe = PipelineLike( - device, - vae, - [text_encoder1, text_encoder2], - [tokenizer1, tokenizer2], - unet, - scheduler, - args.clip_skip, - ) - pipe.set_control_nets(control_nets) - print("pipeline is ready.") - - if args.diffusers_xformers: - pipe.enable_xformers_memory_efficient_attention() - - # Textual Inversionを処理する - if args.textual_inversion_embeddings: - token_ids_embeds1 = [] - token_ids_embeds2 = [] - for embeds_file in args.textual_inversion_embeddings: - if model_util.is_safetensors(embeds_file): - from safetensors.torch import load_file - - data = load_file(embeds_file) - else: - data = torch.load(embeds_file, map_location="cpu") - - if "string_to_param" in data: - data = data["string_to_param"] - - embeds1 = data["clip_l"] # text encoder 1 - embeds2 = data["clip_g"] # text encoder 2 - - num_vectors_per_token = embeds1.size()[0] - token_string = os.path.splitext(os.path.basename(embeds_file))[0] - - token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] - - # add new word to tokenizer, count is num_vectors_per_token - num_added_tokens1 = tokenizer1.add_tokens(token_strings) - num_added_tokens2 = tokenizer2.add_tokens(token_strings) - assert num_added_tokens1 == num_vectors_per_token and num_added_tokens2 == num_vectors_per_token, ( - f"tokenizer has same word to token string (filename): {embeds_file}" - + f" / 指定した名前(ファイル名)のトークンが既に存在します: {embeds_file}" - ) - - token_ids1 = tokenizer1.convert_tokens_to_ids(token_strings) - token_ids2 = tokenizer2.convert_tokens_to_ids(token_strings) - print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}") - assert ( - min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1 - ), f"token ids1 is not ordered" - assert ( - min(token_ids2) == token_ids2[0] and token_ids2[-1] == token_ids2[0] + len(token_ids2) - 1 - ), f"token ids2 is not ordered" - assert len(tokenizer1) - 1 == token_ids1[-1], f"token ids 1 is not end of tokenize: {len(tokenizer1)}" - assert len(tokenizer2) - 1 == token_ids2[-1], f"token ids 2 is not end of tokenize: {len(tokenizer2)}" - - if num_vectors_per_token > 1: - pipe.add_token_replacement(0, token_ids1[0], token_ids1) # hoge -> hoge, hogea, hogeb, ... - pipe.add_token_replacement(1, token_ids2[0], token_ids2) - - token_ids_embeds1.append((token_ids1, embeds1)) - token_ids_embeds2.append((token_ids2, embeds2)) - - text_encoder1.resize_token_embeddings(len(tokenizer1)) - text_encoder2.resize_token_embeddings(len(tokenizer2)) - token_embeds1 = text_encoder1.get_input_embeddings().weight.data - token_embeds2 = text_encoder2.get_input_embeddings().weight.data - for token_ids, embeds in token_ids_embeds1: - for token_id, embed in zip(token_ids, embeds): - token_embeds1[token_id] = embed - for token_ids, embeds in token_ids_embeds2: - for token_id, embed in zip(token_ids, embeds): - token_embeds2[token_id] = embed - - # promptを取得する - if args.from_file is not None: - print(f"reading prompts from {args.from_file}") - with open(args.from_file, "r", encoding="utf-8") as f: - prompt_list = f.read().splitlines() - prompt_list = [d for d in prompt_list if len(d.strip()) > 0] - elif args.prompt is not None: - prompt_list = [args.prompt] - else: - prompt_list = [] - - if args.interactive: - args.n_iter = 1 - - # img2imgの前処理、画像の読み込みなど - def load_images(path): - if os.path.isfile(path): - paths = [path] - else: - paths = ( - glob.glob(os.path.join(path, "*.png")) - + glob.glob(os.path.join(path, "*.jpg")) - + glob.glob(os.path.join(path, "*.jpeg")) - + glob.glob(os.path.join(path, "*.webp")) - ) - paths.sort() - - images = [] - for p in paths: - image = Image.open(p) - if image.mode != "RGB": - print(f"convert image to RGB from {image.mode}: {p}") - image = image.convert("RGB") - images.append(image) - - return images - - def resize_images(imgs, size): - resized = [] - for img in imgs: - r_img = img.resize(size, Image.Resampling.LANCZOS) - if hasattr(img, "filename"): # filename属性がない場合があるらしい - r_img.filename = img.filename - resized.append(r_img) - return resized - - if args.image_path is not None: - print(f"load image for img2img: {args.image_path}") - init_images = load_images(args.image_path) - assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" - print(f"loaded {len(init_images)} images for img2img") - else: - init_images = None - - if args.mask_path is not None: - print(f"load mask for inpainting: {args.mask_path}") - mask_images = load_images(args.mask_path) - assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}" - print(f"loaded {len(mask_images)} mask images for inpainting") - else: - mask_images = None - - # promptがないとき、画像のPngInfoから取得する - if init_images is not None and len(prompt_list) == 0 and not args.interactive: - print("get prompts from images' meta data") - for img in init_images: - if "prompt" in img.text: - prompt = img.text["prompt"] - if "negative-prompt" in img.text: - prompt += " --n " + img.text["negative-prompt"] - prompt_list.append(prompt) - - # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) - l = [] - for im in init_images: - l.extend([im] * args.images_per_prompt) - init_images = l - - if mask_images is not None: - l = [] - for im in mask_images: - l.extend([im] * args.images_per_prompt) - mask_images = l - - # 画像サイズにオプション指定があるときはリサイズする - if args.W is not None and args.H is not None: - # highres fix を考慮に入れる - w, h = args.W, args.H - if highres_fix: - w = int(w * args.highres_fix_scale + 0.5) - h = int(h * args.highres_fix_scale + 0.5) - - if init_images is not None: - print(f"resize img2img source images to {w}*{h}") - init_images = resize_images(init_images, (w, h)) - if mask_images is not None: - print(f"resize img2img mask images to {w}*{h}") - mask_images = resize_images(mask_images, (w, h)) - - regional_network = False - if networks and mask_images: - # mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応 - regional_network = True - print("use mask as region") - - size = None - for i, network in enumerate(networks): - if i < 3: - np_mask = np.array(mask_images[0]) - np_mask = np_mask[:, :, i] - size = np_mask.shape - else: - np_mask = np.full(size, 255, dtype=np.uint8) - mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0) - network.set_region(i, i == len(networks) - 1, mask) - mask_images = None - - prev_image = None # for VGG16 guided - if args.guide_image_path is not None: - print(f"load image for ControlNet guidance: {args.guide_image_path}") - guide_images = [] - for p in args.guide_image_path: - guide_images.extend(load_images(p)) - - print(f"loaded {len(guide_images)} guide images for guidance") - if len(guide_images) == 0: - print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}") - guide_images = None - else: - guide_images = None - - # seed指定時はseedを決めておく - if args.seed is not None: - # dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう - random.seed(args.seed) - predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] - if len(predefined_seeds) == 1: - predefined_seeds[0] = args.seed - else: - predefined_seeds = None - - # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) - if args.W is None: - args.W = 1024 - if args.H is None: - args.H = 1024 - - # 画像生成のループ - os.makedirs(args.outdir, exist_ok=True) - max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples - - for gen_iter in range(args.n_iter): - print(f"iteration {gen_iter+1}/{args.n_iter}") - iter_seed = random.randint(0, 0x7FFFFFFF) - - # バッチ処理の関数 - def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): - batch_size = len(batch) - - # highres_fixの処理 - if highres_fix and not highres_1st: - # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す - is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling - - print("process 1st stage") - batch_1st = [] - for _, base, ext in batch: - - def scale_and_round(x): - if x is None: - return None - return int(x * args.highres_fix_scale + 0.5) - - width_1st = scale_and_round(ext.width) - height_1st = scale_and_round(ext.height) - width_1st = width_1st - width_1st % 32 - height_1st = height_1st - height_1st % 32 - - original_width_1st = scale_and_round(ext.original_width) - original_height_1st = scale_and_round(ext.original_height) - original_width_negative_1st = scale_and_round(ext.original_width_negative) - original_height_negative_1st = scale_and_round(ext.original_height_negative) - crop_left_1st = scale_and_round(ext.crop_left) - crop_top_1st = scale_and_round(ext.crop_top) - - strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength - - ext_1st = BatchDataExt( - width_1st, - height_1st, - original_width_1st, - original_height_1st, - original_width_negative_1st, - original_height_negative_1st, - crop_left_1st, - crop_top_1st, - args.highres_fix_steps, - ext.scale, - ext.negative_scale, - strength_1st, - ext.network_muls, - ext.num_sub_prompts, - ) - batch_1st.append(BatchData(is_1st_latent, base, ext_1st)) - - pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする - images_1st = process_batch(batch_1st, True, True) - - # 2nd stageのバッチを作成して以下処理する - print("process 2nd stage") - width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height - - if upscaler: - # upscalerを使って画像を拡大する - lowreso_imgs = None if is_1st_latent else images_1st - lowreso_latents = None if not is_1st_latent else images_1st - - # 戻り値はPIL.Image.Imageかtorch.Tensorのlatents - batch_size = len(images_1st) - vae_batch_size = ( - batch_size - if args.vae_batch_size is None - else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size) - ) - vae_batch_size = int(vae_batch_size) - images_1st = upscaler.upscale( - vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size - ) - - elif args.highres_fix_latents_upscaling: - # latentを拡大する - org_dtype = images_1st.dtype - if images_1st.dtype == torch.bfloat16: - images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない - images_1st = torch.nn.functional.interpolate( - images_1st, (batch[0].ext.height // 8, batch[0].ext.width // 8), mode="bilinear" - ) # , antialias=True) - images_1st = images_1st.to(org_dtype) - - else: - # 画像をLANCZOSで拡大する - images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st] - - batch_2nd = [] - for i, (bd, image) in enumerate(zip(batch, images_1st)): - bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext) - batch_2nd.append(bd_2nd) - batch = batch_2nd - - if args.highres_fix_disable_control_net: - pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする - - # このバッチの情報を取り出す - ( - return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image), - ( - width, - height, - original_width, - original_height, - original_width_negative, - original_height_negative, - crop_left, - crop_top, - steps, - scale, - negative_scale, - strength, - network_muls, - num_sub_prompts, - ), - ) = batch[0] - noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) - - prompts = [] - negative_prompts = [] - start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - noises = [ - torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - for _ in range(steps * scheduler_num_noises_per_step) - ] - seeds = [] - clip_prompts = [] - - if init_image is not None: # img2img? - i2i_noises = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) - init_images = [] - - if mask_image is not None: - mask_images = [] - else: - mask_images = None - else: - i2i_noises = None - init_images = None - mask_images = None - - if guide_image is not None: # CLIP image guided? - guide_images = [] - else: - guide_images = None - - # バッチ内の位置に関わらず同じ乱数を使うためにここで乱数を生成しておく。あわせてimage/maskがbatch内で同一かチェックする - all_images_are_same = True - all_masks_are_same = True - all_guide_images_are_same = True - for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): - prompts.append(prompt) - negative_prompts.append(negative_prompt) - seeds.append(seed) - clip_prompts.append(clip_prompt) - - if init_image is not None: - init_images.append(init_image) - if i > 0 and all_images_are_same: - all_images_are_same = init_images[-2] is init_image - - if mask_image is not None: - mask_images.append(mask_image) - if i > 0 and all_masks_are_same: - all_masks_are_same = mask_images[-2] is mask_image - - if guide_image is not None: - if type(guide_image) is list: - guide_images.extend(guide_image) - all_guide_images_are_same = False - else: - guide_images.append(guide_image) - if i > 0 and all_guide_images_are_same: - all_guide_images_are_same = guide_images[-2] is guide_image - - # make start code - torch.manual_seed(seed) - start_code[i] = torch.randn(noise_shape, device=device, dtype=dtype) - - # make each noises - for j in range(steps * scheduler_num_noises_per_step): - noises[j][i] = torch.randn(noise_shape, device=device, dtype=dtype) - - if i2i_noises is not None: # img2img noise - i2i_noises[i] = torch.randn(noise_shape, device=device, dtype=dtype) - - noise_manager.reset_sampler_noises(noises) - - # すべての画像が同じなら1枚だけpipeに渡すことでpipe側で処理を高速化する - if init_images is not None and all_images_are_same: - init_images = init_images[0] - if mask_images is not None and all_masks_are_same: - mask_images = mask_images[0] - if guide_images is not None and all_guide_images_are_same: - guide_images = guide_images[0] - - # ControlNet使用時はguide imageをリサイズする - if control_nets: - # TODO resampleのメソッド - guide_images = guide_images if type(guide_images) == list else [guide_images] - guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images] - if len(guide_images) == 1: - guide_images = guide_images[0] - - # generate - if networks: - # 追加ネットワークの処理 - shared = {} - for n, m in zip(networks, network_muls if network_muls else network_default_muls): - n.set_multiplier(m) - if regional_network: - n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) - - if not regional_network and network_pre_calc: - for n in networks: - n.restore_weights() - for n in networks: - n.pre_calculation() - print("pre-calculation... done") - - images = pipe( - prompts, - negative_prompts, - init_images, - mask_images, - height, - width, - original_height, - original_width, - original_height_negative, - original_width_negative, - crop_top, - crop_left, - steps, - scale, - negative_scale, - strength, - latents=start_code, - output_type="pil", - max_embeddings_multiples=max_embeddings_multiples, - img2img_noise=i2i_noises, - vae_batch_size=args.vae_batch_size, - return_latents=return_latents, - clip_prompts=clip_prompts, - clip_guide_images=guide_images, - ) - if highres_1st and not args.highres_fix_save_1st: # return images or latents - return images - - # save image - highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts) - ): - if highres_fix: - seed -= 1 # record original seed - metadata = PngInfo() - metadata.add_text("prompt", prompt) - metadata.add_text("seed", str(seed)) - metadata.add_text("sampler", args.sampler) - metadata.add_text("steps", str(steps)) - metadata.add_text("scale", str(scale)) - if negative_prompt is not None: - metadata.add_text("negative-prompt", negative_prompt) - if negative_scale is not None: - metadata.add_text("negative-scale", str(negative_scale)) - if clip_prompt is not None: - metadata.add_text("clip-prompt", clip_prompt) - metadata.add_text("original-height", str(original_height)) - metadata.add_text("original-width", str(original_width)) - metadata.add_text("original-height-negative", str(original_height_negative)) - metadata.add_text("original-width-negative", str(original_width_negative)) - metadata.add_text("crop-top", str(crop_top)) - metadata.add_text("crop-left", str(crop_left)) - - if args.use_original_file_name and init_images is not None: - if type(init_images) is list: - fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" - else: - fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" - elif args.sequential_file_name: - fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" - else: - fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" - - image.save(os.path.join(args.outdir, fln), pnginfo=metadata) - - if not args.no_preview and not highres_1st and args.interactive: - try: - import cv2 - - for prompt, image in zip(prompts, images): - cv2.imshow(prompt[:128], np.array(image)[:, :, ::-1]) # プロンプトが長いと死ぬ - cv2.waitKey() - cv2.destroyAllWindows() - except ImportError: - print("opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません") - - return images - - # 画像生成のプロンプトが一周するまでのループ - prompt_index = 0 - global_step = 0 - batch_data = [] - while args.interactive or prompt_index < len(prompt_list): - if len(prompt_list) == 0: - # interactive - valid = False - while not valid: - print("\nType prompt:") - try: - raw_prompt = input() - except EOFError: - break - - valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0 - if not valid: # EOF, end app - break - else: - raw_prompt = prompt_list[prompt_index] - - # sd-dynamic-prompts like variants: - # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) - raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt) - - # repeat prompt - for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): - raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] - - if pi == 0 or len(raw_prompts) > 1: - # parse prompt: if prompt is not changed, skip parsing - width = args.W - height = args.H - original_width = args.original_width - original_height = args.original_height - original_width_negative = args.original_width_negative - original_height_negative = args.original_height_negative - crop_top = args.crop_top - crop_left = args.crop_left - scale = args.scale - negative_scale = args.negative_scale - steps = args.steps - seed = None - seeds = None - strength = 0.8 if args.strength is None else args.strength - negative_prompt = "" - clip_prompt = None - network_muls = None - - prompt_args = raw_prompt.strip().split(" --") - prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") - - for parg in prompt_args[1:]: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - print(f"width: {width}") - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - print(f"height: {height}") - continue - - m = re.match(r"ow (\d+)", parg, re.IGNORECASE) - if m: - original_width = int(m.group(1)) - print(f"original width: {original_width}") - continue - - m = re.match(r"oh (\d+)", parg, re.IGNORECASE) - if m: - original_height = int(m.group(1)) - print(f"original height: {original_height}") - continue - - m = re.match(r"nw (\d+)", parg, re.IGNORECASE) - if m: - original_width_negative = int(m.group(1)) - print(f"original width negative: {original_width_negative}") - continue - - m = re.match(r"nh (\d+)", parg, re.IGNORECASE) - if m: - original_height_negative = int(m.group(1)) - print(f"original height negative: {original_height_negative}") - continue - - m = re.match(r"ct (\d+)", parg, re.IGNORECASE) - if m: - crop_top = int(m.group(1)) - print(f"crop top: {crop_top}") - continue - - m = re.match(r"cl (\d+)", parg, re.IGNORECASE) - if m: - crop_left = int(m.group(1)) - print(f"crop left: {crop_left}") - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - steps = max(1, min(1000, int(m.group(1)))) - print(f"steps: {steps}") - continue - - m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE) - if m: # seed - seeds = [int(d) for d in m.group(1).split(",")] - print(f"seeds: {seeds}") - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - print(f"scale: {scale}") - continue - - m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE) - if m: # negative scale - if m.group(1).lower() == "none": - negative_scale = None - else: - negative_scale = float(m.group(1)) - print(f"negative scale: {negative_scale}") - continue - - m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE) - if m: # strength - strength = float(m.group(1)) - print(f"strength: {strength}") - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - print(f"negative prompt: {negative_prompt}") - continue - - m = re.match(r"c (.+)", parg, re.IGNORECASE) - if m: # clip prompt - clip_prompt = m.group(1) - print(f"clip prompt: {clip_prompt}") - continue - - m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE) - if m: # network multiplies - network_muls = [float(v) for v in m.group(1).split(",")] - while len(network_muls) < len(networks): - network_muls.append(network_muls[-1]) - print(f"network mul: {network_muls}") - continue - - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) - - # prepare seed - if seeds is not None: # given in prompt - # 数が足りないなら前のをそのまま使う - if len(seeds) > 0: - seed = seeds.pop(0) - else: - if predefined_seeds is not None: - if len(predefined_seeds) > 0: - seed = predefined_seeds.pop(0) - else: - print("predefined seeds are exhausted") - seed = None - elif args.iter_same_seed: - seeds = iter_seed - else: - seed = None # 前のを消す - - if seed is None: - seed = random.randint(0, 0x7FFFFFFF) - if args.interactive: - print(f"seed: {seed}") - - # prepare init image, guide image and mask - init_image = mask_image = guide_image = None - - # 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する - if init_images is not None: - init_image = init_images[global_step % len(init_images)] - - # img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する - # 32単位に丸めたやつにresizeされるので踏襲する - if not highres_fix: - width, height = init_image.size - width = width - width % 32 - height = height - height % 32 - if width != init_image.size[0] or height != init_image.size[1]: - print( - f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます" - ) - - if mask_images is not None: - mask_image = mask_images[global_step % len(mask_images)] - - if guide_images is not None: - if control_nets: # 複数件の場合あり - c = len(control_nets) - p = global_step % (len(guide_images) // c) - guide_image = guide_images[p * c : p * c + c] - else: - guide_image = guide_images[global_step % len(guide_images)] - - if regional_network: - num_sub_prompts = len(prompt.split(" AND ")) - assert ( - len(networks) <= num_sub_prompts - ), "Number of networks must be less than or equal to number of sub prompts." - else: - num_sub_prompts = None - - b1 = BatchData( - False, - BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), - BatchDataExt( - width, - height, - original_width, - original_height, - original_width_negative, - original_height_negative, - crop_left, - crop_top, - steps, - scale, - negative_scale, - strength, - tuple(network_muls) if network_muls else None, - num_sub_prompts, - ), - ) - if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要? - process_batch(batch_data, highres_fix) - batch_data.clear() - - batch_data.append(b1) - if len(batch_data) == args.batch_size: - prev_image = process_batch(batch_data, highres_fix)[0] - batch_data.clear() - - global_step += 1 - - prompt_index += 1 - - if len(batch_data) > 0: - process_batch(batch_data, highres_fix) - batch_data.clear() - - print("done!") - - -def setup_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser() - - parser.add_argument("--prompt", type=str, default=None, help="prompt / プロンプト") - parser.add_argument( - "--from_file", type=str, default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む" - ) - parser.add_argument( - "--interactive", action="store_true", help="interactive mode (generates one image) / 対話モード(生成される画像は1枚になります)" - ) - parser.add_argument( - "--no_preview", action="store_true", help="do not show generated image in interactive mode / 対話モードで画像を表示しない" - ) - parser.add_argument( - "--image_path", type=str, default=None, help="image to inpaint or to generate from / img2imgまたはinpaintを行う元画像" - ) - parser.add_argument("--mask_path", type=str, default=None, help="mask in inpainting / inpaint時のマスク") - parser.add_argument("--strength", type=float, default=None, help="img2img strength / img2img時のstrength") - parser.add_argument("--images_per_prompt", type=int, default=1, help="number of images per prompt / プロンプトあたりの出力枚数") - parser.add_argument("--outdir", type=str, default="outputs", help="dir to write results to / 生成画像の出力先") - parser.add_argument("--sequential_file_name", action="store_true", help="sequential output file name / 生成画像のファイル名を連番にする") - parser.add_argument( - "--use_original_file_name", - action="store_true", - help="prepend original file name in img2img / img2imgで元画像のファイル名を生成画像のファイル名の先頭に付ける", - ) - # parser.add_argument("--ddim_eta", type=float, default=0.0, help="ddim eta (eta=0.0 corresponds to deterministic sampling", ) - parser.add_argument("--n_iter", type=int, default=1, help="sample this often / 繰り返し回数") - parser.add_argument("--H", type=int, default=None, help="image height, in pixel space / 生成画像高さ") - parser.add_argument("--W", type=int, default=None, help="image width, in pixel space / 生成画像幅") - parser.add_argument( - "--original_height", type=int, default=None, help="original height for SDXL conditioning / SDXLの条件付けに用いるoriginal heightの値" - ) - parser.add_argument( - "--original_width", type=int, default=None, help="original width for SDXL conditioning / SDXLの条件付けに用いるoriginal widthの値" - ) - parser.add_argument( - "--original_height_negative", - type=int, - default=None, - help="original height for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal heightの値", - ) - parser.add_argument( - "--original_width_negative", - type=int, - default=None, - help="original width for SDXL unconditioning / SDXLのネガティブ条件付けに用いるoriginal widthの値", - ) - parser.add_argument("--crop_top", type=int, default=None, help="crop top for SDXL conditioning / SDXLの条件付けに用いるcrop topの値") - parser.add_argument("--crop_left", type=int, default=None, help="crop left for SDXL conditioning / SDXLの条件付けに用いるcrop leftの値") - parser.add_argument("--batch_size", type=int, default=1, help="batch size / バッチサイズ") - parser.add_argument( - "--vae_batch_size", - type=float, - default=None, - help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率", - ) - parser.add_argument( - "--vae_slices", - type=int, - default=None, - help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨", - ) - parser.add_argument("--no_half_vae", action="store_true", help="do not use fp16/bf16 precision for VAE / VAE処理時にfp16/bf16を使わない") - parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数") - parser.add_argument( - "--sampler", - type=str, - default="ddim", - choices=[ - "ddim", - "pndm", - "lms", - "euler", - "euler_a", - "heun", - "dpm_2", - "dpm_2_a", - "dpmsolver", - "dpmsolver++", - "dpmsingle", - "k_lms", - "k_euler", - "k_euler_a", - "k_dpm_2", - "k_dpm_2_a", - ], - help=f"sampler (scheduler) type / サンプラー(スケジューラ)の種類", - ) - parser.add_argument( - "--scale", - type=float, - default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty)) / guidance scale", - ) - parser.add_argument("--ckpt", type=str, default=None, help="path to checkpoint of model / モデルのcheckpointファイルまたはディレクトリ") - parser.add_argument( - "--vae", type=str, default=None, help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ" - ) - parser.add_argument( - "--tokenizer_cache_dir", - type=str, - default=None, - help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)", - ) - # parser.add_argument("--replace_clip_l14_336", action='store_true', - # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") - parser.add_argument( - "--seed", - type=int, - default=None, - help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed", - ) - parser.add_argument( - "--iter_same_seed", - action="store_true", - help="use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)", - ) - parser.add_argument("--fp16", action="store_true", help="use fp16 / fp16を指定し省メモリ化する") - parser.add_argument("--bf16", action="store_true", help="use bfloat16 / bfloat16を指定し省メモリ化する") - parser.add_argument("--xformers", action="store_true", help="use xformers / xformersを使用し高速化する") - parser.add_argument("--sdpa", action="store_true", help="use sdpa in PyTorch 2 / sdpa") - parser.add_argument( - "--diffusers_xformers", - action="store_true", - help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", - ) - parser.add_argument( - "--opt_channels_last", action="store_true", help="set channels last option to model / モデルにchannels lastを指定し最適化する" - ) - parser.add_argument( - "--network_module", type=str, default=None, nargs="*", help="additional network module to use / 追加ネットワークを使う時そのモジュール名" - ) - parser.add_argument( - "--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み" - ) - parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率") - parser.add_argument( - "--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数" - ) - parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する") - parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする") - parser.add_argument( - "--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する" - ) - parser.add_argument( - "--textual_inversion_embeddings", - type=str, - default=None, - nargs="*", - help="Embeddings files of Textual Inversion / Textual Inversionのembeddings", - ) - parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") - parser.add_argument( - "--max_embeddings_multiples", - type=int, - default=None, - help="max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", - ) - parser.add_argument( - "--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガイドに使う画像" - ) - parser.add_argument( - "--highres_fix_scale", - type=float, - default=None, - help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする", - ) - parser.add_argument( - "--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数" - ) - parser.add_argument( - "--highres_fix_strength", - type=float, - default=None, - help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ", - ) - parser.add_argument( - "--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する" - ) - parser.add_argument( - "--highres_fix_latents_upscaling", - action="store_true", - help="use latents upscaling for highres fix / highres fixでlatentで拡大する", - ) - parser.add_argument( - "--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名" - ) - parser.add_argument( - "--highres_fix_upscaler_args", - type=str, - default=None, - help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数", - ) - parser.add_argument( - "--highres_fix_disable_control_net", - action="store_true", - help="disable ControlNet for highres fix / highres fixでControlNetを使わない", - ) - - parser.add_argument( - "--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する" - ) - - parser.add_argument( - "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名" - ) - parser.add_argument( - "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名" - ) - parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み") - parser.add_argument( - "--control_net_ratios", - type=float, - default=None, - nargs="*", - help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", - ) - # parser.add_argument( - # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" - # ) - - return parser - - -if __name__ == "__main__": - parser = setup_parser() - - args = parser.parse_args() - main(args) diff --git a/sdxl_train_lora_control_net.py b/sdxl_train_control_net_lllite.py similarity index 97% rename from sdxl_train_lora_control_net.py rename to sdxl_train_control_net_lllite.py index b6fb1dec..6e49e449 100644 --- a/sdxl_train_lora_control_net.py +++ b/sdxl_train_control_net_lllite.py @@ -34,7 +34,7 @@ from library.custom_train_functions import ( apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, ) -import networks.lora_control_net as lora_control_net +import networks.control_net_lllite as control_net_lllite # TODO 他のスクリプトと共通化する @@ -176,7 +176,7 @@ def train(args): accelerator.wait_for_everyone() # prepare ControlNet - network = lora_control_net.LoRAControlNet(unet, args.cond_emb_dim, args.network_dim, 1, args.network_dropout) + network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout) network.apply_to() if args.network_weights is not None: @@ -242,7 +242,7 @@ def train(args): unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, network, optimizer, train_dataloader, lr_scheduler ) - network: lora_control_net.LoRAControlNet + network: control_net_lllite.ControlNetLLLite # transform DDP after prepare (train_network here only) unet, network = train_util.transform_models_if_DDP([unet, network]) @@ -311,7 +311,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lora_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs ) loss_list = [] @@ -401,11 +401,9 @@ def train(args): controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) with accelerator.autocast(): - # conditioning image embeddingを計算する / calculate conditioning image embedding - cond_embs_4d, cond_embs_3d = network(controlnet_image) - - # 個別のLoRA的モジュールでさらにembeddingを計算する / calculate embedding in each LoRA-like module - network.set_cond_embs(cond_embs_4d, cond_embs_3d) + # conditioning imageをControlNetに渡す / pass conditioning image to ControlNet + # 内部でcond_embに変換される / it will be converted to cond_emb inside + network.set_cond_image(controlnet_image) # それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) @@ -562,7 +560,7 @@ def setup_parser() -> argparse.ArgumentParser: if __name__ == "__main__": - sdxl_original_unet.USE_REENTRANT = False + # sdxl_original_unet.USE_REENTRANT = False parser = setup_parser() From 5a86bbc0a05eddaa34e289808057cdb3f0f7f7fa Mon Sep 17 00:00:00 2001 From: ykume Date: Sat, 19 Aug 2023 18:54:31 +0900 Subject: [PATCH 12/19] fix typos, update readme --- docs/train_lll_README-ja.md | 9 ++++----- sdxl_train_control_net_lllite.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/train_lll_README-ja.md b/docs/train_lll_README-ja.md index 33ee0ca8..1859fc9e 100644 --- a/docs/train_lll_README-ja.md +++ b/docs/train_lll_README-ja.md @@ -4,7 +4,7 @@ ConrtolNet-LLLite は、[ConrtolNet](https://github.com/lllyasviel/ControlNet) の軽量版です。LoRA Like Lite という意味で、LoRAに似た構造の軽量なControlNetです。現在はSDXLにのみ対応しています。 ## モデル構造 -制御用画像(以下conditioning image)を潜在空間に写像するconditioning image embeddingと、U-Netの各モジュールに付与されるLoRA的な構造を持つモジュールを組み合わせたモデルです。詳しくはソースコードを参照してください。 +制御用画像(以下conditioning image)を潜在空間に写像するconditioning image embeddingと、U-Netの各モジュールに付与されるLoRAにちょっと似た構造を持つモジュールを組み合わせたモデルです。詳しくはソースコードを参照してください。 ## モデルの学習 @@ -19,11 +19,10 @@ conditioning_data_dir = "path/to/conditioning/image/dir" ``` ### 学習 -`sdxl_train_lora_control_net.py` を実行してください。`--cond_emb_dim` でconditioning image embeddingの次元数を指定できます。`--network_dim` でLoRA的モジュールのrankを指定できます。その他のオプションは`sdxl_train_network.py`に準じます。 +`sdxl_train_control_net_lllite.py` を実行してください。`--cond_emb_dim` でconditioning image embeddingの次元数を指定できます。`--network_dim` でLoRA的モジュールのrankを指定できます。その他のオプションは`sdxl_train_network.py`に準じますが、`--network_module`の指定は不要です。 + ### 推論 -`sdxl_gen_img_lora_ctrl_test.py` を実行してください。`--control_net_models` でLLLiteのモデルファイルを指定できます。次元数はモデルファイルから自動取得します。 +`sdxl_gen_img.py` を実行してください。`--control_net_lllite_models` でLLLiteのモデルファイルを指定できます。次元数はモデルファイルから自動取得します。 `--guide_image_path`で推論に用いるconditioning imageを指定してください。なおpreprocessは行われないため、たとえばCannyならCanny処理を行った画像を指定してください(背景黒に白線)。`--control_net_preps`, `--control_net_weights`, `--control_net_ratios` には未対応です。 - -その他のオプションは`sdxl_gen_img.py`に準じます。 \ No newline at end of file diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 6e49e449..6e5c4232 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -325,7 +325,7 @@ def train(args): accelerator.print(f"\nsaving checkpoint: {ckpt_file}") sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) - sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-llite" + sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite" unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata) if args.huggingface_repo_id is not None: From 782b11b8448f3d30a67d8956e131d06082b4e32d Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sat, 19 Aug 2023 21:41:54 +0900 Subject: [PATCH 13/19] Update train_lll_README-ja.md add sample images --- docs/train_lll_README-ja.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/train_lll_README-ja.md b/docs/train_lll_README-ja.md index 1859fc9e..cd1ded68 100644 --- a/docs/train_lll_README-ja.md +++ b/docs/train_lll_README-ja.md @@ -26,3 +26,14 @@ conditioning_data_dir = "path/to/conditioning/image/dir" `sdxl_gen_img.py` を実行してください。`--control_net_lllite_models` でLLLiteのモデルファイルを指定できます。次元数はモデルファイルから自動取得します。 `--guide_image_path`で推論に用いるconditioning imageを指定してください。なおpreprocessは行われないため、たとえばCannyならCanny処理を行った画像を指定してください(背景黒に白線)。`--control_net_preps`, `--control_net_weights`, `--control_net_ratios` には未対応です。 + +### サンプル +Canny +![kohya_ss_girl_standing_at_classroom_smiling_to_the_viewer_class_78976b3e-0d4d-4ea0-b8e3-053ae493abbc](https://github.com/kohya-ss/sd-scripts/assets/52813779/7e883352-0fea-4f5a-b820-94e17ec3f3f2) + +![im_20230819212806_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/c28196f9-b2c3-40ad-b000-21a77e657968) + +![im_20230819212815_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/b8506354-feb8-4d58-86a8-738a9ba03911) + +![im_20230819212822_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/1612c221-8df5-420c-b907-75758d89aca7) + From 0646112010acaea0d2f3f3235472e05f013e9b18 Mon Sep 17 00:00:00 2001 From: ykume Date: Sun, 20 Aug 2023 00:09:09 +0900 Subject: [PATCH 14/19] fix a bug x is updated inplace --- networks/control_net_lllite.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/networks/control_net_lllite.py b/networks/control_net_lllite.py index e0157080..36e36071 100644 --- a/networks/control_net_lllite.py +++ b/networks/control_net_lllite.py @@ -153,7 +153,7 @@ class LLLiteModule(torch.nn.Module): # down reduces the number of input dimensions and combines it with conditioning image embedding # we expect that it will mix well by combining in the channel direction instead of adding - cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2) + cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2) cx = self.mid(cx) if self.dropout is not None and self.training: @@ -161,18 +161,11 @@ class LLLiteModule(torch.nn.Module): cx = self.up(cx) - # residualを加算する / add residual + # residua (x) lを加算して元のforwardを呼び出す / add residual (x) and call the original forward if self.batch_cond_only: - x[1::2] += cx - else: - # to_outを対象とすると、cloneがないと次のエラーが出る / if to_out is the target, the following error will occur without clone - # RuntimeError: Output 0 of ReshapeAliasBackward0 is a view and is being modified inplace. - # This view was created inside a custom Function ... - # x = x.clone() + cx = torch.zeros_like(x)[1::2] + cx - x += cx - - x = self.org_forward(x) # ここで元のモジュールを呼び出す / call the original module here + x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here return x From e19189282419cb57379c94aeab91ee2d70594ed8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Aug 2023 12:24:40 +0900 Subject: [PATCH 15/19] fix bucketing doesn't work in controlnet training --- library/train_util.py | 8 ++++++-- networks/control_net_lllite.py | 3 ++- sdxl_train_control_net_lllite.py | 2 ++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 631f1cb7..ff1b4a33 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1770,12 +1770,16 @@ class ControlNetDataset(BaseDataset): cond_img = load_image(image_info.cond_img_path) if self.dreambooth_dataset_delegate.enable_bucket: - cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ assert ( cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1] ), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}" - ct, cl = crop_top_left + cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + + # TODO support random crop + # 現在サポートしているcropはrandomではなく中央のみ h, w = target_size_hw + ct = (cond_img.shape[0] - h) // 2 + cl = (cond_img.shape[1] - w) // 2 cond_img = cond_img[ct : ct + h, cl : cl + w] else: # assert ( diff --git a/networks/control_net_lllite.py b/networks/control_net_lllite.py index 36e36071..3140919c 100644 --- a/networks/control_net_lllite.py +++ b/networks/control_net_lllite.py @@ -120,6 +120,7 @@ class LLLiteModule(torch.nn.Module): / call the model inside, so if necessary, surround it with torch.no_grad() """ # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance + # print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}") cx = self.conditioning1(cond_image) if not self.is_conv2d: # reshape / b,c,h,w -> b,h*w,c @@ -146,7 +147,7 @@ class LLLiteModule(torch.nn.Module): cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) if self.use_zeros_for_batch_uncond: cx[0::2] = 0.0 # uncond is zero - # print(f"C {self.lllite_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") + # print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}") # downで入力の次元数を削減し、conditioning image embeddingと結合する # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 6e5c4232..09cf1643 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -113,6 +113,8 @@ def train(args): assert ( train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + else: + print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません") if args.cache_text_encoder_outputs: assert ( From bee5c3f1b82128fbb0cb3ab03777b5f8be14f092 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Aug 2023 12:45:56 +0900 Subject: [PATCH 16/19] update lllite doc --- docs/train_lll_README-ja.md | 39 ---------------------------- docs/train_lllite_README-ja.md | 40 +++++++++++++++++++++++++++++ docs/train_lllite_README.md | 46 ++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 39 deletions(-) delete mode 100644 docs/train_lll_README-ja.md create mode 100644 docs/train_lllite_README-ja.md create mode 100644 docs/train_lllite_README.md diff --git a/docs/train_lll_README-ja.md b/docs/train_lll_README-ja.md deleted file mode 100644 index cd1ded68..00000000 --- a/docs/train_lll_README-ja.md +++ /dev/null @@ -1,39 +0,0 @@ -# ConrtolNet-LLLite について - -## 概要 -ConrtolNet-LLLite は、[ConrtolNet](https://github.com/lllyasviel/ControlNet) の軽量版です。LoRA Like Lite という意味で、LoRAに似た構造の軽量なControlNetです。現在はSDXLにのみ対応しています。 - -## モデル構造 -制御用画像(以下conditioning image)を潜在空間に写像するconditioning image embeddingと、U-Netの各モジュールに付与されるLoRAにちょっと似た構造を持つモジュールを組み合わせたモデルです。詳しくはソースコードを参照してください。 - -## モデルの学習 - -### データセットの準備 -通常のdatasetに加え、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。 - -```toml -[[datasets.subsets]] -image_dir = "path/to/image/dir" -caption_extension = ".txt" -conditioning_data_dir = "path/to/conditioning/image/dir" -``` - -### 学習 -`sdxl_train_control_net_lllite.py` を実行してください。`--cond_emb_dim` でconditioning image embeddingの次元数を指定できます。`--network_dim` でLoRA的モジュールのrankを指定できます。その他のオプションは`sdxl_train_network.py`に準じますが、`--network_module`の指定は不要です。 - - -### 推論 -`sdxl_gen_img.py` を実行してください。`--control_net_lllite_models` でLLLiteのモデルファイルを指定できます。次元数はモデルファイルから自動取得します。 - -`--guide_image_path`で推論に用いるconditioning imageを指定してください。なおpreprocessは行われないため、たとえばCannyならCanny処理を行った画像を指定してください(背景黒に白線)。`--control_net_preps`, `--control_net_weights`, `--control_net_ratios` には未対応です。 - -### サンプル -Canny -![kohya_ss_girl_standing_at_classroom_smiling_to_the_viewer_class_78976b3e-0d4d-4ea0-b8e3-053ae493abbc](https://github.com/kohya-ss/sd-scripts/assets/52813779/7e883352-0fea-4f5a-b820-94e17ec3f3f2) - -![im_20230819212806_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/c28196f9-b2c3-40ad-b000-21a77e657968) - -![im_20230819212815_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/b8506354-feb8-4d58-86a8-738a9ba03911) - -![im_20230819212822_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/1612c221-8df5-420c-b907-75758d89aca7) - diff --git a/docs/train_lllite_README-ja.md b/docs/train_lllite_README-ja.md new file mode 100644 index 00000000..9df4284e --- /dev/null +++ b/docs/train_lllite_README-ja.md @@ -0,0 +1,40 @@ +# ConrtolNet-LLLite について + +## 概要 +ConrtolNet-LLLite は、[ConrtolNet](https://github.com/lllyasviel/ControlNet) の軽量版です。LoRA Like Lite という意味で、LoRAからインスピレーションを得た構造を持つ、軽量なControlNetです。現在はSDXLにのみ対応しています。 + +## モデル構造 +ひとつのLLLiteモジュールは、制御用画像(以下conditioning image)を潜在空間に写像するconditioning image embeddingと、LoRAにちょっと似た構造を持つ小型のネットワークからなります。LLLiteモジュールを、LoRAと同様にU-NetのLinearやConvに追加します。詳しくはソースコードを参照してください。 + +推論環境の制限で、現在はCrossAttentionのみ(attn1のq/k/v、attn2のq)に追加されます。 + +## モデルの学習 + +### データセットの準備 +通常のdatasetに加え、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。 + +```toml +[[datasets.subsets]] +image_dir = "path/to/image/dir" +caption_extension = ".txt" +conditioning_data_dir = "path/to/conditioning/image/dir" +``` + +現時点の制約として、random_cropは使用できません。 + +### 学習 +`sdxl_train_control_net_lllite.py` を実行してください。`--cond_emb_dim` でconditioning image embeddingの次元数を指定できます。`--network_dim` でLoRA的モジュールのrankを指定できます。その他のオプションは`sdxl_train_network.py`に準じますが、`--network_module`の指定は不要です。 + +conditioning image embeddingの次元数は、サンプルのCannyでは32を指定しています。LoRA的モジュールのrankは同じく64です。対象とするconditioning imageの特徴に合わせて調整してください。 + +(サンプルのCannyは恐らくかなり難しいと思われます。depthなどでは半分程度にしてもいいかもしれません。) + +### 推論 +ComfyUIのカスタムノードを用意しています。: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI + +スクリプトで生成する場合は、`sdxl_gen_img.py` を実行してください。`--control_net_lllite_models` でLLLiteのモデルファイルを指定できます。次元数はモデルファイルから自動取得します。 + +`--guide_image_path`で推論に用いるconditioning imageを指定してください。なおpreprocessは行われないため、たとえばCannyならCanny処理を行った画像を指定してください(背景黒に白線)。`--control_net_preps`, `--control_net_weights`, `--control_net_ratios` には未対応です。 + +## サンプル +Canny diff --git a/docs/train_lllite_README.md b/docs/train_lllite_README.md new file mode 100644 index 00000000..ab8fbd62 --- /dev/null +++ b/docs/train_lllite_README.md @@ -0,0 +1,46 @@ +# About ConrtolNet-LLLite + +## Overview + +ConrtolNet-LLLite is a lightweight version of [ConrtolNet](https://github.com/lllyasviel/ControlNet). It is a "LoRA Like Lite" that is inspired by LoRA and has a lightweight structure. Currently, only SDXL is supported. + +## Model structure + +A single LLLite module consists of a conditioning image embedding that maps a conditioning image to a latent space and a small network with a structure similar to LoRA. The LLLite module is added to U-Net's Linear and Conv in the same way as LoRA. Please refer to the source code for details. + +Due to the limitations of the inference environment, only CrossAttention (attn1 q/k/v, attn2 q) is currently added. + +## Model training + +### Preparing the dataset + +In addition to the normal dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. + +```toml +[[datasets.subsets]] +image_dir = "path/to/image/dir" +caption_extension = ".txt" +conditioning_data_dir = "path/to/conditioning/image/dir" +``` + +At the moment, random_crop cannot be used. + +### Training + +Run `sdxl_train_control_net_lllite.py`. You can specify the dimension of the conditioning image embedding with `--cond_emb_dim`. You can specify the rank of the LoRA-like module with `--network_dim`. Other options are the same as `sdxl_train_network.py`, but `--network_module` is not required. + +For the sample Canny, the dimension of the conditioning image embedding is 32. The rank of the LoRA-like module is also 64. Adjust according to the features of the conditioning image you are targeting. + +(The sample Canny is probably quite difficult. It may be better to reduce it to about half for depth, etc.) + +### Inference + +A custom node for ComfyUI is available: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI + +If you want to generate images with a script, run `sdxl_gen_img.py`. You can specify the LLLite model file with `--control_net_lllite_models`. The dimension is automatically obtained from the model file. + +Specify the conditioning image to be used for inference with `--guide_image_path`. Since preprocess is not performed, if it is Canny, specify an image processed with Canny (white line on black background). `--control_net_preps`, `--control_net_weights`, and `--control_net_ratios` are not supported. + +## Sample + +Canny From b74dfba2159812304b377df292fb8cb917eddaeb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Aug 2023 12:50:37 +0900 Subject: [PATCH 17/19] update lllite doc --- docs/train_lllite_README-ja.md | 10 ++++++++-- docs/train_lllite_README.md | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/docs/train_lllite_README-ja.md b/docs/train_lllite_README-ja.md index 9df4284e..1b39a537 100644 --- a/docs/train_lllite_README-ja.md +++ b/docs/train_lllite_README-ja.md @@ -3,6 +3,12 @@ ## 概要 ConrtolNet-LLLite は、[ConrtolNet](https://github.com/lllyasviel/ControlNet) の軽量版です。LoRA Like Lite という意味で、LoRAからインスピレーションを得た構造を持つ、軽量なControlNetです。現在はSDXLにのみ対応しています。 +## サンプルの重みファイルと推論 + +こちらにあります: https://huggingface.co/kohya-ss/controlnet-lllite + +ComfyUIのカスタムノードを用意しています。: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI + ## モデル構造 ひとつのLLLiteモジュールは、制御用画像(以下conditioning image)を潜在空間に写像するconditioning image embeddingと、LoRAにちょっと似た構造を持つ小型のネットワークからなります。LLLiteモジュールを、LoRAと同様にU-NetのLinearやConvに追加します。詳しくはソースコードを参照してください。 @@ -23,14 +29,14 @@ conditioning_data_dir = "path/to/conditioning/image/dir" 現時点の制約として、random_cropは使用できません。 ### 学習 -`sdxl_train_control_net_lllite.py` を実行してください。`--cond_emb_dim` でconditioning image embeddingの次元数を指定できます。`--network_dim` でLoRA的モジュールのrankを指定できます。その他のオプションは`sdxl_train_network.py`に準じますが、`--network_module`の指定は不要です。 +スクリプトで生成する場合は、`sdxl_train_control_net_lllite.py` を実行してください。`--cond_emb_dim` でconditioning image embeddingの次元数を指定できます。`--network_dim` でLoRA的モジュールのrankを指定できます。その他のオプションは`sdxl_train_network.py`に準じますが、`--network_module`の指定は不要です。 conditioning image embeddingの次元数は、サンプルのCannyでは32を指定しています。LoRA的モジュールのrankは同じく64です。対象とするconditioning imageの特徴に合わせて調整してください。 (サンプルのCannyは恐らくかなり難しいと思われます。depthなどでは半分程度にしてもいいかもしれません。) ### 推論 -ComfyUIのカスタムノードを用意しています。: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI + スクリプトで生成する場合は、`sdxl_gen_img.py` を実行してください。`--control_net_lllite_models` でLLLiteのモデルファイルを指定できます。次元数はモデルファイルから自動取得します。 diff --git a/docs/train_lllite_README.md b/docs/train_lllite_README.md index ab8fbd62..8ef41809 100644 --- a/docs/train_lllite_README.md +++ b/docs/train_lllite_README.md @@ -4,6 +4,14 @@ ConrtolNet-LLLite is a lightweight version of [ConrtolNet](https://github.com/lllyasviel/ControlNet). It is a "LoRA Like Lite" that is inspired by LoRA and has a lightweight structure. Currently, only SDXL is supported. +## Sample weight file and inference + +Sample weight file is available here: https://huggingface.co/kohya-ss/controlnet-lllite + +A custom node for ComfyUI is available: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI + +Sample images are at the end of this page. + ## Model structure A single LLLite module consists of a conditioning image embedding that maps a conditioning image to a latent space and a small network with a structure similar to LoRA. The LLLite module is added to U-Net's Linear and Conv in the same way as LoRA. Please refer to the source code for details. @@ -35,8 +43,6 @@ For the sample Canny, the dimension of the conditioning image embedding is 32. T ### Inference -A custom node for ComfyUI is available: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI - If you want to generate images with a script, run `sdxl_gen_img.py`. You can specify the LLLite model file with `--control_net_lllite_models`. The dimension is automatically obtained from the model file. Specify the conditioning image to be used for inference with `--guide_image_path`. Since preprocess is not performed, if it is Canny, specify an image processed with Canny (white line on black background). `--control_net_preps`, `--control_net_weights`, and `--control_net_ratios` are not supported. From 98f8785a4f40ab31889471fce21bd7ff0dd3bf78 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 20 Aug 2023 12:55:24 +0900 Subject: [PATCH 18/19] Update train_lllite_README.md --- docs/train_lllite_README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/train_lllite_README.md b/docs/train_lllite_README.md index ab8fbd62..6d090a2b 100644 --- a/docs/train_lllite_README.md +++ b/docs/train_lllite_README.md @@ -44,3 +44,10 @@ Specify the conditioning image to be used for inference with `--guide_image_path ## Sample Canny +![kohya_ss_girl_standing_at_classroom_smiling_to_the_viewer_class_78976b3e-0d4d-4ea0-b8e3-053ae493abbc](https://github.com/kohya-ss/sd-scripts/assets/52813779/37e9a736-649b-4c0f-ab26-880a1bf319b5) + +![im_20230820104253_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/c8896900-ab86-4120-932f-6e2ae17b77c0) + +![im_20230820104302_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/b12457a0-ee3c-450e-ba9a-b712d0fe86bb) + +![im_20230820104310_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/8845b8d9-804a-44ac-9618-113a28eac8a1) From 15b463d18d5021f339c48593cc2d6e70398ddf8a Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Aug 2023 12:56:44 +0900 Subject: [PATCH 19/19] update lllite doc --- docs/train_lllite_README-ja.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/train_lllite_README-ja.md b/docs/train_lllite_README-ja.md index 1b39a537..45491b73 100644 --- a/docs/train_lllite_README-ja.md +++ b/docs/train_lllite_README-ja.md @@ -9,6 +9,8 @@ ConrtolNet-LLLite は、[ConrtolNet](https://github.com/lllyasviel/ControlNet) ComfyUIのカスタムノードを用意しています。: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI +生成サンプルはこのページの末尾にあります。 + ## モデル構造 ひとつのLLLiteモジュールは、制御用画像(以下conditioning image)を潜在空間に写像するconditioning image embeddingと、LoRAにちょっと似た構造を持つ小型のネットワークからなります。LLLiteモジュールを、LoRAと同様にU-NetのLinearやConvに追加します。詳しくはソースコードを参照してください。 @@ -44,3 +46,10 @@ conditioning image embeddingの次元数は、サンプルのCannyでは32を指 ## サンプル Canny +![kohya_ss_girl_standing_at_classroom_smiling_to_the_viewer_class_78976b3e-0d4d-4ea0-b8e3-053ae493abbc](https://github.com/kohya-ss/sd-scripts/assets/52813779/37e9a736-649b-4c0f-ab26-880a1bf319b5) + +![im_20230820104253_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/c8896900-ab86-4120-932f-6e2ae17b77c0) + +![im_20230820104302_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/b12457a0-ee3c-450e-ba9a-b712d0fe86bb) + +![im_20230820104310_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/8845b8d9-804a-44ac-9618-113a28eac8a1)