From 1e52fe6e09ea1aec95716b9aaa2e6837eaa08213 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 17 Aug 2023 20:49:39 +0900 Subject: [PATCH] add comments --- networks/lora_control_net.py | 95 ++++++++++-- sdxl_train_lora_control_net.py | 261 +-------------------------------- 2 files changed, 85 insertions(+), 271 deletions(-) diff --git a/networks/lora_control_net.py b/networks/lora_control_net.py index 11b4db90..0dd2a0a1 100644 --- a/networks/lora_control_net.py +++ b/networks/lora_control_net.py @@ -5,12 +5,25 @@ from networks.lora import LoRAModule, LoRANetwork from library import sdxl_original_unet +# input_blocksに適用するかどうか / if True, input_blocks are not applied SKIP_INPUT_BLOCKS = False + +# output_blocksに適用するかどうか / if True, output_blocks are not applied SKIP_OUTPUT_BLOCKS = True + +# conv2dに適用するかどうか / if True, conv2d are not applied SKIP_CONV2D = False -TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored -ATTN1_ETC_ONLY = False # True -TRANSFORMER_MAX_BLOCK_INDEX = None # 3 # None # 2 # None for all blocks + +# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない +# if True, only transformer_blocks are applied, and ResBlocks are not applied +TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks + +# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2 +ATTN1_ETC_ONLY = False # True + +# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用 +# max index of transformer_blocks. if None, apply to all transformer_blocks +TRANSFORMER_MAX_BLOCK_INDEX = None class LoRAModuleControlNet(LoRAModule): @@ -19,6 +32,16 @@ class LoRAModuleControlNet(LoRAModule): self.is_conv2d = org_module.__class__.__name__ == "Conv2d" self.cond_emb_dim = cond_emb_dim + # conditioning1は、conditioning image embeddingを、各LoRA的モジュールでさらに学習する。ここはtimestepごとに呼ばれない + # それぞれのモジュールで異なる表現を学習することを期待している + # conditioning1 learns conditioning image embedding in each LoRA-like module. this is not called for each timestep + # we expect to learn different representations in each module + + # conditioning2は、conditioning1の出力とLoRAの出力を結合し、LoRAの出力に加算する。timestepごとに呼ばれる + # conditioning image embeddingとU-Netの出力を合わせて学ぶことで、conditioningに応じたU-Netの調整を行う + # conditioning2 combines the output of conditioning1 and the output of LoRA, and adds it to the output of LoRA. this is called for each timestep + # by learning the output of conditioning image embedding and U-Net together, we adjust U-Net according to conditioning + if self.is_conv2d: self.conditioning1 = torch.nn.Sequential( torch.nn.Conv2d(cond_emb_dim, cond_emb_dim, kernel_size=3, stride=1, padding=0), @@ -45,16 +68,26 @@ class LoRAModuleControlNet(LoRAModule): torch.nn.Linear(cond_emb_dim, lora_dim), torch.nn.ReLU(inplace=True), ) + + # Zero-Convにするならコメントを外す / uncomment if you want to use Zero-Conv # torch.nn.init.zeros_(self.conditioning2[-2].weight) # zero conv - self.depth = depth + self.depth = depth # 1~3 self.cond_emb = None - self.batch_cond_only = False - self.use_zeros_for_batch_uncond = False + self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference + self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0 def set_cond_embs(self, cond_embs_4d, cond_embs_3d): + r""" + 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む + / call the model inside, so if necessary, surround it with torch.no_grad() + """ + # conv2dとlinearでshapeが違うので必要な方を選択 / select the required one because the shape is different for conv2d and linear cond_embs = cond_embs_4d if self.is_conv2d else cond_embs_3d + cond_emb = cond_embs[self.depth - 1] + + # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance self.cond_emb = self.conditioning1(cond_emb) def set_batch_cond_only(self, cond_only, zeros): @@ -65,32 +98,39 @@ class LoRAModuleControlNet(LoRAModule): if self.cond_emb is None: return self.org_forward(x) - # LoRA + # LoRA-Down lx = x if self.batch_cond_only: - lx = lx[1::2] # cond only + lx = lx[1::2] # cond only in inference lx = self.lora_down(lx) if self.dropout is not None and self.training: lx = torch.nn.functional.dropout(lx, p=self.dropout) - # conditioning image + # conditioning image embeddingを結合 / combine conditioning image embedding cx = self.cond_emb + if not self.batch_cond_only and lx.shape[0] // 2 == cx.shape[0]: # inference only cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1) if self.use_zeros_for_batch_uncond: cx[0::2] = 0.0 # uncond is zero # print(f"C {self.lora_name}, lx.shape={lx.shape}, cx.shape={cx.shape}") + # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している + # we expect that it will mix well by combining in the channel direction instead of adding cx = torch.cat([cx, lx], dim=1 if self.is_conv2d else 2) cx = self.conditioning2(cx) - lx = lx + cx + lx = lx + cx # lxはresidual的に加算される / lx is added residually + + # LoRA-Up lx = self.lora_up(lx) + # call original module x = self.org_forward(x) + # add LoRA if self.batch_cond_only: x[1::2] += lx * self.multiplier * self.scale else: @@ -127,6 +167,7 @@ class LoRAControlNet(torch.nn.Module): is_conv2d = child_module.__class__.__name__ == "Conv2d" if is_linear or (is_conv2d and not SKIP_CONV2D): + # block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う # block index to depth: depth is using to calculate conditioning size and channels block_name, index1, index2 = (name + "." + child_name).split(".")[:3] index1 = int(index1) @@ -155,7 +196,10 @@ class LoRAControlNet(torch.nn.Module): if tf_index > TRANSFORMER_MAX_BLOCK_INDEX: continue - # skip time emb or clip emb + # time embは適用外とする + # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない + # time emb is not applied + # attn2 conditioning (input from CLIP) cannot be applied because the shape is different if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)): continue @@ -191,8 +235,22 @@ class LoRAControlNet(torch.nn.Module): print(f"create ControlNet LoRA for U-Net: {len(self.unet_loras)} modules.") # conditioning image embedding + + # control画像そのままではLoRA的モジュールの入力にはサイズもチャネルも扱いにくいので、 + # 適切な潜在空間に変換する。ここでは、conditioning image embeddingと呼ぶ + # ただcontrol画像自体にはあまり情報量はないので、conditioning image embeddingはわりと小さくてよいはず + # また、conditioning image embeddingは、各LoRA的モジュールでさらに個別に学習する + # depthに応じて3つのサイズを用意する + + # conditioning image embedding is converted to an appropriate latent space + # because the size and channels of the input to the LoRA-like module are difficult to handle + # we call it conditioning image embedding + # however, the control image itself does not have much information, so the conditioning image embedding should be small + # conditioning image embedding is also learned individually in each LoRA-like module + # prepare three sizes according to depth + self.cond_block0 = torch.nn.Sequential( - torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0), # to latent size + torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0), # to latent (from VAE) size torch.nn.ReLU(inplace=True), torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=3, stride=2, padding=1), torch.nn.ReLU(inplace=True), @@ -216,7 +274,7 @@ class LoRAControlNet(torch.nn.Module): x = self.cond_block2(x) x2 = x - x_3d = [] + x_3d = [] # for Linear for x0 in [x0, x1, x2]: # b,c,h,w -> b,h*w,c n, c, h, w = x0.shape @@ -226,6 +284,10 @@ class LoRAControlNet(torch.nn.Module): return [x0, x1, x2], x_3d def set_cond_embs(self, cond_embs_4d, cond_embs_3d): + r""" + 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む + / call the model inside, so if necessary, surround it with torch.no_grad() + """ for lora in self.unet_loras: lora.set_cond_embs(cond_embs_4d, cond_embs_3d) @@ -295,6 +357,9 @@ class LoRAControlNet(torch.nn.Module): if __name__ == "__main__": + # デバッグ用 / for debug + + # これを指定しないとエラーが出てcond_blockが学習できない / if not specified, an error occurs and cond_block cannot be learned sdxl_original_unet.USE_REENTRANT = False # test shape etc @@ -303,7 +368,7 @@ if __name__ == "__main__": unet.to("cuda").to(torch.float16) print("create LoRA controlnet") - control_net = LoRAControlNet(unet, 128, 64, 1) + control_net = LoRAControlNet(unet, 64, 16, 1) control_net.apply_to() control_net.to("cuda") @@ -329,7 +394,7 @@ if __name__ == "__main__": # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters())) # print("render") # image.format = "svg" # "png" - # image.render("NeuralNet") + # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time # input() import bitsandbytes diff --git a/sdxl_train_lora_control_net.py b/sdxl_train_lora_control_net.py index 92469add..e0ec3a6a 100644 --- a/sdxl_train_lora_control_net.py +++ b/sdxl_train_lora_control_net.py @@ -401,8 +401,13 @@ def train(args): controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) with accelerator.autocast(): + # conditioning image embeddingを計算する / calculate conditioning image embedding cond_embs_4d, cond_embs_3d = network(controlnet_image) + + # 個別のLoRA的モジュールでさらにembeddingを計算する / calculate embedding in each LoRA-like module network.set_cond_embs(cond_embs_4d, cond_embs_3d) + + # それらの値を使いつつ、U-Netでノイズを予測する / predict noise with U-Net using those values noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) if args.v_parameterization: @@ -514,262 +519,6 @@ def train(args): print("model saved.") - r""" - progress_bar = tqdm( - range(args.max_train_steps), - smoothing=0, - disable=not accelerator.is_local_main_process, - desc="steps", - ) - global_step = 0 - - noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - clip_sample=False, - ) - if accelerator.is_main_process: - init_kwargs = {} - if args.log_tracker_config is not None: - init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs - ) - - loss_list = [] - loss_total = 0.0 - del train_dataset_group - - # function for saving/removing - def save_model(ckpt_name, model, force_sync_upload=False): - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, ckpt_name) - - accelerator.print(f"\nsaving checkpoint: {ckpt_file}") - - state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict()) - - if save_dtype is not None: - for key in list(state_dict.keys()): - v = state_dict[key] - v = v.detach().clone().to("cpu").to(save_dtype) - state_dict[key] = v - - if os.path.splitext(ckpt_file)[1] == ".safetensors": - from safetensors.torch import save_file - - save_file(state_dict, ckpt_file) - else: - torch.save(state_dict, ckpt_file) - - if args.huggingface_repo_id is not None: - huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) - - def remove_model(old_ckpt_name): - old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) - if os.path.exists(old_ckpt_file): - accelerator.print(f"removing old checkpoint: {old_ckpt_file}") - os.remove(old_ckpt_file) - - # training loop - for epoch in range(num_train_epochs): - if is_main_process: - accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch + 1 - - for step, batch in enumerate(train_dataloader): - current_step.value = global_step - with accelerator.accumulate(controlnet): - with torch.no_grad(): - if "latents" in batch and batch["latents"] is not None: - latents = batch["latents"].to(accelerator.device) - else: - # latentに変換 - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) - if args.noise_offset: - noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) - elif args.multires_noise_iterations: - noise = pyramid_noise_like( - noise, - latents.device, - args.multires_noise_iterations, - args.multires_noise_discount, - ) - - # Sample a random timestep for each image - timesteps = torch.randint( - 0, - noise_scheduler.config.num_train_timesteps, - (b_size,), - device=latents.device, - ) - timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype) - - with accelerator.autocast(): - down_block_res_samples, mid_block_res_sample = controlnet( - noisy_latents, - timesteps, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=controlnet_image, - return_dict=False, - ) - - # Predict the noise residual - noise_pred = unet( - noisy_latents, - timesteps, - encoder_hidden_states, - down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples], - mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), - ).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights - - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = controlnet.parameters() - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - train_util.sample_images( - accelerator, - args, - None, - global_step, - accelerator.device, - vae, - tokenizer, - text_encoder, - unet, - controlnet=controlnet, - ) - - # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: - accelerator.wait_for_everyone() - if accelerator.is_main_process: - ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) - save_model( - ckpt_name, - accelerator.unwrap_model(controlnet), - ) - - if args.save_state: - train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - - remove_step_no = train_util.get_remove_step_no(args, global_step) - if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) - remove_model(remove_ckpt_name) - - current_loss = loss.detach().item() - if epoch == 0: - loss_list.append(current_loss) - else: - loss_total -= loss_list[step] - loss_list[step] = current_loss - loss_total += current_loss - avr_loss = loss_total / len(loss_list) - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if args.logging_dir is not None: - logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) - accelerator.log(logs, step=global_step) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {"loss/epoch": loss_total / len(loss_list)} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - # 指定エポックごとにモデルを保存 - if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs - if is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) - save_model(ckpt_name, accelerator.unwrap_model(controlnet)) - - remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) - if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) - remove_model(remove_ckpt_name) - - if args.save_state: - train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) - - train_util.sample_images( - accelerator, - args, - epoch + 1, - global_step, - accelerator.device, - vae, - tokenizer, - text_encoder, - unet, - controlnet=controlnet, - ) - - # end of epoch - if is_main_process: - controlnet = accelerator.unwrap_model(controlnet) - - accelerator.end_training() - - if is_main_process and args.save_state: - train_util.save_state_on_train_end(args, accelerator) - - # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく - - if is_main_process: - ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model(ckpt_name, controlnet, force_sync_upload=True) - - print("model saved.") - """ - def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser()