diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 5b6106fb..677d1bf4 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -18,6 +18,7 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device): noise_scheduler.all_snr = all_snr.to(device) + def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): # fix beta: zero terminal SNR print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") @@ -55,6 +56,7 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): noise_scheduler.alphas = alphas noise_scheduler.alphas_cumprod = alphas_cumprod + def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) @@ -64,11 +66,24 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler): + scale = get_snr_scale(timesteps, noise_scheduler) + loss = loss * scale + return loss + + +def get_snr_scale(timesteps, noise_scheduler): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 scale = snr_t / (snr_t + 1) + # # show debug info + # print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}") + return scale - loss = loss * scale + +def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss): + scale = get_snr_scale(timesteps, noise_scheduler) + # print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") + loss = loss + loss / scale * v_pred_like_loss return loss @@ -87,6 +102,12 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted action="store_true", help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする", ) + parser.add_argument( + "--v_pred_like_loss", + type=float, + default=None, + help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する", + ) if support_weighted_captions: parser.add_argument( "--weighted_captions", diff --git a/library/train_util.py b/library/train_util.py index 7f9fd75a..1353173e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2868,6 +2868,11 @@ def verify_training_args(args: argparse.Namespace): raise ValueError( "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" ) + + if args.v_pred_like_loss and args.v_parameterization: + raise ValueError( + "v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません" + ) if args.zero_terminal_snr and not args.v_parameterization: print( diff --git a/sdxl_train.py b/sdxl_train.py index f5084a42..b57e2f5c 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -24,9 +24,8 @@ import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, prepare_scheduler_for_custom_training, - pyramid_noise_like, - apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, + add_v_prediction_like_loss, ) from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -175,7 +174,7 @@ def train(args): # Windows版のxformersはfloatで学習できなかったりするのでxformersを使わない設定も可能にしておく必要がある accelerator.print("Disable Diffusers' xformers") train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える vae.set_use_memory_efficient_attention_xformers(args.xformers) # 学習を準備する @@ -338,9 +337,7 @@ def train(args): accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") - accelerator.print( - f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" - ) + accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}") # accelerator.print( # f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" # ) @@ -459,13 +456,17 @@ def train(args): target = noise - if args.min_snr_gamma: + if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss: # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) loss = loss.mean() # mean over batch dimension else: diff --git a/train_network.py b/train_network.py index 6f41d199..e296d72b 100644 --- a/train_network.py +++ b/train_network.py @@ -31,9 +31,8 @@ from library.custom_train_functions import ( apply_snr_weight, get_weighted_text_embeddings, prepare_scheduler_for_custom_training, - pyramid_noise_like, - apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, + add_v_prediction_like_loss, ) @@ -792,6 +791,8 @@ class NetworkTrainer: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion.py b/train_textual_inversion.py index d3ff6456..300afa3e 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -24,6 +24,7 @@ from library.custom_train_functions import ( apply_snr_weight, prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, + add_v_prediction_like_loss, ) imagenet_templates_small = [ @@ -566,6 +567,8 @@ class TextualInversionTrainer: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし