mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 09:18:00 +00:00
Added first version of out-of-tolerance latent
std/mean detection code.
This commit is contained in:
@@ -753,6 +753,16 @@ class NetworkTrainer:
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# Warn user if any latents have mean values that are further than a theshold level away
|
||||
# from 0.0, or that have standard deviations outside a threshold scale from 1.0.
|
||||
if args.latent_threshold_warn_levels is not None:
|
||||
# (Flux only for now, but this could be updated to support e.g. SDXL or SD3)
|
||||
from library.flux_train_utils import check_latent_means_and_stds_against_thresholds
|
||||
check_latent_means_and_stds_against_thresholds(
|
||||
args.latent_threshold_warn_levels,
|
||||
args.latent_threshold_visualizer,
|
||||
train_dataset_group.image_data)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
|
||||
Reference in New Issue
Block a user