diff --git a/wuerstchen/wuerstchen_train.py b/wuerstchen/wuerstchen_train.py index 69505f43..2c9fd022 100644 --- a/wuerstchen/wuerstchen_train.py +++ b/wuerstchen/wuerstchen_train.py @@ -239,8 +239,8 @@ def train(args): current_epoch = Value("i", 0) current_step = Value("i", 0) - ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None - collator = train_util.collator_class(current_epoch, current_step, ds_for_collater) + ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) train_dataset_group.verify_bucket_reso_steps(32)