Fix band weights via toml. Add more logging

This commit is contained in:
rockerBOO
2025-04-11 20:39:31 -04:00
parent 20a99771bf
commit 7b9e92a8cc
2 changed files with 8 additions and 2 deletions

View File

@@ -4657,6 +4657,10 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
ignore_nesting_dict[section_name] = section_dict
continue
if section_name == "wavelet_loss_band_weights":
ignore_nesting_dict[section_name] = section_dict
continue
# if value is dict, save all key and value into one dict
for key, value in section_dict.items():
ignore_nesting_dict[key] = value

View File

@@ -1318,11 +1318,13 @@ class NetworkTrainer:
logger.info("Wavelet Loss:")
logger.info(f"\tLevel: {args.wavelet_loss_level}")
logger.info(f"\tAlpha: {args.wavelet_loss_alpha}")
logger.info(f"\tTransform: {args.wavelet_loss_transform}")
logger.info(f"\tWavelet: {args.wavelet_loss_wavelet}")
if args.wavelet_loss_ll_level_threshold is not None:
logger.info(f"\tLL level threshold: {args.wavelet_loss_band_weights}")
logger.info(f"\tLL level threshold: {args.wavelet_loss_ll_level_threshold}")
if args.wavelet_loss_band_weights is not None:
logger.info(f"\tBand Weights: {args.wavelet_loss_band_weights}")
logger.info(f"\tBand weights: {args.wavelet_loss_band_weights}")
del train_dataset_group
if val_dataset_group is not None: