Added first version of out-of-tolerance latent

std/mean detection code.
This commit is contained in:
araleza
2025-03-27 07:45:05 +00:00
parent 8ebe858f89
commit 98f3afe60e
2 changed files with 158 additions and 0 deletions

View File

@@ -753,6 +753,16 @@ class NetworkTrainer:
persistent_workers=args.persistent_data_loader_workers,
)
# Warn user if any latents have mean values that are further than a theshold level away
# from 0.0, or that have standard deviations outside a threshold scale from 1.0.
if args.latent_threshold_warn_levels is not None:
# (Flux only for now, but this could be updated to support e.g. SDXL or SD3)
from library.flux_train_utils import check_latent_means_and_stds_against_thresholds
check_latent_means_and_stds_against_thresholds(
args.latent_threshold_warn_levels,
args.latent_threshold_visualizer,
train_dataset_group.image_data)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(