diff --git a/library/train_util.py b/library/train_util.py index e152f30f..4bb2d6c1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4657,6 +4657,10 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar ignore_nesting_dict[section_name] = section_dict continue + if section_name == "wavelet_loss_band_weights": + ignore_nesting_dict[section_name] = section_dict + continue + # if value is dict, save all key and value into one dict for key, value in section_dict.items(): ignore_nesting_dict[key] = value diff --git a/train_network.py b/train_network.py index 68123a8e..7c881246 100644 --- a/train_network.py +++ b/train_network.py @@ -1318,11 +1318,13 @@ class NetworkTrainer: logger.info("Wavelet Loss:") logger.info(f"\tLevel: {args.wavelet_loss_level}") + logger.info(f"\tAlpha: {args.wavelet_loss_alpha}") + logger.info(f"\tTransform: {args.wavelet_loss_transform}") logger.info(f"\tWavelet: {args.wavelet_loss_wavelet}") if args.wavelet_loss_ll_level_threshold is not None: - logger.info(f"\tLL level threshold: {args.wavelet_loss_band_weights}") + logger.info(f"\tLL level threshold: {args.wavelet_loss_ll_level_threshold}") if args.wavelet_loss_band_weights is not None: - logger.info(f"\tBand Weights: {args.wavelet_loss_band_weights}") + logger.info(f"\tBand weights: {args.wavelet_loss_band_weights}") del train_dataset_group if val_dataset_group is not None: