From 7b9e92a8cc566a6ce51cb0b8efe9e0dca203fb6d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Fri, 11 Apr 2025 20:39:31 -0400 Subject: [PATCH] Fix band weights via toml. Add more logging --- library/train_util.py | 4 ++++ train_network.py | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index e152f30f..4bb2d6c1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4657,6 +4657,10 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar ignore_nesting_dict[section_name] = section_dict continue + if section_name == "wavelet_loss_band_weights": + ignore_nesting_dict[section_name] = section_dict + continue + # if value is dict, save all key and value into one dict for key, value in section_dict.items(): ignore_nesting_dict[key] = value diff --git a/train_network.py b/train_network.py index 68123a8e..7c881246 100644 --- a/train_network.py +++ b/train_network.py @@ -1318,11 +1318,13 @@ class NetworkTrainer: logger.info("Wavelet Loss:") logger.info(f"\tLevel: {args.wavelet_loss_level}") + logger.info(f"\tAlpha: {args.wavelet_loss_alpha}") + logger.info(f"\tTransform: {args.wavelet_loss_transform}") logger.info(f"\tWavelet: {args.wavelet_loss_wavelet}") if args.wavelet_loss_ll_level_threshold is not None: - logger.info(f"\tLL level threshold: {args.wavelet_loss_band_weights}") + logger.info(f"\tLL level threshold: {args.wavelet_loss_ll_level_threshold}") if args.wavelet_loss_band_weights is not None: - logger.info(f"\tBand Weights: {args.wavelet_loss_band_weights}") + logger.info(f"\tBand weights: {args.wavelet_loss_band_weights}") del train_dataset_group if val_dataset_group is not None: