From ec2efe52e45caca863005505b95f5f4a575e1246 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 3 Jun 2023 10:52:22 +0900 Subject: [PATCH] scale v-pred loss like noise pred --- fine_tune.py | 21 +++++++++++++++++---- library/custom_train_functions.py | 26 ++++++++++++++++++++++++-- library/train_util.py | 7 ++++++- train_db.py | 5 +++++ train_network.py | 9 +++++++-- train_textual_inversion.py | 17 +++++++++++++---- train_textual_inversion_XTI.py | 11 +++++++---- 7 files changed, 79 insertions(+), 17 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 154d3be7..201d4952 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -21,7 +21,14 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset +from library.custom_train_functions import ( + apply_snr_weight, + get_weighted_text_embeddings, + prepare_scheduler_for_custom_training, + pyramid_noise_like, + apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, +) def train(args): @@ -261,6 +268,7 @@ def train(args): noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) if accelerator.is_main_process: accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name) @@ -327,11 +335,16 @@ def train(args): else: target = noise - if args.min_snr_gamma: - # do not mean over batch dimension for snr weight + if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred: + # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + loss = loss.mean() # mean over batch dimension else: loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index f32f050e..9d0dc402 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -5,20 +5,37 @@ import re from typing import List, Optional, Union -def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): +def prepare_scheduler_for_custom_training(noise_scheduler, device): + if hasattr(noise_scheduler, "all_snr"): + return + alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) alpha = sqrt_alphas_cumprod sigma = sqrt_one_minus_alphas_cumprod all_snr = (alpha / sigma) ** 2 - snr = torch.stack([all_snr[t] for t in timesteps]) + + noise_scheduler.all_snr = all_snr.to(device) + + +def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): + snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() # from paper loss = loss * snr_weight return loss +def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 + scale = snr_t / (snr_t + 1) + + loss = loss * scale + return loss + + # TODO train_utilと分散しているのでどちらかに寄せる @@ -29,6 +46,11 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨", ) + parser.add_argument( + "--scale_v_pred_loss_like_noise_pred", + action="store_true", + help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする", + ) if support_weighted_captions: parser.add_argument( "--weighted_captions", diff --git a/library/train_util.py b/library/train_util.py index 46c5c3b2..844faca7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2311,6 +2311,11 @@ def verify_training_args(args: argparse.Namespace): if args.adaptive_noise_scale is not None and args.noise_offset is None: raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です") + if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization: + raise ValueError( + "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" + ) + def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool @@ -3638,4 +3643,4 @@ class collater_class: # set epoch and step dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) - return examples[0] \ No newline at end of file + return examples[0] diff --git a/train_db.py b/train_db.py index 7ec06354..c81a092d 100644 --- a/train_db.py +++ b/train_db.py @@ -26,8 +26,10 @@ import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, get_weighted_text_embeddings, + prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, ) # perlin_noise, @@ -240,6 +242,7 @@ def train(args): noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) if accelerator.is_main_process: accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name) @@ -327,6 +330,8 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_network.py b/train_network.py index cd90b0a2..32258e88 100644 --- a/train_network.py +++ b/train_network.py @@ -28,9 +28,11 @@ import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, get_weighted_text_embeddings, + prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, max_norm, + scale_v_prediction_loss_like_noise_prediction, ) @@ -316,7 +318,7 @@ def train(args): network.prepare_grad_etc(text_encoder, unet) - if not cache_latents: + if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する vae.requires_grad_(False) vae.eval() vae.to(accelerator.device, dtype=weight_dtype) @@ -554,6 +556,8 @@ def train(args): noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if accelerator.is_main_process: accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name) @@ -658,6 +662,8 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし @@ -840,7 +846,6 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率", ) - return parser diff --git a/train_textual_inversion.py b/train_textual_inversion.py index b73027de..8be0703d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -20,7 +20,13 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, pyramid_noise_like, apply_noise_offset +from library.custom_train_functions import ( + apply_snr_weight, + prepare_scheduler_for_custom_training, + pyramid_noise_like, + apply_noise_offset, + scale_v_prediction_loss_like_noise_prediction, +) imagenet_templates_small = [ "a photo of a {}", @@ -338,6 +344,7 @@ def train(args): noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) if accelerator.is_main_process: accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) @@ -412,12 +419,14 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 8c8f7e8b..7b734f28 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -20,7 +20,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, pyramid_noise_like, apply_noise_offset +from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI imagenet_templates_small = [ @@ -372,6 +372,7 @@ def train(args): noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) if accelerator.is_main_process: accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) @@ -451,11 +452,13 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) + loss_weights = batch["loss_weights"] # 各sampleごとのweight + + loss = loss * loss_weights if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - - loss_weights = batch["loss_weights"] # 各sampleごとのweight - loss = loss * loss_weights + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし