mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 16:22:28 +00:00
Added first version of out-of-tolerance latent
std/mean detection code.
This commit is contained in:
@@ -550,6 +550,137 @@ def save_flux_model_on_epoch_end_or_stepwise(
|
||||
|
||||
# endregion
|
||||
|
||||
# region Latent analysis
|
||||
|
||||
def check_latent_means_and_stds_against_thresholds(thresholds_string, latent_threshold_visualizer, image_data):
|
||||
|
||||
# Skip mean/std check?
|
||||
if thresholds_string.lower() == "disable":
|
||||
return
|
||||
|
||||
# Thresholds should be in 'mean,std' format; split on comma
|
||||
parts = thresholds_string.split(',')
|
||||
if len(parts) != 2:
|
||||
logger.error(f"latent_threshold_warn_levels was set to '{thresholds_string}', "
|
||||
"Expected latent threshold warning string to either be in 'mean,std' format, or 'disable'")
|
||||
return
|
||||
|
||||
mean_thresh = float(parts[0].strip()) # Magnitude
|
||||
std_thresh_max = float(parts[1].strip()) # This threshold is the max value. The min value of 1.0 / std_thresh is also tested against
|
||||
|
||||
if std_thresh_max < 1.0:
|
||||
logger.error("Expected std threshold warning level to be >= 1.0. (Std values are "
|
||||
"automatically checked against a lower bound of '1.0 / std threshold warning level')")
|
||||
return
|
||||
|
||||
std_thresh_min = 1.0 / std_thresh_max
|
||||
|
||||
# Start forming a list of results, one for each latent
|
||||
results = []
|
||||
|
||||
# Load and check each latent in turn
|
||||
logger.info('Checking latent means/stds:')
|
||||
for image_filename in tqdm(image_data):
|
||||
image_info = image_data[image_filename]
|
||||
|
||||
# Load the latent
|
||||
with np.load(image_info.latents_npz) as latents:
|
||||
|
||||
latent_name = f'latents_{image_info.bucket_reso[1] // 8}x{image_info.bucket_reso[0] // 8}.npy'
|
||||
latent = latents[latent_name] # Only checking the unflipped latent
|
||||
|
||||
image_filename_no_path = image_filename.rsplit('/', 1)[-1]
|
||||
|
||||
# Check mean
|
||||
mean = np.average(latent)
|
||||
if mean < -mean_thresh or mean > mean_thresh:
|
||||
warn_mean = abs(mean - mean_thresh) # Out of tolerance
|
||||
else:
|
||||
warn_mean = 0 # Passed mean check
|
||||
|
||||
# Check std
|
||||
std = np.std(latent)
|
||||
if std < std_thresh_min or std > std_thresh_max:
|
||||
if std > std_thresh_max:
|
||||
# log base 2 is not necessarily the ideal function, but hopefully it'll roughly
|
||||
# balance an out-of-threshold std against an out-of-threshold mean in terms
|
||||
# of magnitude
|
||||
warn_std = math.log(std / std_thresh_max, 2) # Out of tolerance (too large)
|
||||
else:
|
||||
warn_std = math.log(std_thresh_min / std, 2) # Out of tolerance (too small)
|
||||
else:
|
||||
warn_std = 0 # Passed std check
|
||||
|
||||
# The first element is how notable this latent's mean and std is considered to be
|
||||
# for the list of 'most out-of-threshold results' to warn about
|
||||
results += [[warn_mean + warn_std, mean, std, image_filename_no_path, image_filename]]
|
||||
|
||||
# Sort the results into order of most notably out of threshold first
|
||||
results.sort(key=lambda x: -x[0])
|
||||
|
||||
# List a few test failure image results
|
||||
|
||||
for i, result in enumerate(results):
|
||||
|
||||
if i >= 3: # Three results maximum
|
||||
break
|
||||
if result[0] == 0.0: # Fewer than 3 images that did not pass?
|
||||
break
|
||||
|
||||
if i == 0:
|
||||
logger.warning("Images are being trained on that have out-of-tolerance latent mean or std values. "
|
||||
"Training may improve if these images are changed/deleted. Remember to delete the images' _flux.npz "
|
||||
"files by hand if you modify the image, as they will not necessarily be automatically regenerated. "
|
||||
"Here is a list (of up to three) out-of-tolerance images: (Consider using --latent_threshold_visualizer "
|
||||
"to diagnose)")
|
||||
|
||||
print(f'Mean,std = [{result[1]:.3f}, {result[2]:.3f}]: {result[3]}')
|
||||
|
||||
# Show one latent test failure result visually in a window?
|
||||
if results[0][0] > 0.0 and latent_threshold_visualizer:
|
||||
# Re-fetch the latent for the 'worst' test fail latent
|
||||
image_info = image_data[results[0][4]] # Get image_info by image filename
|
||||
with np.load(image_info.latents_npz) as latents:
|
||||
latent_name = f'latents_{image_info.bucket_reso[1] // 8}x{image_info.bucket_reso[0] // 8}.npy'
|
||||
latent = latents[latent_name] # Only show the unflipped latent
|
||||
|
||||
# Average the latent's 16 channels together and clip to some reasonable range for the Flux AE.
|
||||
averaged = np.mean(latent, axis=0)
|
||||
averaged_clipped = np.clip(averaged, -1, 1)
|
||||
|
||||
rgb = np.zeros((latent.shape[1], latent.shape[2], 3), dtype=np.uint8)
|
||||
|
||||
# For negative values: Blue (fade from black to full blue)
|
||||
mask_neg = averaged_clipped < 0
|
||||
blue_intensity = (np.abs(averaged_clipped[mask_neg]) * 255).astype(np.uint8)
|
||||
rgb[mask_neg, 2] = blue_intensity
|
||||
|
||||
# For positive values: Red (fade from black to full red)
|
||||
mask_pos = averaged_clipped > 0
|
||||
red_intensity = (averaged_clipped[mask_pos] * 255).astype(np.uint8)
|
||||
rgb[mask_pos, 0] = red_intensity
|
||||
|
||||
# Scale up 8x both for clarity and to match the original image size
|
||||
import cv2
|
||||
scale_factor = 8
|
||||
scaled_rgb = cv2.resize(
|
||||
rgb,
|
||||
(rgb.shape[1] * scale_factor, rgb.shape[0] * scale_factor), # (width, height)
|
||||
interpolation=cv2.INTER_NEAREST
|
||||
)
|
||||
|
||||
# Show the latent average image
|
||||
window_name = f"{results[0][3]}: blue -'ve, red +'ve."
|
||||
cv2.imshow(window_name, cv2.cvtColor(scaled_rgb, cv2.COLOR_RGB2BGR))
|
||||
while True: # Wait until window is closed or escape key is pressed
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
if key == 27 or cv2.getWindowProperty(window_name, cv2.WND_PROP_VISIBLE) < 1:
|
||||
break
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
pass
|
||||
|
||||
# endregion
|
||||
|
||||
def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
@@ -617,3 +748,20 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
default=3.0,
|
||||
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
||||
)
|
||||
|
||||
# Latent mean/std analysis tools
|
||||
parser.add_argument(
|
||||
"--latent_threshold_warn_levels",
|
||||
type=str,
|
||||
default="0.16,1.35",
|
||||
help='Flux may train better if the training latents have a mean of 0.0 and an std of 1.0. '
|
||||
'Set this parameter to "mean_thresh,std_thresh" to warn if tolerances are exceeded, or "disabled" to skip checks. '
|
||||
'Mean is tested to be in [-mean_thresh..+mean_thresh] range, std in [1.0/std_thresh..std_thresh] range'
|
||||
)
|
||||
parser.add_argument(
|
||||
"--latent_threshold_visualizer",
|
||||
action="store_true",
|
||||
help="If --latent_threshold_warn_levels detects at least one out-of-threshold latent, one of them is "
|
||||
"shown on screen with red/blue blocks to show +'ve / -'ve latent values. This can help to identify "
|
||||
"why this image has mean and std values that differ significantly from 0.0 and 1.0"
|
||||
)
|
||||
|
||||
@@ -753,6 +753,16 @@ class NetworkTrainer:
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# Warn user if any latents have mean values that are further than a theshold level away
|
||||
# from 0.0, or that have standard deviations outside a threshold scale from 1.0.
|
||||
if args.latent_threshold_warn_levels is not None:
|
||||
# (Flux only for now, but this could be updated to support e.g. SDXL or SD3)
|
||||
from library.flux_train_utils import check_latent_means_and_stds_against_thresholds
|
||||
check_latent_means_and_stds_against_thresholds(
|
||||
args.latent_threshold_warn_levels,
|
||||
args.latent_threshold_visualizer,
|
||||
train_dataset_group.image_data)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
|
||||
Reference in New Issue
Block a user