mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Fix band weights via toml. Add more logging
This commit is contained in:
@@ -4657,6 +4657,10 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
if section_name == "wavelet_loss_band_weights":
|
||||
ignore_nesting_dict[section_name] = section_dict
|
||||
continue
|
||||
|
||||
# if value is dict, save all key and value into one dict
|
||||
for key, value in section_dict.items():
|
||||
ignore_nesting_dict[key] = value
|
||||
|
||||
@@ -1318,11 +1318,13 @@ class NetworkTrainer:
|
||||
|
||||
logger.info("Wavelet Loss:")
|
||||
logger.info(f"\tLevel: {args.wavelet_loss_level}")
|
||||
logger.info(f"\tAlpha: {args.wavelet_loss_alpha}")
|
||||
logger.info(f"\tTransform: {args.wavelet_loss_transform}")
|
||||
logger.info(f"\tWavelet: {args.wavelet_loss_wavelet}")
|
||||
if args.wavelet_loss_ll_level_threshold is not None:
|
||||
logger.info(f"\tLL level threshold: {args.wavelet_loss_band_weights}")
|
||||
logger.info(f"\tLL level threshold: {args.wavelet_loss_ll_level_threshold}")
|
||||
if args.wavelet_loss_band_weights is not None:
|
||||
logger.info(f"\tBand Weights: {args.wavelet_loss_band_weights}")
|
||||
logger.info(f"\tBand weights: {args.wavelet_loss_band_weights}")
|
||||
|
||||
del train_dataset_group
|
||||
if val_dataset_group is not None:
|
||||
|
||||
Reference in New Issue
Block a user