mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Debias Estimation loss (#889)
* update for bnb 0.41.1 * fixed generate_controlnet_subsets_config for training * Revert "update for bnb 0.41.1" This reverts commit70bd3612d8. * add debiased_estimation_loss * add train_network * Revert "add train_network" This reverts commit6539363c5c. * Update train_network.py
This commit is contained in:
@@ -32,6 +32,7 @@ from library.custom_train_functions import (
|
|||||||
get_weighted_text_embeddings,
|
get_weighted_text_embeddings,
|
||||||
prepare_scheduler_for_custom_training,
|
prepare_scheduler_for_custom_training,
|
||||||
scale_v_prediction_loss_like_noise_prediction,
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
|
apply_debiased_estimation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -339,7 +340,7 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred:
|
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss,:
|
||||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
@@ -348,6 +349,8 @@ def train(args):
|
|||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
|
if args.debiased_estimation_loss:
|
||||||
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||||
|
|
||||||
loss = loss.mean() # mean over batch dimension
|
loss = loss.mean() # mean over batch dimension
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -86,6 +86,12 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los
|
|||||||
loss = loss + loss / scale * v_pred_like_loss
|
loss = loss + loss / scale * v_pred_like_loss
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
|
||||||
|
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||||
|
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||||
|
weight = 1/torch.sqrt(snr_t)
|
||||||
|
loss = weight * loss
|
||||||
|
return loss
|
||||||
|
|
||||||
# TODO train_utilと分散しているのでどちらかに寄せる
|
# TODO train_utilと分散しているのでどちらかに寄せる
|
||||||
|
|
||||||
@@ -108,6 +114,11 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
|||||||
default=None,
|
default=None,
|
||||||
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
|
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--debiased_estimation_loss",
|
||||||
|
action="store_true",
|
||||||
|
help="debiased estimation loss / debiased estimation loss",
|
||||||
|
)
|
||||||
if support_weighted_captions:
|
if support_weighted_captions:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--weighted_captions",
|
"--weighted_captions",
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from library.custom_train_functions import (
|
|||||||
prepare_scheduler_for_custom_training,
|
prepare_scheduler_for_custom_training,
|
||||||
scale_v_prediction_loss_like_noise_prediction,
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
add_v_prediction_like_loss,
|
add_v_prediction_like_loss,
|
||||||
|
apply_debiased_estimation,
|
||||||
)
|
)
|
||||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||||
|
|
||||||
@@ -548,7 +549,7 @@ def train(args):
|
|||||||
|
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss:
|
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss or args.debiased_estimation_loss:
|
||||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
@@ -559,6 +560,8 @@ def train(args):
|
|||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.v_pred_like_loss:
|
if args.v_pred_like_loss:
|
||||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||||
|
if args.debiased_estimation_loss:
|
||||||
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||||
|
|
||||||
loss = loss.mean() # mean over batch dimension
|
loss = loss.mean() # mean over batch dimension
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ from library.custom_train_functions import (
|
|||||||
pyramid_noise_like,
|
pyramid_noise_like,
|
||||||
apply_noise_offset,
|
apply_noise_offset,
|
||||||
scale_v_prediction_loss_like_noise_prediction,
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
|
apply_debiased_estimation,
|
||||||
)
|
)
|
||||||
import networks.control_net_lllite_for_train as control_net_lllite_for_train
|
import networks.control_net_lllite_for_train as control_net_lllite_for_train
|
||||||
|
|
||||||
@@ -465,6 +466,8 @@ def train(args):
|
|||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.v_pred_like_loss:
|
if args.v_pred_like_loss:
|
||||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||||
|
if args.debiased_estimation_loss:
|
||||||
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ from library.custom_train_functions import (
|
|||||||
pyramid_noise_like,
|
pyramid_noise_like,
|
||||||
apply_noise_offset,
|
apply_noise_offset,
|
||||||
scale_v_prediction_loss_like_noise_prediction,
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
|
apply_debiased_estimation,
|
||||||
)
|
)
|
||||||
import networks.control_net_lllite as control_net_lllite
|
import networks.control_net_lllite as control_net_lllite
|
||||||
|
|
||||||
@@ -435,6 +436,8 @@ def train(args):
|
|||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.v_pred_like_loss:
|
if args.v_pred_like_loss:
|
||||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||||
|
if args.debiased_estimation_loss:
|
||||||
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from library.custom_train_functions import (
|
|||||||
pyramid_noise_like,
|
pyramid_noise_like,
|
||||||
apply_noise_offset,
|
apply_noise_offset,
|
||||||
scale_v_prediction_loss_like_noise_prediction,
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
|
apply_debiased_estimation,
|
||||||
)
|
)
|
||||||
|
|
||||||
# perlin_noise,
|
# perlin_noise,
|
||||||
@@ -336,6 +337,8 @@ def train(args):
|
|||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
|
if args.debiased_estimation_loss:
|
||||||
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from library.custom_train_functions import (
|
|||||||
prepare_scheduler_for_custom_training,
|
prepare_scheduler_for_custom_training,
|
||||||
scale_v_prediction_loss_like_noise_prediction,
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
add_v_prediction_like_loss,
|
add_v_prediction_like_loss,
|
||||||
|
apply_debiased_estimation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -528,6 +529,7 @@ class NetworkTrainer:
|
|||||||
"ss_min_snr_gamma": args.min_snr_gamma,
|
"ss_min_snr_gamma": args.min_snr_gamma,
|
||||||
"ss_scale_weight_norms": args.scale_weight_norms,
|
"ss_scale_weight_norms": args.scale_weight_norms,
|
||||||
"ss_ip_noise_gamma": args.ip_noise_gamma,
|
"ss_ip_noise_gamma": args.ip_noise_gamma,
|
||||||
|
"ss_debiased_estimation": bool(args.debiased_estimation_loss),
|
||||||
}
|
}
|
||||||
|
|
||||||
if use_user_config:
|
if use_user_config:
|
||||||
@@ -811,6 +813,8 @@ class NetworkTrainer:
|
|||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.v_pred_like_loss:
|
if args.v_pred_like_loss:
|
||||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||||
|
if args.debiased_estimation_loss:
|
||||||
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from library.custom_train_functions import (
|
|||||||
prepare_scheduler_for_custom_training,
|
prepare_scheduler_for_custom_training,
|
||||||
scale_v_prediction_loss_like_noise_prediction,
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
add_v_prediction_like_loss,
|
add_v_prediction_like_loss,
|
||||||
|
apply_debiased_estimation,
|
||||||
)
|
)
|
||||||
|
|
||||||
imagenet_templates_small = [
|
imagenet_templates_small = [
|
||||||
@@ -582,6 +583,8 @@ class TextualInversionTrainer:
|
|||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
if args.v_pred_like_loss:
|
if args.v_pred_like_loss:
|
||||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||||
|
if args.debiased_estimation_loss:
|
||||||
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from library.custom_train_functions import (
|
|||||||
pyramid_noise_like,
|
pyramid_noise_like,
|
||||||
apply_noise_offset,
|
apply_noise_offset,
|
||||||
scale_v_prediction_loss_like_noise_prediction,
|
scale_v_prediction_loss_like_noise_prediction,
|
||||||
|
apply_debiased_estimation,
|
||||||
)
|
)
|
||||||
import library.original_unet as original_unet
|
import library.original_unet as original_unet
|
||||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||||
@@ -471,6 +472,8 @@ def train(args):
|
|||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
if args.scale_v_pred_loss_like_noise_pred:
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
|
if args.debiased_estimation_loss:
|
||||||
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user