Added first version of out-of-tolerance latent

std/mean detection code.
This commit is contained in:
araleza
2025-03-27 07:45:05 +00:00
parent 8ebe858f89
commit 98f3afe60e
2 changed files with 158 additions and 0 deletions

View File

@@ -550,6 +550,137 @@ def save_flux_model_on_epoch_end_or_stepwise(
# endregion
# region Latent analysis
def check_latent_means_and_stds_against_thresholds(thresholds_string, latent_threshold_visualizer, image_data):
# Skip mean/std check?
if thresholds_string.lower() == "disable":
return
# Thresholds should be in 'mean,std' format; split on comma
parts = thresholds_string.split(',')
if len(parts) != 2:
logger.error(f"latent_threshold_warn_levels was set to '{thresholds_string}', "
"Expected latent threshold warning string to either be in 'mean,std' format, or 'disable'")
return
mean_thresh = float(parts[0].strip()) # Magnitude
std_thresh_max = float(parts[1].strip()) # This threshold is the max value. The min value of 1.0 / std_thresh is also tested against
if std_thresh_max < 1.0:
logger.error("Expected std threshold warning level to be >= 1.0. (Std values are "
"automatically checked against a lower bound of '1.0 / std threshold warning level')")
return
std_thresh_min = 1.0 / std_thresh_max
# Start forming a list of results, one for each latent
results = []
# Load and check each latent in turn
logger.info('Checking latent means/stds:')
for image_filename in tqdm(image_data):
image_info = image_data[image_filename]
# Load the latent
with np.load(image_info.latents_npz) as latents:
latent_name = f'latents_{image_info.bucket_reso[1] // 8}x{image_info.bucket_reso[0] // 8}.npy'
latent = latents[latent_name] # Only checking the unflipped latent
image_filename_no_path = image_filename.rsplit('/', 1)[-1]
# Check mean
mean = np.average(latent)
if mean < -mean_thresh or mean > mean_thresh:
warn_mean = abs(mean - mean_thresh) # Out of tolerance
else:
warn_mean = 0 # Passed mean check
# Check std
std = np.std(latent)
if std < std_thresh_min or std > std_thresh_max:
if std > std_thresh_max:
# log base 2 is not necessarily the ideal function, but hopefully it'll roughly
# balance an out-of-threshold std against an out-of-threshold mean in terms
# of magnitude
warn_std = math.log(std / std_thresh_max, 2) # Out of tolerance (too large)
else:
warn_std = math.log(std_thresh_min / std, 2) # Out of tolerance (too small)
else:
warn_std = 0 # Passed std check
# The first element is how notable this latent's mean and std is considered to be
# for the list of 'most out-of-threshold results' to warn about
results += [[warn_mean + warn_std, mean, std, image_filename_no_path, image_filename]]
# Sort the results into order of most notably out of threshold first
results.sort(key=lambda x: -x[0])
# List a few test failure image results
for i, result in enumerate(results):
if i >= 3: # Three results maximum
break
if result[0] == 0.0: # Fewer than 3 images that did not pass?
break
if i == 0:
logger.warning("Images are being trained on that have out-of-tolerance latent mean or std values. "
"Training may improve if these images are changed/deleted. Remember to delete the images' _flux.npz "
"files by hand if you modify the image, as they will not necessarily be automatically regenerated. "
"Here is a list (of up to three) out-of-tolerance images: (Consider using --latent_threshold_visualizer "
"to diagnose)")
print(f'Mean,std = [{result[1]:.3f}, {result[2]:.3f}]: {result[3]}')
# Show one latent test failure result visually in a window?
if results[0][0] > 0.0 and latent_threshold_visualizer:
# Re-fetch the latent for the 'worst' test fail latent
image_info = image_data[results[0][4]] # Get image_info by image filename
with np.load(image_info.latents_npz) as latents:
latent_name = f'latents_{image_info.bucket_reso[1] // 8}x{image_info.bucket_reso[0] // 8}.npy'
latent = latents[latent_name] # Only show the unflipped latent
# Average the latent's 16 channels together and clip to some reasonable range for the Flux AE.
averaged = np.mean(latent, axis=0)
averaged_clipped = np.clip(averaged, -1, 1)
rgb = np.zeros((latent.shape[1], latent.shape[2], 3), dtype=np.uint8)
# For negative values: Blue (fade from black to full blue)
mask_neg = averaged_clipped < 0
blue_intensity = (np.abs(averaged_clipped[mask_neg]) * 255).astype(np.uint8)
rgb[mask_neg, 2] = blue_intensity
# For positive values: Red (fade from black to full red)
mask_pos = averaged_clipped > 0
red_intensity = (averaged_clipped[mask_pos] * 255).astype(np.uint8)
rgb[mask_pos, 0] = red_intensity
# Scale up 8x both for clarity and to match the original image size
import cv2
scale_factor = 8
scaled_rgb = cv2.resize(
rgb,
(rgb.shape[1] * scale_factor, rgb.shape[0] * scale_factor), # (width, height)
interpolation=cv2.INTER_NEAREST
)
# Show the latent average image
window_name = f"{results[0][3]}: blue -'ve, red +'ve."
cv2.imshow(window_name, cv2.cvtColor(scaled_rgb, cv2.COLOR_RGB2BGR))
while True: # Wait until window is closed or escape key is pressed
key = cv2.waitKey(1) & 0xFF
if key == 27 or cv2.getWindowProperty(window_name, cv2.WND_PROP_VISIBLE) < 1:
break
cv2.destroyAllWindows()
pass
# endregion
def add_flux_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
@@ -617,3 +748,20 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)
# Latent mean/std analysis tools
parser.add_argument(
"--latent_threshold_warn_levels",
type=str,
default="0.16,1.35",
help='Flux may train better if the training latents have a mean of 0.0 and an std of 1.0. '
'Set this parameter to "mean_thresh,std_thresh" to warn if tolerances are exceeded, or "disabled" to skip checks. '
'Mean is tested to be in [-mean_thresh..+mean_thresh] range, std in [1.0/std_thresh..std_thresh] range'
)
parser.add_argument(
"--latent_threshold_visualizer",
action="store_true",
help="If --latent_threshold_warn_levels detects at least one out-of-threshold latent, one of them is "
"shown on screen with red/blue blocks to show +'ve / -'ve latent values. This can help to identify "
"why this image has mean and std values that differ significantly from 0.0 and 1.0"
)

View File

@@ -753,6 +753,16 @@ class NetworkTrainer:
persistent_workers=args.persistent_data_loader_workers,
)
# Warn user if any latents have mean values that are further than a theshold level away
# from 0.0, or that have standard deviations outside a threshold scale from 1.0.
if args.latent_threshold_warn_levels is not None:
# (Flux only for now, but this could be updated to support e.g. SDXL or SD3)
from library.flux_train_utils import check_latent_means_and_stds_against_thresholds
check_latent_means_and_stds_against_thresholds(
args.latent_threshold_warn_levels,
args.latent_threshold_visualizer,
train_dataset_group.image_data)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(