From 11e8c7d8ffa322ea8af50a4822afcbb4b094cccb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Jun 2023 09:35:33 +0900 Subject: [PATCH] fix to work controlnet training --- library/original_unet.py | 13 +++ train_controlnet.py | 215 +++++++++++++++++++-------------------- 2 files changed, 118 insertions(+), 110 deletions(-) diff --git a/library/original_unet.py b/library/original_unet.py index 94d11290..c0028ddc 100644 --- a/library/original_unet.py +++ b/library/original_unet.py @@ -1468,6 +1468,8 @@ class UNet2DConditionModel(nn.Module): encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, return_dict: bool = True, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, ) -> Union[Dict, Tuple]: r""" Args: @@ -1533,9 +1535,20 @@ class UNet2DConditionModel(nn.Module): down_block_res_samples += res_samples + # skip connectionにControlNetの出力を追加する + if down_block_additional_residuals is not None: + down_block_res_samples = list(down_block_res_samples) + for i in range(len(down_block_res_samples)): + down_block_res_samples[i] += down_block_additional_residuals[i] + down_block_res_samples = tuple(down_block_res_samples) + # 4. mid sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) + # ControlNetの出力を追加する + if mid_block_additional_residual is not None: + sample += mid_block_additional_residual + # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 diff --git a/train_controlnet.py b/train_controlnet.py index 6e4e5bb8..39ac43e9 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -6,6 +6,7 @@ import os import random import time from multiprocessing import Value +from types import SimpleNamespace from tqdm import tqdm import torch @@ -39,17 +40,14 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche } if args.optimizer_type.lower().startswith("DAdapt".lower()): - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[-1].param_groups[0]["d"] - * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - ) + logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] return logs def train(args): - session_id = random.randint(0, 2**32) - training_started_at = time.time() + # session_id = random.randint(0, 2**32) + # training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) @@ -88,15 +86,11 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - train_dataset_group = config_util.generate_dataset_group_by_blueprint( - blueprint.dataset_group - ) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value("i", 0) current_step = Value("i", 0) - ds_for_collater = ( - train_dataset_group if args.max_data_loader_n_workers == 0 else None - ) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) if args.debug_dataset: @@ -115,7 +109,7 @@ def train(args): # acceleratorを準備する print("prepare accelerator") - accelerator, unwrap_model = train_util.prepare_accelerator(args) + accelerator = train_util.prepare_accelerator(args) is_main_process = accelerator.is_main_process # mixed precisionに対応した型を用意しておき適宜castする @@ -126,6 +120,69 @@ def train(args): args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True ) + # DiffusersのControlNetが使用するデータを準備する + if args.v2: + unet.config = { + "act_fn": "silu", + "attention_head_dim": [5, 10, 20, 20], + "block_out_channels": [320, 640, 1280, 1280], + "center_input_sample": False, + "cross_attention_dim": 1024, + "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], + "downsample_padding": 1, + "dual_cross_attention": False, + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_class_embeds": None, + "only_cross_attention": False, + "out_channels": 4, + "sample_size": 96, + "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], + "use_linear_projection": True, + "upcast_attention": True, + "only_cross_attention": False, + "downsample_padding": 1, + "use_linear_projection": True, + "class_embed_type": None, + "num_class_embeds": None, + "resnet_time_scale_shift": "default", + "projection_class_embeddings_input_dim": None, + } + else: + unet.config = { + "act_fn": "silu", + "attention_head_dim": 8, + "block_out_channels": [320, 640, 1280, 1280], + "center_input_sample": False, + "cross_attention_dim": 768, + "down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"], + "downsample_padding": 1, + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "out_channels": 4, + "sample_size": 64, + "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], + "only_cross_attention": False, + "downsample_padding": 1, + "use_linear_projection": False, + "class_embed_type": None, + "num_class_embeds": None, + "upcast_attention": False, + "resnet_time_scale_shift": "default", + "projection_class_embeddings_input_dim": None, + } + unet.config = SimpleNamespace(**unet.config) + controlnet = ControlNetModel.from_unet(unet) if args.controlnet_model_name_or_path: @@ -140,9 +197,8 @@ def train(args): elif os.path.isdir(filename): controlnet = ControlNetModel.from_pretrained(filename) - # モデルに xformers とか memory efficient attention を組み込む - train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) # 学習を準備する if cache_latents: @@ -171,15 +227,11 @@ def train(args): trainable_params = controlnet.parameters() - _, _, optimizer = train_util.get_optimizer( - args, trainable_params - ) + _, _, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min( - args.max_data_loader_n_workers, os.cpu_count() - 1 - ) # cpu_count-1 ただし最大で指定された数まで + n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( train_dataset_group, @@ -193,21 +245,15 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( - len(train_dataloader) - / accelerator.num_processes - / args.gradient_accumulation_steps - ) - accelerator.print( - f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}" + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) + accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix( - args, optimizer, accelerator.num_processes - ) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -245,31 +291,21 @@ def train(args): train_util.resume_from_local_or_hf_if_specified(accelerator, args) # epoch数を計算する - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): - args.save_every_n_epochs = ( - math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 - ) + args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する # TODO: find a way to handle total batch size when there are multiple datasets accelerator.print("running training / 学習開始") - accelerator.print( - f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}" - ) + accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print( - f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" - ) + accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") # print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") - accelerator.print( - f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}" - ) + accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm( @@ -288,11 +324,7 @@ def train(args): clip_sample=False, ) if accelerator.is_main_process: - accelerator.init_trackers( - "controlnet_train" - if args.log_tracker_name is None - else args.log_tracker_name - ) + accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name) loss_list = [] loss_total = 0.0 @@ -321,9 +353,7 @@ def train(args): torch.save(state_dict, ckpt_file) if args.huggingface_repo_id is not None: - huggingface_util.upload( - args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload - ) + huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload) def remove_model(old_ckpt_name): old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) @@ -345,23 +375,17 @@ def train(args): latents = batch["latents"].to(accelerator.device) else: # latentに変換 - latents = vae.encode( - batch["images"].to(dtype=weight_dtype) - ).latent_dist.sample() + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 b_size = latents.shape[0] input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, weight_dtype - ) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: - noise = apply_noise_offset( - latents, noise, args.noise_offset, args.adaptive_noise_scale - ) + noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) elif args.multires_noise_iterations: noise = pyramid_noise_like( noise, @@ -398,13 +422,8 @@ def train(args): noisy_latents, timesteps, encoder_hidden_states, - down_block_additional_residuals=[ - sample.to(dtype=weight_dtype) - for sample in down_block_res_samples - ], - mid_block_additional_residual=mid_block_res_sample.to( - dtype=weight_dtype - ), + down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), ).sample if args.v_parameterization: @@ -413,18 +432,14 @@ def train(args): else: target = noise - loss = torch.nn.functional.mse_loss( - noise_pred.float(), target.float(), reduction="none" - ) + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight( - loss, timesteps, noise_scheduler, args.min_snr_gamma - ) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -456,31 +471,21 @@ def train(args): ) # 指定ステップごとにモデルを保存 - if ( - args.save_every_n_steps is not None - and global_step % args.save_every_n_steps == 0 - ): + if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: - ckpt_name = train_util.get_step_ckpt_name( - args, "." + args.save_model_as, global_step - ) + ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) save_model( - ckpt_name, unwrap_model(controlnet), + ckpt_name, + accelerator.unwrap_model(controlnet), ) if args.save_state: - train_util.save_and_remove_state_stepwise( - args, accelerator, global_step - ) + train_util.save_and_remove_state_stepwise(args, accelerator, global_step) - remove_step_no = train_util.get_remove_step_no( - args, global_step - ) + remove_step_no = train_util.get_remove_step_no(args, global_step) if remove_step_no is not None: - remove_ckpt_name = train_util.get_step_ckpt_name( - args, "." + args.save_model_as, remove_step_no - ) + remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) remove_model(remove_ckpt_name) current_loss = loss.detach().item() @@ -509,26 +514,18 @@ def train(args): # 指定エポックごとにモデルを保存 if args.save_every_n_epochs is not None: - saving = (epoch + 1) % args.save_every_n_epochs == 0 and ( - epoch + 1 - ) < num_train_epochs + saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs if is_main_process and saving: - ckpt_name = train_util.get_epoch_ckpt_name( - args, "." + args.save_model_as, epoch + 1 - ) - save_model(ckpt_name, unwrap_model(controlnet)) + ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1) + save_model(ckpt_name, accelerator.unwrap_model(controlnet)) remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1) if remove_epoch_no is not None: - remove_ckpt_name = train_util.get_epoch_ckpt_name( - args, "." + args.save_model_as, remove_epoch_no - ) + remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no) remove_model(remove_ckpt_name) if args.save_state: - train_util.save_and_remove_state_on_epoch_end( - args, accelerator, epoch + 1 - ) + train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1) train_util.sample_images( accelerator, @@ -545,20 +542,18 @@ def train(args): # end of epoch if is_main_process: - controlnet = unwrap_model(controlnet) + controlnet = accelerator.unwrap_model(controlnet) accelerator.end_training() if is_main_process and args.save_state: train_util.save_state_on_train_end(args, accelerator) - del accelerator # この後メモリを使うのでこれは消す + # del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく if is_main_process: ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as) - save_model( - ckpt_name, controlnet, force_sync_upload=True - ) + save_model(ckpt_name, controlnet, force_sync_upload=True) print("model saved.")