diff --git a/train_network.py b/train_network.py index 086b314a..cab0ec52 100644 --- a/train_network.py +++ b/train_network.py @@ -313,6 +313,7 @@ class NetworkTrainer: collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) if args.debug_dataset: + train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly train_util.debug_dataset(train_dataset_group) return if len(train_dataset_group) == 0: