scale v-pred loss like noise pred

This commit is contained in:
Kohya S
2023-06-03 10:52:22 +09:00
parent 0f0158ddaa
commit ec2efe52e4
7 changed files with 79 additions and 17 deletions

View File

@@ -28,9 +28,11 @@ import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
pyramid_noise_like,
apply_noise_offset,
max_norm,
scale_v_prediction_loss_like_noise_prediction,
)
@@ -316,7 +318,7 @@ def train(args):
network.prepare_grad_etc(text_encoder, unet)
if not cache_latents:
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
@@ -554,6 +556,8 @@ def train(args):
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if accelerator.is_main_process:
accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name)
@@ -658,6 +662,8 @@ def train(args):
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -840,7 +846,6 @@ def setup_parser() -> argparse.ArgumentParser:
nargs="*",
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
)
return parser